diff --git a/.gitignore b/.gitignore index 7328b3294..1ed76375d 100644 --- a/.gitignore +++ b/.gitignore @@ -39,6 +39,9 @@ tools/env.sh tools/openfst-1.8.1/ tools/libsndfile/ tools/python-soundfile/ +tools/onnx +tools/onnxruntime +tools/Paddle2ONNX speechx/fc_patch/ diff --git a/.mergify.yml b/.mergify.yml index 68b248101..5cb1f4865 100644 --- a/.mergify.yml +++ b/.mergify.yml @@ -52,7 +52,7 @@ pull_request_rules: add: ["T2S"] - name: "auto add label=Audio" conditions: - - files~=^paddleaudio/ + - files~=^paddlespeech/audio/ actions: label: add: ["Audio"] @@ -100,7 +100,7 @@ pull_request_rules: add: ["README"] - name: "auto add label=Documentation" conditions: - - files~=^(docs/|CHANGELOG.md|paddleaudio/CHANGELOG.md) + - files~=^(docs/|CHANGELOG.md) actions: label: add: ["Documentation"] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e3cc36e00..6e7ae1fbf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,12 +51,12 @@ repos: language: system files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$ - - id: copyright_checker - name: copyright_checker - entry: python .pre-commit-hooks/copyright-check.hook - language: system - files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ - exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$ + #- id: copyright_checker + # name: copyright_checker + # entry: python .pre-commit-hooks/copyright-check.hook + # language: system + # files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ + # exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$ - repo: https://github.com/asottile/reorder_python_imports rev: v2.4.0 hooks: diff --git a/.pre-commit-hooks/copyright-check.hook b/.pre-commit-hooks/copyright-check.hook index 26044c29e..761edbc01 100644 --- a/.pre-commit-hooks/copyright-check.hook +++ b/.pre-commit-hooks/copyright-check.hook @@ -19,7 +19,7 @@ import subprocess import platform COPYRIGHT = ''' -Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/README.md b/README.md index 2ade8a69c..c9d4796c8 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,4 @@ ([简体中文](./README_cn.md)|English) - - -

@@ -20,20 +17,19 @@

-

- | Quick Start +

+ Quick Start | Quick Start Server | Quick Start Streaming Server - | -
| Documents | Models List - | -

+ | AIStudio Courses + | Paper + | Gitee +
- - +------------------------------------------------------------------------------------ **PaddleSpeech** is an open-source toolkit on [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) platform for a variety of critical tasks in speech and audio, with the state-of-art and influential models. @@ -170,23 +166,12 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision - 🤗 2021.12.14: [ASR](https://huggingface.co/spaces/KPatrick/PaddleSpeechASR) and [TTS](https://huggingface.co/spaces/KPatrick/PaddleSpeechTTS) Demos on Hugging Face Spaces are available! - 👏🏻 2021.12.10: `CLI` is available for `Audio Classification`, `Automatic Speech Recognition`, `Speech Translation (English to Chinese)` and `Text-to-Speech`. -### 🔥 Hot Activities - - - -- 2021.12.21~12.24 - - 4 Days Live Courses: Depth interpretation of PaddleSpeech! - - **Courses videos and related materials: https://aistudio.baidu.com/aistudio/education/group/info/25130** ### Community -- Scan the QR code below with your Wechat (reply【语音】after your friend's application is approved), you can access to official technical exchange group. Look forward to your participation. +- Scan the QR code below with your Wechat, you can access to official technical exchange group and get the bonus ( more than 20GB learning materials, such as papers, codes and videos ) and the live link of the lessons. Look forward to your participation.
- +
## Installation diff --git a/README_cn.md b/README_cn.md index f5ba93629..66ba3c0ec 100644 --- a/README_cn.md +++ b/README_cn.md @@ -18,40 +18,21 @@

-

- Quick Start - | Quick Start Server - | Quick Start Streaming Server -
- Documents - | Models List -

+

+ 快速开始 + | 快速使用服务 + | 快速使用流式服务 + | 教程文档 + | 模型列表 + | AIStudio 课程 + | 论文 + | Gitee +

------------------------------------------------------------------------------------ -
-

- 快速开始 - | 快速使用服务 - | 快速使用流式服务 - | 教程文档 - | 模型列表 -

- - - - - - - **PaddleSpeech** 是基于飞桨 [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) 的语音方向的开源模型库,用于语音和音频中的各种关键任务的开发,包含大量基于深度学习前沿和有影响力的模型,一些典型的应用示例如下: ##### 语音识别 @@ -179,38 +160,30 @@ from https://github.com/18F/open-source-guide/blob/18f-pages/pages/making-readme ### 近期更新 - -- 👑 2022.05.13: PaddleSpeech 发布 [PP-ASR](./docs/source/asr/PPASR_cn.md)、[PP-TTS](./docs/source/tts/PPTTS_cn.md)、[PP-VPR](docs/source/vpr/PPVPR_cn.md) +- 👑 2022.05.13: PaddleSpeech 发布 [PP-ASR](./docs/source/asr/PPASR_cn.md) 流式语音识别系统、[PP-TTS](./docs/source/tts/PPTTS_cn.md) 流式语音合成系统、[PP-VPR](docs/source/vpr/PPVPR_cn.md) 全链路声纹识别系统 - 👏🏻 2022.05.06: PaddleSpeech Streaming Server 上线! 覆盖了语音识别(标点恢复、时间戳),和语音合成。 - 👏🏻 2022.05.06: PaddleSpeech Server 上线! 覆盖了声音分类、语音识别、语音合成、声纹识别,标点恢复。 - 👏🏻 2022.03.28: PaddleSpeech CLI 覆盖声音分类、语音识别、语音翻译(英译中)、语音合成,声纹验证。 - 🤗 2021.12.14: PaddleSpeech [ASR](https://huggingface.co/spaces/KPatrick/PaddleSpeechASR) and [TTS](https://huggingface.co/spaces/KPatrick/PaddleSpeechTTS) Demos on Hugging Face Spaces are available! -### 🔥 热门活动 -- 2021.12.21~12.24 + ### 🔥 加入技术交流群获取入群福利 - 4 日直播课: 深度解读 PaddleSpeech 语音技术! - - **直播回放与课件资料: https://aistudio.baidu.com/aistudio/education/group/info/25130** - - -### 技术交流群 -微信扫描二维码(好友申请通过后回复【语音】)加入官方交流群,获得更高效的问题答疑,与各行各业开发者充分交流,期待您的加入。 + - 3 日直播课链接: 深度解读 PP-TTS、PP-ASR、PP-VPR 三项核心语音系统关键技术 + - 20G 学习大礼包:视频课程、前沿论文与学习资料 + +微信扫描二维码关注公众号,点击“马上报名”填写问卷加入官方交流群,获得更高效的问题答疑,与各行各业开发者充分交流,期待您的加入。
- +
- ## 安装 我们强烈建议用户在 **Linux** 环境下,*3.7* 以上版本的 *python* 上安装 PaddleSpeech。 目前为止,**Linux** 支持声音分类、语音识别、语音合成和语音翻译四种功能,**Mac OSX、 Windows** 下暂不支持语音翻译功能。 想了解具体安装细节,可以参考[安装文档](./docs/source/install_cn.md)。 - + ## 快速开始 安装完成后,开发者可以通过命令行快速开始,改变 `--input` 可以尝试用自己的音频或文本测试。 @@ -257,7 +230,7 @@ paddlespeech asr --input ./zh.wav | paddlespeech text --task punc 更多命令行命令请参考 [demos](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/demos) > Note: 如果需要训练或者微调,请查看[语音识别](./docs/source/asr/quick_start.md), [语音合成](./docs/source/tts/quick_start.md)。 - + ## 快速使用服务 安装完成后,开发者可以通过命令行快速使用服务。 @@ -283,30 +256,30 @@ paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav 更多服务相关的命令行使用信息,请参考 [demos](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/demos/speech_server) - + ## 快速使用流式服务 -开发者可以尝试[流式ASR](./demos/streaming_asr_server/README.md)和 [流式TTS](./demos/streaming_tts_server/README.md)服务. +开发者可以尝试 [流式 ASR](./demos/streaming_asr_server/README.md) 和 [流式 TTS](./demos/streaming_tts_server/README.md) 服务. -**启动流式ASR服务** +**启动流式 ASR 服务** ``` paddlespeech_server start --config_file ./demos/streaming_asr_server/conf/application.yaml ``` -**访问流式ASR服务** +**访问流式 ASR 服务** ``` paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input input_16k.wav ``` -**启动流式TTS服务** +**启动流式 TTS 服务** ``` paddlespeech_server start --config_file ./demos/streaming_tts_server/conf/tts_online_application.yaml ``` -**访问流式TTS服务** +**访问流式 TTS 服务** ``` paddlespeech_client tts_online --server_ip 127.0.0.1 --port 8092 --protocol http --input "您好,欢迎使用百度飞桨语音合成服务。" --output output.wav @@ -314,8 +287,7 @@ paddlespeech_client tts_online --server_ip 127.0.0.1 --port 8092 --protocol http 更多信息参看: [流式 ASR](./demos/streaming_asr_server/README.md) 和 [流式 TTS](./demos/streaming_tts_server/README.md) - - + ## 模型列表 PaddleSpeech 支持很多主流的模型,并提供了预训练模型,详情请见[模型列表](./docs/source/released_model.md)。 @@ -587,6 +559,7 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声 + ## 教程文档 对于 PaddleSpeech 的所关注的任务,以下指南有助于帮助开发者快速入门,了解语音相关核心思想。 @@ -668,7 +641,7 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声 ## 参与 PaddleSpeech 的开发 -热烈欢迎您在[Discussions](https://github.com/PaddlePaddle/PaddleSpeech/discussions) 中提交问题,并在[Issues](https://github.com/PaddlePaddle/PaddleSpeech/issues) 中指出发现的 bug。此外,我们非常希望您参与到 PaddleSpeech 的开发中! +热烈欢迎您在 [Discussions](https://github.com/PaddlePaddle/PaddleSpeech/discussions) 中提交问题,并在 [Issues](https://github.com/PaddlePaddle/PaddleSpeech/issues) 中指出发现的 bug。此外,我们非常希望您参与到 PaddleSpeech 的开发中! ### 贡献者

diff --git a/audio/.gitignore b/audio/.gitignore deleted file mode 100644 index 1c930053d..000000000 --- a/audio/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -.eggs -*.wav diff --git a/audio/CHANGELOG.md b/audio/CHANGELOG.md deleted file mode 100644 index 925d77696..000000000 --- a/audio/CHANGELOG.md +++ /dev/null @@ -1,9 +0,0 @@ -# Changelog - -Date: 2022-3-15, Author: Xiaojie Chen. - - kaldi and librosa mfcc, fbank, spectrogram. - - unit test and benchmark. - -Date: 2022-2-25, Author: Hui Zhang. - - Refactor architecture. - - dtw distance and mcd style dtw. diff --git a/audio/README.md b/audio/README.md deleted file mode 100644 index 697c01739..000000000 --- a/audio/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# PaddleAudio - -PaddleAudio is an audio library for PaddlePaddle. - -## Install - -`pip install .` diff --git a/audio/docs/Makefile b/audio/docs/Makefile deleted file mode 100644 index 69fe55ecf..000000000 --- a/audio/docs/Makefile +++ /dev/null @@ -1,19 +0,0 @@ -# Minimal makefile for Sphinx documentation -# - -# You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = sphinx-build -SOURCEDIR = source -BUILDDIR = build - -# Put it first so that "make" without argument is like "make help". -help: - @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -.PHONY: help Makefile - -# Catch-all target: route all unknown targets to Sphinx using the new -# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file diff --git a/audio/docs/README.md b/audio/docs/README.md deleted file mode 100644 index 20626f52b..000000000 --- a/audio/docs/README.md +++ /dev/null @@ -1,24 +0,0 @@ -# Build docs for PaddleAudio - -Execute the following steps in **current directory**. - -## 1. Install - -`pip install Sphinx sphinx_rtd_theme` - - -## 2. Generate API docs - -Generate API docs from doc string. - -`sphinx-apidoc -fMeT -o source ../paddleaudio ../paddleaudio/utils --templatedir source/_templates` - - -## 3. Build - -`sphinx-build source _html` - - -## 4. Preview - -Open `_html/index.html` for page preview. diff --git a/audio/docs/images/paddle.png b/audio/docs/images/paddle.png deleted file mode 100644 index bc1135abf..000000000 Binary files a/audio/docs/images/paddle.png and /dev/null differ diff --git a/audio/docs/make.bat b/audio/docs/make.bat deleted file mode 100644 index 543c6b13b..000000000 --- a/audio/docs/make.bat +++ /dev/null @@ -1,35 +0,0 @@ -@ECHO OFF - -pushd %~dp0 - -REM Command file for Sphinx documentation - -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -) -set SOURCEDIR=source -set BUILDDIR=build - -if "%1" == "" goto help - -%SPHINXBUILD% >NUL 2>NUL -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ - exit /b 1 -) - -%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% -goto end - -:help -%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% - -:end -popd diff --git a/audio/paddleaudio/metric/dtw.py b/audio/paddleaudio/metric/dtw.py deleted file mode 100644 index 662e4506d..000000000 --- a/audio/paddleaudio/metric/dtw.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import numpy as np -from dtaidistance import dtw_ndim - -__all__ = [ - 'dtw_distance', -] - - -def dtw_distance(xs: np.ndarray, ys: np.ndarray) -> float: - """Dynamic Time Warping. - This function keeps a compact matrix, not the full warping paths matrix. - Uses dynamic programming to compute: - - Examples: - .. code-block:: python - - wps[i, j] = (s1[i]-s2[j])**2 + min( - wps[i-1, j ] + penalty, // vertical / insertion / expansion - wps[i , j-1] + penalty, // horizontal / deletion / compression - wps[i-1, j-1]) // diagonal / match - - dtw = sqrt(wps[-1, -1]) - - Args: - xs (np.ndarray): ref sequence, [T,D] - ys (np.ndarray): hyp sequence, [T,D] - - Returns: - float: dtw distance - """ - return dtw_ndim.distance(xs, ys) diff --git a/audio/paddleaudio/utils/env.py b/audio/paddleaudio/utils/env.py deleted file mode 100644 index a2d14b89e..000000000 --- a/audio/paddleaudio/utils/env.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License" -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -''' -This module is used to store environmental variables in PaddleAudio. -PPAUDIO_HOME --> the root directory for storing PaddleAudio related data. Default to ~/.paddleaudio. Users can change the -├ default value through the PPAUDIO_HOME environment variable. -├─ MODEL_HOME --> Store model files. -└─ DATA_HOME --> Store automatically downloaded datasets. -''' -import os - -__all__ = [ - 'USER_HOME', - 'PPAUDIO_HOME', - 'MODEL_HOME', - 'DATA_HOME', -] - - -def _get_user_home(): - return os.path.expanduser('~') - - -def _get_ppaudio_home(): - if 'PPAUDIO_HOME' in os.environ: - home_path = os.environ['PPAUDIO_HOME'] - if os.path.exists(home_path): - if os.path.isdir(home_path): - return home_path - else: - raise RuntimeError( - 'The environment variable PPAUDIO_HOME {} is not a directory.'. - format(home_path)) - else: - return home_path - return os.path.join(_get_user_home(), '.paddleaudio') - - -def _get_sub_home(directory): - home = os.path.join(_get_ppaudio_home(), directory) - if not os.path.exists(home): - os.makedirs(home) - return home - - -USER_HOME = _get_user_home() -PPAUDIO_HOME = _get_ppaudio_home() -MODEL_HOME = _get_sub_home('models') -DATA_HOME = _get_sub_home('datasets') diff --git a/audio/setup.py b/audio/setup.py deleted file mode 100644 index ec67c81de..000000000 --- a/audio/setup.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import glob -import os - -import setuptools -from setuptools.command.install import install -from setuptools.command.test import test - -# set the version here -VERSION = '0.0.0' - - -# Inspired by the example at https://pytest.org/latest/goodpractises.html -class TestCommand(test): - def finalize_options(self): - test.finalize_options(self) - self.test_args = [] - self.test_suite = True - - def run(self): - self.run_benchmark() - super(TestCommand, self).run() - - def run_tests(self): - # Run nose ensuring that argv simulates running nosetests directly - import nose - nose.run_exit(argv=['nosetests', '-w', 'tests']) - - def run_benchmark(self): - for benchmark_item in glob.glob('tests/benchmark/*py'): - os.system(f'pytest {benchmark_item}') - - -class InstallCommand(install): - def run(self): - install.run(self) - - -def write_version_py(filename='paddleaudio/__init__.py'): - with open(filename, "a") as f: - f.write(f"__version__ = '{VERSION}'") - - -def remove_version_py(filename='paddleaudio/__init__.py'): - with open(filename, "r") as f: - lines = f.readlines() - with open(filename, "w") as f: - for line in lines: - if "__version__" not in line: - f.write(line) - - -remove_version_py() -write_version_py() - -setuptools.setup( - name="paddleaudio", - version=VERSION, - author="", - author_email="", - description="PaddleAudio, in development", - long_description="", - long_description_content_type="text/markdown", - url="", - packages=setuptools.find_packages(include=['paddleaudio*']), - classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - ], - python_requires='>=3.6', - install_requires=[ - 'numpy >= 1.15.0', 'scipy >= 1.0.0', 'resampy >= 0.2.2', - 'soundfile >= 0.9.0', 'colorlog', 'dtaidistance == 2.3.1', 'pathos' - ], - extras_require={ - 'test': [ - 'nose', 'librosa==0.8.1', 'soundfile==0.10.3.post1', - 'torchaudio==0.10.2', 'pytest-benchmark' - ], - }, - cmdclass={ - 'install': InstallCommand, - 'test': TestCommand, - }, ) - -remove_version_py() diff --git a/audio/tests/.gitkeep b/audio/tests/.gitkeep deleted file mode 100644 index e69de29bb..000000000 diff --git a/demos/README.md b/demos/README.md index 8abd67249..2a306df6b 100644 --- a/demos/README.md +++ b/demos/README.md @@ -2,14 +2,14 @@ ([简体中文](./README_cn.md)|English) -The directory containes many speech applications in multi scenarios. +This directory contains many speech applications in multiple scenarios. * audio searching - mass audio similarity retrieval * audio tagging - multi-label tagging of an audio file -* automatic_video_subtitiles - generate subtitles from a video +* automatic_video_subtitles - generate subtitles from a video * metaverse - 2D AR with TTS * punctuation_restoration - restore punctuation from raw text -* speech recogintion - recognize text of an audio file +* speech recognition - recognize text of an audio file * speech server - Server for Speech Task, e.g. ASR,TTS,CLS * streaming asr server - receive audio stream from websocket, and recognize to transcript. * speech translation - end to end speech translation diff --git a/demos/audio_content_search/README.md b/demos/audio_content_search/README.md index d73d6a59d..4428bf389 100644 --- a/demos/audio_content_search/README.md +++ b/demos/audio_content_search/README.md @@ -16,7 +16,12 @@ see [installation](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/doc You can choose one way from meduim and hard to install paddlespeech. -The dependency refers to the requirements.txt +The dependency refers to the requirements.txt, and install the dependency as follows: + +``` +pip install -r requriement.txt +``` + ### 2. Prepare Input File The input of this demo should be a WAV file(`.wav`), and the sample rate must be the same as the model. diff --git a/demos/audio_content_search/README_cn.md b/demos/audio_content_search/README_cn.md index c74af4cf1..6f51c4cf2 100644 --- a/demos/audio_content_search/README_cn.md +++ b/demos/audio_content_search/README_cn.md @@ -16,7 +16,11 @@ 请看[安装文档](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md)。 你可以从 medium,hard 三中方式中选择一种方式安装。 -依赖参见 requirements.txt +依赖参见 requirements.txt, 安装依赖 + +``` +pip install -r requriement.txt +``` ### 2. 准备输入 这个 demo 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。 diff --git a/demos/audio_content_search/conf/acs_application.yaml b/demos/audio_content_search/conf/acs_application.yaml index d3c5e3039..dbddd06fb 100644 --- a/demos/audio_content_search/conf/acs_application.yaml +++ b/demos/audio_content_search/conf/acs_application.yaml @@ -28,6 +28,7 @@ acs_python: word_list: "./conf/words.txt" sample_rate: 16000 device: 'cpu' # set 'gpu:id' or 'cpu' + ping_timeout: 100 # seconds diff --git a/demos/audio_content_search/requirements.txt b/demos/audio_content_search/requirements.txt new file mode 100644 index 000000000..4126a4868 --- /dev/null +++ b/demos/audio_content_search/requirements.txt @@ -0,0 +1 @@ +websocket-client \ No newline at end of file diff --git a/demos/audio_content_search/streaming_asr_server.py b/demos/audio_content_search/streaming_asr_server.py new file mode 100644 index 000000000..011b009aa --- /dev/null +++ b/demos/audio_content_search/streaming_asr_server.py @@ -0,0 +1,38 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from paddlespeech.cli.log import logger +from paddlespeech.server.bin.paddlespeech_server import ServerExecutor +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog='paddlespeech_server.start', add_help=True) + parser.add_argument( + "--config_file", + action="store", + help="yaml file of the app", + default=None, + required=True) + + parser.add_argument( + "--log_file", + action="store", + help="log file", + default="./log/paddlespeech.log") + logger.info("start to parse the args") + args = parser.parse_args() + + logger.info("start to launch the streaming asr server") + streaming_asr_server = ServerExecutor() + streaming_asr_server(config_file=args.config_file, log_file=args.log_file) diff --git a/demos/audio_searching/README.md b/demos/audio_searching/README.md index e829d991a..db38d14ed 100644 --- a/demos/audio_searching/README.md +++ b/demos/audio_searching/README.md @@ -89,7 +89,7 @@ Then to start the system server, and it provides HTTP backend services. Then start the server with Fastapi. ```bash - export PYTHONPATH=$PYTHONPATH:./src:../../paddleaudio + export PYTHONPATH=$PYTHONPATH:./src python src/audio_search.py ``` diff --git a/demos/audio_searching/README_cn.md b/demos/audio_searching/README_cn.md index c13742af7..6d38b91f5 100644 --- a/demos/audio_searching/README_cn.md +++ b/demos/audio_searching/README_cn.md @@ -91,7 +91,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…" 启动用 Fastapi 构建的服务 ```bash - export PYTHONPATH=$PYTHONPATH:./src:../../paddleaudio + export PYTHONPATH=$PYTHONPATH:./src python src/audio_search.py ``` diff --git a/demos/audio_searching/src/encode.py b/demos/audio_searching/src/encode.py index c89a11c1f..f6bcb00ad 100644 --- a/demos/audio_searching/src/encode.py +++ b/demos/audio_searching/src/encode.py @@ -14,7 +14,7 @@ import numpy as np from logs import LOGGER -from paddlespeech.cli import VectorExecutor +from paddlespeech.cli.vector import VectorExecutor vector_executor = VectorExecutor() diff --git a/demos/audio_searching/src/operations/load.py b/demos/audio_searching/src/operations/load.py index d1ea00576..0d9edb784 100644 --- a/demos/audio_searching/src/operations/load.py +++ b/demos/audio_searching/src/operations/load.py @@ -26,9 +26,8 @@ def get_audios(path): """ supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"] return [ - item - for sublist in [[os.path.join(dir, file) for file in files] - for dir, _, files in list(os.walk(path))] + item for sublist in [[os.path.join(dir, file) for file in files] + for dir, _, files in list(os.walk(path))] for item in sublist if os.path.splitext(item)[1] in supported_formats ] diff --git a/demos/audio_tagging/README.md b/demos/audio_tagging/README.md index 9d4af0be6..fc4a334ea 100644 --- a/demos/audio_tagging/README.md +++ b/demos/audio_tagging/README.md @@ -57,7 +57,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav https://paddlespe - Python API ```python import paddle - from paddlespeech.cli import CLSExecutor + from paddlespeech.cli.cls import CLSExecutor cls_executor = CLSExecutor() result = cls_executor( diff --git a/demos/audio_tagging/README_cn.md b/demos/audio_tagging/README_cn.md index 79f87bf8c..36b5d8aaf 100644 --- a/demos/audio_tagging/README_cn.md +++ b/demos/audio_tagging/README_cn.md @@ -57,7 +57,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav https://paddlespe - Python API ```python import paddle - from paddlespeech.cli import CLSExecutor + from paddlespeech.cli.cls import CLSExecutor cls_executor = CLSExecutor() result = cls_executor( diff --git a/demos/automatic_video_subtitiles/README.md b/demos/automatic_video_subtitiles/README.md index db6da40db..b815425ec 100644 --- a/demos/automatic_video_subtitiles/README.md +++ b/demos/automatic_video_subtitiles/README.md @@ -28,7 +28,8 @@ ffmpeg -i subtitle_demo1.mp4 -ac 1 -ar 16000 -vn input.wav - Python API ```python import paddle - from paddlespeech.cli import ASRExecutor, TextExecutor + from paddlespeech.cli.asr import ASRExecutor + from paddlespeech.cli.text import TextExecutor asr_executor = ASRExecutor() text_executor = TextExecutor() diff --git a/demos/automatic_video_subtitiles/README_cn.md b/demos/automatic_video_subtitiles/README_cn.md index fc7b2cf6a..990ff6dbd 100644 --- a/demos/automatic_video_subtitiles/README_cn.md +++ b/demos/automatic_video_subtitiles/README_cn.md @@ -23,7 +23,8 @@ ffmpeg -i subtitle_demo1.mp4 -ac 1 -ar 16000 -vn input.wav - Python API ```python import paddle - from paddlespeech.cli import ASRExecutor, TextExecutor + from paddlespeech.cli.asr import ASRExecutor + from paddlespeech.cli.text import TextExecutor asr_executor = ASRExecutor() text_executor = TextExecutor() diff --git a/demos/automatic_video_subtitiles/recognize.py b/demos/automatic_video_subtitiles/recognize.py index 72e3c3a85..304599d19 100644 --- a/demos/automatic_video_subtitiles/recognize.py +++ b/demos/automatic_video_subtitiles/recognize.py @@ -16,8 +16,8 @@ import os import paddle -from paddlespeech.cli import ASRExecutor -from paddlespeech.cli import TextExecutor +from paddlespeech.cli.asr import ASRExecutor +from paddlespeech.cli.text import TextExecutor # yapf: disable parser = argparse.ArgumentParser(__doc__) diff --git a/demos/custom_streaming_asr/README.md b/demos/custom_streaming_asr/README.md index aa28d502f..da86e90ab 100644 --- a/demos/custom_streaming_asr/README.md +++ b/demos/custom_streaming_asr/README.md @@ -3,10 +3,13 @@ # Customized Auto Speech Recognition ## introduction + In some cases, we need to recognize the specific rare words with high accuracy. eg: address recognition in navigation apps. customized ASR can slove those issues. this demo is customized for expense account, which need to recognize rare address. +the scripts are in https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/speechx/examples/custom_asr + * G with slot: 打车到 "address_slot"。 ![](https://ai-studio-static-online.cdn.bcebos.com/28d9ef132a7f47a895a65ae9e5c4f55b8f472c9f3dd24be8a2e66e0b88b173a4) @@ -62,4 +65,4 @@ I0513 10:58:13.884493 41768 feature_cache.h:52] set finished I0513 10:58:24.247171 41768 paddle_nnet.h:76] Tensor neml: 10240 I0513 10:58:24.247249 41768 paddle_nnet.h:76] Tensor neml: 10240 LOG ([5.5.544~2-f21d7]:main():decoder/recognizer_test_main.cc:90) the result of case_10 is 五月十二日二十二点三十六分加班打车回家四十一元 -``` \ No newline at end of file +``` diff --git a/demos/custom_streaming_asr/README_cn.md b/demos/custom_streaming_asr/README_cn.md index ffbf682fb..f9981a6ae 100644 --- a/demos/custom_streaming_asr/README_cn.md +++ b/demos/custom_streaming_asr/README_cn.md @@ -6,6 +6,8 @@ 这个 demo 是打车报销单的场景识别,需要识别一些稀有的地名,可以通过如下操作实现。 +相关脚本:https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/speechx/examples/custom_asr + * G with slot: 打车到 "address_slot"。 ![](https://ai-studio-static-online.cdn.bcebos.com/28d9ef132a7f47a895a65ae9e5c4f55b8f472c9f3dd24be8a2e66e0b88b173a4) diff --git a/demos/punctuation_restoration/README.md b/demos/punctuation_restoration/README.md index 518d437dc..458ab92f9 100644 --- a/demos/punctuation_restoration/README.md +++ b/demos/punctuation_restoration/README.md @@ -42,7 +42,7 @@ The input of this demo should be a text of the specific language that can be pas - Python API ```python import paddle - from paddlespeech.cli import TextExecutor + from paddlespeech.cli.text import TextExecutor text_executor = TextExecutor() result = text_executor( diff --git a/demos/punctuation_restoration/README_cn.md b/demos/punctuation_restoration/README_cn.md index 9d4be8bf0..f25acdadb 100644 --- a/demos/punctuation_restoration/README_cn.md +++ b/demos/punctuation_restoration/README_cn.md @@ -44,7 +44,7 @@ - Python API ```python import paddle - from paddlespeech.cli import TextExecutor + from paddlespeech.cli.text import TextExecutor text_executor = TextExecutor() result = text_executor( diff --git a/demos/speaker_verification/README.md b/demos/speaker_verification/README.md index b6a1d9bcc..900b5ae40 100644 --- a/demos/speaker_verification/README.md +++ b/demos/speaker_verification/README.md @@ -53,51 +53,50 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav Output: ```bash - demo [ 1.4217498 5.626253 -5.342073 1.1773866 3.308055 - 1.756596 5.167894 10.80636 -3.8226728 -5.6141334 - 2.623845 -0.8072968 1.9635103 -7.3128724 0.01103897 - -9.723131 0.6619743 -6.976803 10.213478 7.494748 - 2.9105635 3.8949256 3.7999806 7.1061673 16.905321 - -7.1493764 8.733103 3.4230042 -4.831653 -11.403367 - 11.232214 7.1274667 -4.2828417 2.452362 -5.130748 - -18.177666 -2.6116815 -11.000337 -6.7314315 1.6564683 - 0.7618269 1.1253023 -2.083836 4.725744 -8.782597 - -3.539873 3.814236 5.1420674 2.162061 4.096431 - -6.4162116 12.747448 1.9429878 -15.152943 6.417416 - 16.097002 -9.716668 -1.9920526 -3.3649497 -1.871939 - 11.567354 3.69788 11.258265 7.442363 9.183411 - 4.5281515 -1.2417862 4.3959084 6.6727695 5.8898783 - 7.627124 -0.66919386 -11.889693 -9.208865 -7.4274073 - -3.7776625 6.917234 -9.848748 -2.0944717 -5.135116 - 0.49563864 9.317534 -5.9141874 -1.8098574 -0.11738578 - -7.169265 -1.0578263 -5.7216787 -5.1173844 16.137651 - -4.473626 7.6624317 -0.55381083 9.631587 -6.4704556 - -8.548508 4.3716145 -0.79702514 4.478997 -2.9758704 - 3.272176 2.8382776 5.134597 -9.190781 -0.5657382 - -4.8745747 2.3165567 -5.984303 -2.1798875 0.35541576 - -0.31784213 9.493548 2.1144536 4.358092 -12.089823 - 8.451689 -7.925461 4.6242585 4.4289427 18.692003 - -2.6204622 -5.149185 -0.35821092 8.488551 4.981496 - -9.32683 -2.2544234 6.6417594 1.2119585 10.977129 - 16.555033 3.3238444 9.551863 -1.6676947 -0.79539716 - -8.605674 -0.47356385 2.6741948 -5.359179 -2.6673796 - 0.66607 15.443222 4.740594 -3.4725387 11.592567 - -2.054497 1.7361217 -8.265324 -9.30447 5.4068313 - -1.5180256 -7.746615 -6.089606 0.07112726 -0.34904733 - -8.649895 -9.998958 -2.564841 -0.53999114 2.601808 - -0.31927416 -1.8815292 -2.07215 -3.4105783 -8.2998085 - 1.483641 -15.365992 -8.288208 3.8847756 -3.4876456 - 7.3629923 0.4657332 3.132599 12.438889 -1.8337058 - 4.532936 2.7264361 10.145339 -6.521951 2.897153 - -3.3925855 5.079156 7.759716 4.677565 5.8457737 - 2.402413 7.7071047 3.9711342 -6.390043 6.1268735 - -3.7760346 -11.118123 ] + demo [ -1.3251206 7.8606825 -4.620626 0.3000721 2.2648535 + -1.1931441 3.0647137 7.673595 -6.0044727 -12.02426 + -1.9496069 3.1269536 1.618838 -7.6383104 -1.2299773 + -12.338331 2.1373026 -5.3957124 9.717328 5.6752305 + 3.7805123 3.0597172 3.429692 8.97601 13.174125 + -0.53132284 8.9424715 4.46511 -4.4262476 -9.726503 + 8.399328 7.2239175 -7.435854 2.9441683 -4.3430395 + -13.886965 -1.6346735 -10.9027405 -5.311245 3.8007221 + 3.8976038 -2.1230774 -2.3521194 4.151031 -7.4048667 + 0.13911647 2.4626107 4.9664545 0.9897574 5.4839754 + -3.3574002 10.1340065 -0.6120171 -10.403095 4.6007543 + 16.00935 -7.7836914 -4.1945305 -6.9368606 1.1789556 + 11.490801 4.2380238 9.550931 8.375046 7.5089145 + -0.65707296 -0.30051577 2.8406055 3.0828028 0.730817 + 6.148354 0.13766119 -13.424735 -7.7461405 -2.3227983 + -8.305252 2.9879124 -10.995229 0.15211068 -2.3820348 + -1.7984174 8.495629 -5.8522367 -3.755498 0.6989711 + -5.2702994 -2.6188622 -1.8828466 -4.64665 14.078544 + -0.5495333 10.579158 -3.2160501 9.349004 -4.381078 + -11.675817 -2.8630207 4.5721755 2.246612 -4.574342 + 1.8610188 2.3767874 5.6257877 -9.784078 0.64967257 + -1.4579505 0.4263264 -4.9211264 -2.454784 3.4869802 + -0.42654222 8.341269 1.356552 7.0966883 -13.102829 + 8.016734 -7.1159344 1.8699781 0.208721 14.699384 + -1.025278 -2.6107233 -2.5082312 8.427193 6.9138527 + -6.2912464 0.6157366 2.489688 -3.4668267 9.921763 + 11.200815 -0.1966403 7.4916005 -0.62312716 -0.25848144 + -9.947997 -0.9611041 1.1649219 -2.1907122 -1.5028487 + -0.51926106 15.165954 2.4649463 -0.9980445 7.4416637 + -2.0768049 3.5896823 -7.3055434 -7.5620847 4.323335 + 0.0804418 -6.56401 -2.3148053 -1.7642345 -2.4708817 + -7.675618 -9.548878 -1.0177554 0.16986446 2.5877135 + -1.8752296 -0.36614323 -6.0493784 -2.3965611 -5.9453387 + 0.9424033 -13.155974 -7.457801 0.14658108 -3.742797 + 5.8414927 -1.2872906 5.5694313 12.57059 1.0939219 + 2.2142086 1.9181576 6.9914207 -5.888139 3.1409824 + -2.003628 2.4434285 9.973139 5.03668 2.0051203 + 2.8615603 5.860224 2.9176188 -1.6311141 2.0292206 + -4.070415 -6.831437 ] ``` - Python API ```python - import paddle - from paddlespeech.cli import VectorExecutor + from paddlespeech.cli.vector import VectorExecutor vector_executor = VectorExecutor() audio_emb = vector_executor( @@ -128,88 +127,88 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav ```bash # Vector Result: Audio embedding Result: - [ 1.4217498 5.626253 -5.342073 1.1773866 3.308055 - 1.756596 5.167894 10.80636 -3.8226728 -5.6141334 - 2.623845 -0.8072968 1.9635103 -7.3128724 0.01103897 - -9.723131 0.6619743 -6.976803 10.213478 7.494748 - 2.9105635 3.8949256 3.7999806 7.1061673 16.905321 - -7.1493764 8.733103 3.4230042 -4.831653 -11.403367 - 11.232214 7.1274667 -4.2828417 2.452362 -5.130748 - -18.177666 -2.6116815 -11.000337 -6.7314315 1.6564683 - 0.7618269 1.1253023 -2.083836 4.725744 -8.782597 - -3.539873 3.814236 5.1420674 2.162061 4.096431 - -6.4162116 12.747448 1.9429878 -15.152943 6.417416 - 16.097002 -9.716668 -1.9920526 -3.3649497 -1.871939 - 11.567354 3.69788 11.258265 7.442363 9.183411 - 4.5281515 -1.2417862 4.3959084 6.6727695 5.8898783 - 7.627124 -0.66919386 -11.889693 -9.208865 -7.4274073 - -3.7776625 6.917234 -9.848748 -2.0944717 -5.135116 - 0.49563864 9.317534 -5.9141874 -1.8098574 -0.11738578 - -7.169265 -1.0578263 -5.7216787 -5.1173844 16.137651 - -4.473626 7.6624317 -0.55381083 9.631587 -6.4704556 - -8.548508 4.3716145 -0.79702514 4.478997 -2.9758704 - 3.272176 2.8382776 5.134597 -9.190781 -0.5657382 - -4.8745747 2.3165567 -5.984303 -2.1798875 0.35541576 - -0.31784213 9.493548 2.1144536 4.358092 -12.089823 - 8.451689 -7.925461 4.6242585 4.4289427 18.692003 - -2.6204622 -5.149185 -0.35821092 8.488551 4.981496 - -9.32683 -2.2544234 6.6417594 1.2119585 10.977129 - 16.555033 3.3238444 9.551863 -1.6676947 -0.79539716 - -8.605674 -0.47356385 2.6741948 -5.359179 -2.6673796 - 0.66607 15.443222 4.740594 -3.4725387 11.592567 - -2.054497 1.7361217 -8.265324 -9.30447 5.4068313 - -1.5180256 -7.746615 -6.089606 0.07112726 -0.34904733 - -8.649895 -9.998958 -2.564841 -0.53999114 2.601808 - -0.31927416 -1.8815292 -2.07215 -3.4105783 -8.2998085 - 1.483641 -15.365992 -8.288208 3.8847756 -3.4876456 - 7.3629923 0.4657332 3.132599 12.438889 -1.8337058 - 4.532936 2.7264361 10.145339 -6.521951 2.897153 - -3.3925855 5.079156 7.759716 4.677565 5.8457737 - 2.402413 7.7071047 3.9711342 -6.390043 6.1268735 - -3.7760346 -11.118123 ] + [ -1.3251206 7.8606825 -4.620626 0.3000721 2.2648535 + -1.1931441 3.0647137 7.673595 -6.0044727 -12.02426 + -1.9496069 3.1269536 1.618838 -7.6383104 -1.2299773 + -12.338331 2.1373026 -5.3957124 9.717328 5.6752305 + 3.7805123 3.0597172 3.429692 8.97601 13.174125 + -0.53132284 8.9424715 4.46511 -4.4262476 -9.726503 + 8.399328 7.2239175 -7.435854 2.9441683 -4.3430395 + -13.886965 -1.6346735 -10.9027405 -5.311245 3.8007221 + 3.8976038 -2.1230774 -2.3521194 4.151031 -7.4048667 + 0.13911647 2.4626107 4.9664545 0.9897574 5.4839754 + -3.3574002 10.1340065 -0.6120171 -10.403095 4.6007543 + 16.00935 -7.7836914 -4.1945305 -6.9368606 1.1789556 + 11.490801 4.2380238 9.550931 8.375046 7.5089145 + -0.65707296 -0.30051577 2.8406055 3.0828028 0.730817 + 6.148354 0.13766119 -13.424735 -7.7461405 -2.3227983 + -8.305252 2.9879124 -10.995229 0.15211068 -2.3820348 + -1.7984174 8.495629 -5.8522367 -3.755498 0.6989711 + -5.2702994 -2.6188622 -1.8828466 -4.64665 14.078544 + -0.5495333 10.579158 -3.2160501 9.349004 -4.381078 + -11.675817 -2.8630207 4.5721755 2.246612 -4.574342 + 1.8610188 2.3767874 5.6257877 -9.784078 0.64967257 + -1.4579505 0.4263264 -4.9211264 -2.454784 3.4869802 + -0.42654222 8.341269 1.356552 7.0966883 -13.102829 + 8.016734 -7.1159344 1.8699781 0.208721 14.699384 + -1.025278 -2.6107233 -2.5082312 8.427193 6.9138527 + -6.2912464 0.6157366 2.489688 -3.4668267 9.921763 + 11.200815 -0.1966403 7.4916005 -0.62312716 -0.25848144 + -9.947997 -0.9611041 1.1649219 -2.1907122 -1.5028487 + -0.51926106 15.165954 2.4649463 -0.9980445 7.4416637 + -2.0768049 3.5896823 -7.3055434 -7.5620847 4.323335 + 0.0804418 -6.56401 -2.3148053 -1.7642345 -2.4708817 + -7.675618 -9.548878 -1.0177554 0.16986446 2.5877135 + -1.8752296 -0.36614323 -6.0493784 -2.3965611 -5.9453387 + 0.9424033 -13.155974 -7.457801 0.14658108 -3.742797 + 5.8414927 -1.2872906 5.5694313 12.57059 1.0939219 + 2.2142086 1.9181576 6.9914207 -5.888139 3.1409824 + -2.003628 2.4434285 9.973139 5.03668 2.0051203 + 2.8615603 5.860224 2.9176188 -1.6311141 2.0292206 + -4.070415 -6.831437 ] # get the test embedding Test embedding Result: - [ -1.902964 2.0690894 -8.034194 3.5472693 0.18089125 - 6.9085927 1.4097427 -1.9487704 -10.021278 -0.20755845 - -8.04332 4.344489 2.3200977 -14.306299 5.184692 - -11.55602 -3.8497238 0.6444722 1.2833948 2.6766639 - 0.5878921 0.7946299 1.7207596 2.5791872 14.998469 - -1.3385371 15.031221 -0.8006958 1.99287 -9.52007 - 2.435466 4.003221 -4.33817 -4.898601 -5.304714 - -18.033886 10.790787 -12.784645 -5.641755 2.9761686 - -10.566622 1.4839455 6.152458 -5.7195854 2.8603241 - 6.112133 8.489869 5.5958056 1.2836679 -1.2293907 - 0.89927405 7.0288725 -2.854029 -0.9782962 5.8255906 - 14.905906 -5.025907 0.7866458 -4.2444224 -16.354029 - 10.521315 0.9604709 -3.3257897 7.144871 -13.592733 - -8.568869 -1.7953678 0.26313916 10.916714 -6.9374123 - 1.857403 -6.2746415 2.8154466 -7.2338667 -2.293357 - -0.05452765 5.4287076 5.0849075 -6.690375 -1.6183422 - 3.654291 0.94352573 -9.200294 -5.4749465 -3.5235846 - 1.3420814 4.240421 -2.772944 -2.8451524 16.311104 - 4.2969875 -1.762936 -12.5758915 8.595198 -0.8835239 - -1.5708797 1.568961 1.1413603 3.5032008 -0.45251232 - -6.786333 16.89443 5.3366146 -8.789056 0.6355629 - 3.2579517 -3.328322 7.5969577 0.66025066 -6.550468 - -9.148656 2.020372 -0.4615173 1.1965656 -3.8764873 - 11.6562195 -6.0750933 12.182899 3.2218833 0.81969476 - 5.570001 -3.8459578 -7.205299 7.9262037 -7.6611166 - -5.249467 -2.2671914 7.2658715 -13.298164 4.821147 - -2.7263982 11.691089 -3.8918593 -2.838112 -1.0336838 - -3.8034165 2.8536487 -5.60398 -1.1972581 1.3455094 - -3.4903061 2.2408795 5.5010734 -3.970756 11.99696 - -7.8858757 0.43160373 -5.5059714 4.3426995 16.322706 - 11.635366 0.72157705 -9.245714 -3.91465 -4.449838 - -1.5716927 7.713747 -2.2430465 -6.198303 -13.481864 - 2.8156567 -5.7812386 5.1456156 2.7289324 -14.505571 - 13.270688 3.448231 -7.0659585 4.5886116 -4.466099 - -0.296428 -11.463529 -2.6076477 14.110243 -6.9725137 - -1.9962958 2.7119343 19.391657 0.01961198 14.607133 - -1.6695905 -4.391516 1.3131028 -6.670972 -5.888604 - 12.0612335 5.9285784 3.3715196 1.492534 10.723728 - -0.95514804 -12.085431 ] + [ 2.5247195 5.119042 -4.335273 4.4583654 5.047907 + 3.5059214 1.6159848 0.49364898 -11.6899185 -3.1014526 + -5.6589785 -0.42684984 2.674276 -11.937654 6.2248464 + -10.776924 -5.694543 1.112041 1.5709964 1.0961034 + 1.3976512 2.324352 1.339981 5.279319 13.734659 + -2.5753925 13.651442 -2.2357535 5.1575427 -3.251567 + 1.4023279 6.1191974 -6.0845175 -1.3646189 -2.6789894 + -15.220778 9.779349 -9.411551 -6.388947 6.8313975 + -9.245996 0.31196198 2.5509644 -4.413065 6.1649427 + 6.793837 2.6328635 8.620976 3.4832475 0.52491665 + 2.9115407 5.8392377 0.6702376 -3.2726715 2.6694255 + 16.91701 -5.5811176 0.23362345 -4.5573606 -11.801059 + 14.728292 -0.5198082 -3.999922 7.0927105 -7.0459595 + -5.4389 -0.46420583 -5.1085467 10.376568 -8.889225 + -0.37705845 -1.659806 2.6731026 -7.1909504 1.4608804 + -2.163136 -0.17949677 4.0241547 0.11319201 0.601279 + 2.039692 3.1910992 -11.649526 -8.121584 -4.8707457 + 0.3851982 1.4231744 -2.3321972 0.99332285 14.121717 + 5.899413 0.7384519 -17.760096 10.555021 4.1366534 + -0.3391071 -0.20792882 3.208204 0.8847948 -8.721497 + -6.432868 13.006379 4.8956 -9.155822 -1.9441519 + 5.7815638 -2.066733 10.425042 -0.8802383 -2.4314315 + -9.869258 0.35095334 -5.3549943 2.1076174 -8.290468 + 8.4433365 -4.689333 9.334139 -2.172678 -3.0250976 + 8.394216 -3.2110903 -7.93868 2.3960824 -2.3213403 + -1.4963245 -3.476059 4.132903 -10.893354 4.362673 + -0.45456508 10.258634 -1.1655927 -6.7799754 0.22885278 + -4.399287 2.333433 -4.84745 -4.2752337 -1.3577863 + -1.0685898 9.505196 7.3062205 0.08708266 12.927811 + -9.57974 1.3936648 -1.9444873 5.776769 15.251903 + 10.6118355 -1.4903594 -9.535318 -3.6553776 -1.6699586 + -0.5933151 7.600357 -4.8815503 -8.698617 -15.855757 + 0.25632986 -7.2235737 0.9506656 0.7128582 -9.051738 + 8.74869 -1.6426028 -6.5762258 2.506905 -6.7431564 + 5.129912 -12.189555 -3.6435068 12.068113 -6.0059533 + -2.3535995 2.9014351 22.3082 -1.5563312 13.193291 + 2.7583609 -7.468798 1.3407065 -4.599617 -6.2345777 + 10.7689295 7.137627 5.099476 0.3473359 9.647881 + -2.0484571 -5.8549366 ] # get the score between enroll and test - Eembeddings Score: 0.4292638301849365 + Eembeddings Score: 0.45332613587379456 ``` ### 4.Pretrained Models diff --git a/demos/speaker_verification/README_cn.md b/demos/speaker_verification/README_cn.md index 90bba38ac..f6afa86ac 100644 --- a/demos/speaker_verification/README_cn.md +++ b/demos/speaker_verification/README_cn.md @@ -51,51 +51,51 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav 输出: ```bash - demo [ 1.4217498 5.626253 -5.342073 1.1773866 3.308055 - 1.756596 5.167894 10.80636 -3.8226728 -5.6141334 - 2.623845 -0.8072968 1.9635103 -7.3128724 0.01103897 - -9.723131 0.6619743 -6.976803 10.213478 7.494748 - 2.9105635 3.8949256 3.7999806 7.1061673 16.905321 - -7.1493764 8.733103 3.4230042 -4.831653 -11.403367 - 11.232214 7.1274667 -4.2828417 2.452362 -5.130748 - -18.177666 -2.6116815 -11.000337 -6.7314315 1.6564683 - 0.7618269 1.1253023 -2.083836 4.725744 -8.782597 - -3.539873 3.814236 5.1420674 2.162061 4.096431 - -6.4162116 12.747448 1.9429878 -15.152943 6.417416 - 16.097002 -9.716668 -1.9920526 -3.3649497 -1.871939 - 11.567354 3.69788 11.258265 7.442363 9.183411 - 4.5281515 -1.2417862 4.3959084 6.6727695 5.8898783 - 7.627124 -0.66919386 -11.889693 -9.208865 -7.4274073 - -3.7776625 6.917234 -9.848748 -2.0944717 -5.135116 - 0.49563864 9.317534 -5.9141874 -1.8098574 -0.11738578 - -7.169265 -1.0578263 -5.7216787 -5.1173844 16.137651 - -4.473626 7.6624317 -0.55381083 9.631587 -6.4704556 - -8.548508 4.3716145 -0.79702514 4.478997 -2.9758704 - 3.272176 2.8382776 5.134597 -9.190781 -0.5657382 - -4.8745747 2.3165567 -5.984303 -2.1798875 0.35541576 - -0.31784213 9.493548 2.1144536 4.358092 -12.089823 - 8.451689 -7.925461 4.6242585 4.4289427 18.692003 - -2.6204622 -5.149185 -0.35821092 8.488551 4.981496 - -9.32683 -2.2544234 6.6417594 1.2119585 10.977129 - 16.555033 3.3238444 9.551863 -1.6676947 -0.79539716 - -8.605674 -0.47356385 2.6741948 -5.359179 -2.6673796 - 0.66607 15.443222 4.740594 -3.4725387 11.592567 - -2.054497 1.7361217 -8.265324 -9.30447 5.4068313 - -1.5180256 -7.746615 -6.089606 0.07112726 -0.34904733 - -8.649895 -9.998958 -2.564841 -0.53999114 2.601808 - -0.31927416 -1.8815292 -2.07215 -3.4105783 -8.2998085 - 1.483641 -15.365992 -8.288208 3.8847756 -3.4876456 - 7.3629923 0.4657332 3.132599 12.438889 -1.8337058 - 4.532936 2.7264361 10.145339 -6.521951 2.897153 - -3.3925855 5.079156 7.759716 4.677565 5.8457737 - 2.402413 7.7071047 3.9711342 -6.390043 6.1268735 - -3.7760346 -11.118123 ] + [ -1.3251206 7.8606825 -4.620626 0.3000721 2.2648535 + -1.1931441 3.0647137 7.673595 -6.0044727 -12.02426 + -1.9496069 3.1269536 1.618838 -7.6383104 -1.2299773 + -12.338331 2.1373026 -5.3957124 9.717328 5.6752305 + 3.7805123 3.0597172 3.429692 8.97601 13.174125 + -0.53132284 8.9424715 4.46511 -4.4262476 -9.726503 + 8.399328 7.2239175 -7.435854 2.9441683 -4.3430395 + -13.886965 -1.6346735 -10.9027405 -5.311245 3.8007221 + 3.8976038 -2.1230774 -2.3521194 4.151031 -7.4048667 + 0.13911647 2.4626107 4.9664545 0.9897574 5.4839754 + -3.3574002 10.1340065 -0.6120171 -10.403095 4.6007543 + 16.00935 -7.7836914 -4.1945305 -6.9368606 1.1789556 + 11.490801 4.2380238 9.550931 8.375046 7.5089145 + -0.65707296 -0.30051577 2.8406055 3.0828028 0.730817 + 6.148354 0.13766119 -13.424735 -7.7461405 -2.3227983 + -8.305252 2.9879124 -10.995229 0.15211068 -2.3820348 + -1.7984174 8.495629 -5.8522367 -3.755498 0.6989711 + -5.2702994 -2.6188622 -1.8828466 -4.64665 14.078544 + -0.5495333 10.579158 -3.2160501 9.349004 -4.381078 + -11.675817 -2.8630207 4.5721755 2.246612 -4.574342 + 1.8610188 2.3767874 5.6257877 -9.784078 0.64967257 + -1.4579505 0.4263264 -4.9211264 -2.454784 3.4869802 + -0.42654222 8.341269 1.356552 7.0966883 -13.102829 + 8.016734 -7.1159344 1.8699781 0.208721 14.699384 + -1.025278 -2.6107233 -2.5082312 8.427193 6.9138527 + -6.2912464 0.6157366 2.489688 -3.4668267 9.921763 + 11.200815 -0.1966403 7.4916005 -0.62312716 -0.25848144 + -9.947997 -0.9611041 1.1649219 -2.1907122 -1.5028487 + -0.51926106 15.165954 2.4649463 -0.9980445 7.4416637 + -2.0768049 3.5896823 -7.3055434 -7.5620847 4.323335 + 0.0804418 -6.56401 -2.3148053 -1.7642345 -2.4708817 + -7.675618 -9.548878 -1.0177554 0.16986446 2.5877135 + -1.8752296 -0.36614323 -6.0493784 -2.3965611 -5.9453387 + 0.9424033 -13.155974 -7.457801 0.14658108 -3.742797 + 5.8414927 -1.2872906 5.5694313 12.57059 1.0939219 + 2.2142086 1.9181576 6.9914207 -5.888139 3.1409824 + -2.003628 2.4434285 9.973139 5.03668 2.0051203 + 2.8615603 5.860224 2.9176188 -1.6311141 2.0292206 + -4.070415 -6.831437 ] ``` - Python API ```python import paddle - from paddlespeech.cli import VectorExecutor + from paddlespeech.cli.vector import VectorExecutor vector_executor = VectorExecutor() audio_emb = vector_executor( @@ -125,88 +125,88 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav ```bash # Vector Result: Audio embedding Result: - [ 1.4217498 5.626253 -5.342073 1.1773866 3.308055 - 1.756596 5.167894 10.80636 -3.8226728 -5.6141334 - 2.623845 -0.8072968 1.9635103 -7.3128724 0.01103897 - -9.723131 0.6619743 -6.976803 10.213478 7.494748 - 2.9105635 3.8949256 3.7999806 7.1061673 16.905321 - -7.1493764 8.733103 3.4230042 -4.831653 -11.403367 - 11.232214 7.1274667 -4.2828417 2.452362 -5.130748 - -18.177666 -2.6116815 -11.000337 -6.7314315 1.6564683 - 0.7618269 1.1253023 -2.083836 4.725744 -8.782597 - -3.539873 3.814236 5.1420674 2.162061 4.096431 - -6.4162116 12.747448 1.9429878 -15.152943 6.417416 - 16.097002 -9.716668 -1.9920526 -3.3649497 -1.871939 - 11.567354 3.69788 11.258265 7.442363 9.183411 - 4.5281515 -1.2417862 4.3959084 6.6727695 5.8898783 - 7.627124 -0.66919386 -11.889693 -9.208865 -7.4274073 - -3.7776625 6.917234 -9.848748 -2.0944717 -5.135116 - 0.49563864 9.317534 -5.9141874 -1.8098574 -0.11738578 - -7.169265 -1.0578263 -5.7216787 -5.1173844 16.137651 - -4.473626 7.6624317 -0.55381083 9.631587 -6.4704556 - -8.548508 4.3716145 -0.79702514 4.478997 -2.9758704 - 3.272176 2.8382776 5.134597 -9.190781 -0.5657382 - -4.8745747 2.3165567 -5.984303 -2.1798875 0.35541576 - -0.31784213 9.493548 2.1144536 4.358092 -12.089823 - 8.451689 -7.925461 4.6242585 4.4289427 18.692003 - -2.6204622 -5.149185 -0.35821092 8.488551 4.981496 - -9.32683 -2.2544234 6.6417594 1.2119585 10.977129 - 16.555033 3.3238444 9.551863 -1.6676947 -0.79539716 - -8.605674 -0.47356385 2.6741948 -5.359179 -2.6673796 - 0.66607 15.443222 4.740594 -3.4725387 11.592567 - -2.054497 1.7361217 -8.265324 -9.30447 5.4068313 - -1.5180256 -7.746615 -6.089606 0.07112726 -0.34904733 - -8.649895 -9.998958 -2.564841 -0.53999114 2.601808 - -0.31927416 -1.8815292 -2.07215 -3.4105783 -8.2998085 - 1.483641 -15.365992 -8.288208 3.8847756 -3.4876456 - 7.3629923 0.4657332 3.132599 12.438889 -1.8337058 - 4.532936 2.7264361 10.145339 -6.521951 2.897153 - -3.3925855 5.079156 7.759716 4.677565 5.8457737 - 2.402413 7.7071047 3.9711342 -6.390043 6.1268735 - -3.7760346 -11.118123 ] + [ -1.3251206 7.8606825 -4.620626 0.3000721 2.2648535 + -1.1931441 3.0647137 7.673595 -6.0044727 -12.02426 + -1.9496069 3.1269536 1.618838 -7.6383104 -1.2299773 + -12.338331 2.1373026 -5.3957124 9.717328 5.6752305 + 3.7805123 3.0597172 3.429692 8.97601 13.174125 + -0.53132284 8.9424715 4.46511 -4.4262476 -9.726503 + 8.399328 7.2239175 -7.435854 2.9441683 -4.3430395 + -13.886965 -1.6346735 -10.9027405 -5.311245 3.8007221 + 3.8976038 -2.1230774 -2.3521194 4.151031 -7.4048667 + 0.13911647 2.4626107 4.9664545 0.9897574 5.4839754 + -3.3574002 10.1340065 -0.6120171 -10.403095 4.6007543 + 16.00935 -7.7836914 -4.1945305 -6.9368606 1.1789556 + 11.490801 4.2380238 9.550931 8.375046 7.5089145 + -0.65707296 -0.30051577 2.8406055 3.0828028 0.730817 + 6.148354 0.13766119 -13.424735 -7.7461405 -2.3227983 + -8.305252 2.9879124 -10.995229 0.15211068 -2.3820348 + -1.7984174 8.495629 -5.8522367 -3.755498 0.6989711 + -5.2702994 -2.6188622 -1.8828466 -4.64665 14.078544 + -0.5495333 10.579158 -3.2160501 9.349004 -4.381078 + -11.675817 -2.8630207 4.5721755 2.246612 -4.574342 + 1.8610188 2.3767874 5.6257877 -9.784078 0.64967257 + -1.4579505 0.4263264 -4.9211264 -2.454784 3.4869802 + -0.42654222 8.341269 1.356552 7.0966883 -13.102829 + 8.016734 -7.1159344 1.8699781 0.208721 14.699384 + -1.025278 -2.6107233 -2.5082312 8.427193 6.9138527 + -6.2912464 0.6157366 2.489688 -3.4668267 9.921763 + 11.200815 -0.1966403 7.4916005 -0.62312716 -0.25848144 + -9.947997 -0.9611041 1.1649219 -2.1907122 -1.5028487 + -0.51926106 15.165954 2.4649463 -0.9980445 7.4416637 + -2.0768049 3.5896823 -7.3055434 -7.5620847 4.323335 + 0.0804418 -6.56401 -2.3148053 -1.7642345 -2.4708817 + -7.675618 -9.548878 -1.0177554 0.16986446 2.5877135 + -1.8752296 -0.36614323 -6.0493784 -2.3965611 -5.9453387 + 0.9424033 -13.155974 -7.457801 0.14658108 -3.742797 + 5.8414927 -1.2872906 5.5694313 12.57059 1.0939219 + 2.2142086 1.9181576 6.9914207 -5.888139 3.1409824 + -2.003628 2.4434285 9.973139 5.03668 2.0051203 + 2.8615603 5.860224 2.9176188 -1.6311141 2.0292206 + -4.070415 -6.831437 ] # get the test embedding Test embedding Result: - [ -1.902964 2.0690894 -8.034194 3.5472693 0.18089125 - 6.9085927 1.4097427 -1.9487704 -10.021278 -0.20755845 - -8.04332 4.344489 2.3200977 -14.306299 5.184692 - -11.55602 -3.8497238 0.6444722 1.2833948 2.6766639 - 0.5878921 0.7946299 1.7207596 2.5791872 14.998469 - -1.3385371 15.031221 -0.8006958 1.99287 -9.52007 - 2.435466 4.003221 -4.33817 -4.898601 -5.304714 - -18.033886 10.790787 -12.784645 -5.641755 2.9761686 - -10.566622 1.4839455 6.152458 -5.7195854 2.8603241 - 6.112133 8.489869 5.5958056 1.2836679 -1.2293907 - 0.89927405 7.0288725 -2.854029 -0.9782962 5.8255906 - 14.905906 -5.025907 0.7866458 -4.2444224 -16.354029 - 10.521315 0.9604709 -3.3257897 7.144871 -13.592733 - -8.568869 -1.7953678 0.26313916 10.916714 -6.9374123 - 1.857403 -6.2746415 2.8154466 -7.2338667 -2.293357 - -0.05452765 5.4287076 5.0849075 -6.690375 -1.6183422 - 3.654291 0.94352573 -9.200294 -5.4749465 -3.5235846 - 1.3420814 4.240421 -2.772944 -2.8451524 16.311104 - 4.2969875 -1.762936 -12.5758915 8.595198 -0.8835239 - -1.5708797 1.568961 1.1413603 3.5032008 -0.45251232 - -6.786333 16.89443 5.3366146 -8.789056 0.6355629 - 3.2579517 -3.328322 7.5969577 0.66025066 -6.550468 - -9.148656 2.020372 -0.4615173 1.1965656 -3.8764873 - 11.6562195 -6.0750933 12.182899 3.2218833 0.81969476 - 5.570001 -3.8459578 -7.205299 7.9262037 -7.6611166 - -5.249467 -2.2671914 7.2658715 -13.298164 4.821147 - -2.7263982 11.691089 -3.8918593 -2.838112 -1.0336838 - -3.8034165 2.8536487 -5.60398 -1.1972581 1.3455094 - -3.4903061 2.2408795 5.5010734 -3.970756 11.99696 - -7.8858757 0.43160373 -5.5059714 4.3426995 16.322706 - 11.635366 0.72157705 -9.245714 -3.91465 -4.449838 - -1.5716927 7.713747 -2.2430465 -6.198303 -13.481864 - 2.8156567 -5.7812386 5.1456156 2.7289324 -14.505571 - 13.270688 3.448231 -7.0659585 4.5886116 -4.466099 - -0.296428 -11.463529 -2.6076477 14.110243 -6.9725137 - -1.9962958 2.7119343 19.391657 0.01961198 14.607133 - -1.6695905 -4.391516 1.3131028 -6.670972 -5.888604 - 12.0612335 5.9285784 3.3715196 1.492534 10.723728 - -0.95514804 -12.085431 ] + [ 2.5247195 5.119042 -4.335273 4.4583654 5.047907 + 3.5059214 1.6159848 0.49364898 -11.6899185 -3.1014526 + -5.6589785 -0.42684984 2.674276 -11.937654 6.2248464 + -10.776924 -5.694543 1.112041 1.5709964 1.0961034 + 1.3976512 2.324352 1.339981 5.279319 13.734659 + -2.5753925 13.651442 -2.2357535 5.1575427 -3.251567 + 1.4023279 6.1191974 -6.0845175 -1.3646189 -2.6789894 + -15.220778 9.779349 -9.411551 -6.388947 6.8313975 + -9.245996 0.31196198 2.5509644 -4.413065 6.1649427 + 6.793837 2.6328635 8.620976 3.4832475 0.52491665 + 2.9115407 5.8392377 0.6702376 -3.2726715 2.6694255 + 16.91701 -5.5811176 0.23362345 -4.5573606 -11.801059 + 14.728292 -0.5198082 -3.999922 7.0927105 -7.0459595 + -5.4389 -0.46420583 -5.1085467 10.376568 -8.889225 + -0.37705845 -1.659806 2.6731026 -7.1909504 1.4608804 + -2.163136 -0.17949677 4.0241547 0.11319201 0.601279 + 2.039692 3.1910992 -11.649526 -8.121584 -4.8707457 + 0.3851982 1.4231744 -2.3321972 0.99332285 14.121717 + 5.899413 0.7384519 -17.760096 10.555021 4.1366534 + -0.3391071 -0.20792882 3.208204 0.8847948 -8.721497 + -6.432868 13.006379 4.8956 -9.155822 -1.9441519 + 5.7815638 -2.066733 10.425042 -0.8802383 -2.4314315 + -9.869258 0.35095334 -5.3549943 2.1076174 -8.290468 + 8.4433365 -4.689333 9.334139 -2.172678 -3.0250976 + 8.394216 -3.2110903 -7.93868 2.3960824 -2.3213403 + -1.4963245 -3.476059 4.132903 -10.893354 4.362673 + -0.45456508 10.258634 -1.1655927 -6.7799754 0.22885278 + -4.399287 2.333433 -4.84745 -4.2752337 -1.3577863 + -1.0685898 9.505196 7.3062205 0.08708266 12.927811 + -9.57974 1.3936648 -1.9444873 5.776769 15.251903 + 10.6118355 -1.4903594 -9.535318 -3.6553776 -1.6699586 + -0.5933151 7.600357 -4.8815503 -8.698617 -15.855757 + 0.25632986 -7.2235737 0.9506656 0.7128582 -9.051738 + 8.74869 -1.6426028 -6.5762258 2.506905 -6.7431564 + 5.129912 -12.189555 -3.6435068 12.068113 -6.0059533 + -2.3535995 2.9014351 22.3082 -1.5563312 13.193291 + 2.7583609 -7.468798 1.3407065 -4.599617 -6.2345777 + 10.7689295 7.137627 5.099476 0.3473359 9.647881 + -2.0484571 -5.8549366 ] # get the score between enroll and test - Eembeddings Score: 0.4292638301849365 + Eembeddings Score: 0.45332613587379456 ``` ### 4.预训练模型 diff --git a/demos/speech_recognition/README.md b/demos/speech_recognition/README.md index 6493e8e61..c815a88af 100644 --- a/demos/speech_recognition/README.md +++ b/demos/speech_recognition/README.md @@ -58,7 +58,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee - Python API ```python import paddle - from paddlespeech.cli import ASRExecutor + from paddlespeech.cli.asr import ASRExecutor asr_executor = ASRExecutor() text = asr_executor( diff --git a/demos/speech_recognition/README_cn.md b/demos/speech_recognition/README_cn.md index 8d631d89c..13aa9f277 100644 --- a/demos/speech_recognition/README_cn.md +++ b/demos/speech_recognition/README_cn.md @@ -56,7 +56,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee - Python API ```python import paddle - from paddlespeech.cli import ASRExecutor + from paddlespeech.cli.asr import ASRExecutor asr_executor = ASRExecutor() text = asr_executor( diff --git a/demos/speech_server/README.md b/demos/speech_server/README.md index 5a3de0ccd..14a88f078 100644 --- a/demos/speech_server/README.md +++ b/demos/speech_server/README.md @@ -257,13 +257,13 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee paddlespeech_client vector --task spk --server_ip 127.0.0.1 --port 8090 --input 85236145389.wav ``` - * Usage: + Usage: ``` bash paddlespeech_client vector --help ``` - * Arguments: + Arguments: * server_ip: server ip. Default: 127.0.0.1 * port: server port. Default: 8090 * input(required): Input text to generate. @@ -271,35 +271,35 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee * enroll: enroll audio * test: test audio - * Output: + Output: - ``` bash - [2022-05-08 00:18:44,249] [ INFO] - vector http client start - [2022-05-08 00:18:44,250] [ INFO] - the input audio: 85236145389.wav - [2022-05-08 00:18:44,250] [ INFO] - endpoint: http://127.0.0.1:8090/paddlespeech/vector - [2022-05-08 00:18:44,250] [ INFO] - http://127.0.0.1:8590/paddlespeech/vector - [2022-05-08 00:18:44,406] [ INFO] - The vector: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'vec': [1.421751856803894, 5.626245498657227, -5.342077255249023, 1.1773887872695923, 3.3080549240112305, 1.7565933465957642, 5.167886257171631, 10.806358337402344, -3.8226819038391113, -5.614140033721924, 2.6238479614257812, -0.8072972893714905, 1.9635076522827148, -7.312870025634766, 0.011035939678549767, -9.723129272460938, 0.6619706153869629, -6.976806163787842, 10.213476181030273, 7.494769096374512, 2.9105682373046875, 3.8949244022369385, 3.799983501434326, 7.106168746948242, 16.90532875061035, -7.149388313293457, 8.733108520507812, 3.423006296157837, -4.831653594970703, -11.403363227844238, 11.232224464416504, 7.127461910247803, -4.282842636108398, 2.452359437942505, -5.130749702453613, -18.17766761779785, -2.6116831302642822, -11.000344276428223, -6.731433391571045, 1.6564682722091675, 0.7618281245231628, 1.125300407409668, -2.0838370323181152, 4.725743293762207, -8.782588005065918, -3.5398752689361572, 3.8142364025115967, 5.142068862915039, 2.1620609760284424, 4.09643030166626, -6.416214942932129, 12.747446060180664, 1.9429892301559448, -15.15294361114502, 6.417416095733643, 16.09701156616211, -9.716667175292969, -1.9920575618743896, -3.36494779586792, -1.8719440698623657, 11.567351341247559, 3.6978814601898193, 11.258262634277344, 7.442368507385254, 9.183408737182617, 4.528149127960205, -1.2417854070663452, 4.395912170410156, 6.6727728843688965, 5.88988733291626, 7.627128601074219, -0.6691966652870178, -11.889698028564453, -9.20886516571045, -7.42740535736084, -3.777663230895996, 6.917238712310791, -9.848755836486816, -2.0944676399230957, -5.1351165771484375, 0.4956451654434204, 9.317537307739258, -5.914181232452393, -1.809860348701477, -0.11738915741443634, -7.1692705154418945, -1.057827353477478, -5.721670627593994, -5.117385387420654, 16.13765525817871, -4.473617076873779, 7.6624321937561035, -0.55381840467453, 9.631585121154785, -6.470459461212158, -8.548508644104004, 4.371616840362549, -0.7970245480537415, 4.4789886474609375, -2.975860834121704, 3.2721822261810303, 2.838287830352783, 5.134591102600098, -9.19079875946045, -0.5657302737236023, -4.8745832443237305, 2.3165574073791504, -5.984319686889648, -2.1798853874206543, 0.3554139733314514, -0.3178512752056122, 9.493552207946777, 2.1144471168518066, 4.358094692230225, -12.089824676513672, 8.451693534851074, -7.925466537475586, 4.624246597290039, 4.428936958312988, 18.69200897216797, -2.6204581260681152, -5.14918851852417, -0.3582090139389038, 8.488558769226074, 4.98148775100708, -9.326835632324219, -2.2544219493865967, 6.641760349273682, 1.2119598388671875, 10.977124214172363, 16.555034637451172, 3.3238420486450195, 9.551861763000488, -1.6676981449127197, -0.7953944206237793, -8.605667114257812, -0.4735655188560486, 2.674196243286133, -5.359177112579346, -2.66738224029541, 0.6660683155059814, 15.44322681427002, 4.740593433380127, -3.472534418106079, 11.592567443847656, -2.0544962882995605, 1.736127495765686, -8.265326499938965, -9.30447769165039, 5.406829833984375, -1.518022894859314, -7.746612548828125, -6.089611053466797, 0.07112743705511093, -0.3490503430366516, -8.64989185333252, -9.998957633972168, -2.564845085144043, -0.5399947762489319, 2.6018123626708984, -0.3192799389362335, -1.8815255165100098, -2.0721492767333984, -3.410574436187744, -8.29980754852295, 1.483638048171997, -15.365986824035645, -8.288211822509766, 3.884779930114746, -3.4876468181610107, 7.362999439239502, 0.4657334089279175, 3.1326050758361816, 12.438895225524902, -1.8337041139602661, 4.532927989959717, 2.7264339923858643, 10.14534854888916, -6.521963596343994, 2.897155523300171, -3.392582654953003, 5.079153060913086, 7.7597246170043945, 4.677570819854736, 5.845779895782471, 2.402411460876465, 7.7071051597595215, 3.9711380004882812, -6.39003849029541, 6.12687873840332, -3.776029348373413, -11.118121147155762]}} - [2022-05-08 00:18:44,406] [ INFO] - Response time 0.156481 s. + ```bash + [2022-05-25 12:25:36,165] [ INFO] - vector http client start + [2022-05-25 12:25:36,165] [ INFO] - the input audio: 85236145389.wav + [2022-05-25 12:25:36,165] [ INFO] - endpoint: http://127.0.0.1:8790/paddlespeech/vector + [2022-05-25 12:25:36,166] [ INFO] - http://127.0.0.1:8790/paddlespeech/vector + [2022-05-25 12:25:36,324] [ INFO] - The vector: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'vec': [-1.3251205682754517, 7.860682487487793, -4.620625972747803, 0.3000721037387848, 2.2648534774780273, -1.1931440830230713, 3.064713716506958, 7.673594951629639, -6.004472732543945, -12.024259567260742, -1.9496068954467773, 3.126953601837158, 1.6188379526138306, -7.638310432434082, -1.2299772500991821, -12.33833122253418, 2.1373026371002197, -5.395712375640869, 9.717328071594238, 5.675230503082275, 3.7805123329162598, 3.0597171783447266, 3.429692029953003, 8.9760103225708, 13.174124717712402, -0.5313228368759155, 8.942471504211426, 4.465109825134277, -4.426247596740723, -9.726503372192383, 8.399328231811523, 7.223917484283447, -7.435853958129883, 2.9441683292388916, -4.343039512634277, -13.886964797973633, -1.6346734762191772, -10.902740478515625, -5.311244964599609, 3.800722122192383, 3.897603750228882, -2.123077392578125, -2.3521194458007812, 4.151031017303467, -7.404866695404053, 0.13911646604537964, 2.4626107215881348, 4.96645450592041, 0.9897574186325073, 5.483975410461426, -3.3574001789093018, 10.13400650024414, -0.6120170950889587, -10.403095245361328, 4.600754261016846, 16.009349822998047, -7.78369140625, -4.194530487060547, -6.93686056137085, 1.1789555549621582, 11.490800857543945, 4.23802375793457, 9.550930976867676, 8.375045776367188, 7.508914470672607, -0.6570729613304138, -0.3005157709121704, 2.8406054973602295, 3.0828027725219727, 0.7308170199394226, 6.1483540534973145, 0.1376611888408661, -13.424735069274902, -7.746140480041504, -2.322798252105713, -8.305252075195312, 2.98791241645813, -10.99522876739502, 0.15211068093776703, -2.3820347785949707, -1.7984174489974976, 8.49562931060791, -5.852236747741699, -3.755497932434082, 0.6989710927009583, -5.270299434661865, -2.6188621520996094, -1.8828465938568115, -4.6466498374938965, 14.078543663024902, -0.5495333075523376, 10.579157829284668, -3.216050148010254, 9.349003791809082, -4.381077766418457, -11.675816535949707, -2.863020658493042, 4.5721755027771, 2.246612071990967, -4.574341773986816, 1.8610187768936157, 2.3767874240875244, 5.625787734985352, -9.784077644348145, 0.6496725678443909, -1.457950472831726, 0.4263263940811157, -4.921126365661621, -2.4547839164733887, 3.4869801998138428, -0.4265422224998474, 8.341268539428711, 1.356552004814148, 7.096688270568848, -13.102828979492188, 8.01673412322998, -7.115934371948242, 1.8699780702590942, 0.20872099697589874, 14.699383735656738, -1.0252779722213745, -2.6107232570648193, -2.5082311630249023, 8.427192687988281, 6.913852691650391, -6.29124641418457, 0.6157366037368774, 2.489687919616699, -3.4668266773223877, 9.92176342010498, 11.200815200805664, -0.19664029777050018, 7.491600513458252, -0.6231271624565125, -0.2584814429283142, -9.947997093200684, -0.9611040949821472, 1.1649218797683716, -2.1907122135162354, -1.502848744392395, -0.5192610621452332, 15.165953636169434, 2.4649462699890137, -0.998044490814209, 7.44166374206543, -2.0768048763275146, 3.5896823406219482, -7.305543422698975, -7.562084674835205, 4.32333517074585, 0.08044180274009705, -6.564010143280029, -2.314805269241333, -1.7642345428466797, -2.470881700515747, -7.6756181716918945, -9.548877716064453, -1.017755389213562, 0.1698644608259201, 2.5877134799957275, -1.8752295970916748, -0.36614322662353516, -6.049378395080566, -2.3965611457824707, -5.945338726043701, 0.9424033164978027, -13.155974388122559, -7.45780086517334, 0.14658108353614807, -3.7427968978881836, 5.841492652893066, -1.2872905731201172, 5.569431304931641, 12.570590019226074, 1.0939218997955322, 2.2142086029052734, 1.9181575775146484, 6.991420745849609, -5.888138771057129, 3.1409823894500732, -2.0036280155181885, 2.4434285163879395, 9.973138809204102, 5.036680221557617, 2.005120277404785, 2.861560344696045, 5.860223770141602, 2.917618751525879, -1.63111412525177, 2.0292205810546875, -4.070415019989014, -6.831437110900879]}} + [2022-05-25 12:25:36,324] [ INFO] - Response time 0.159053 s. ``` * Python API -``` python -from paddlespeech.server.bin.paddlespeech_client import VectorClientExecutor + ``` python + from paddlespeech.server.bin.paddlespeech_client import VectorClientExecutor -vectorclient_executor = VectorClientExecutor() -res = vectorclient_executor( - input="85236145389.wav", - server_ip="127.0.0.1", - port=8090, - task="spk") -print(res) -``` + vectorclient_executor = VectorClientExecutor() + res = vectorclient_executor( + input="85236145389.wav", + server_ip="127.0.0.1", + port=8090, + task="spk") + print(res) + ``` -* Output: + Output: ``` bash - {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'vec': [1.421751856803894, 5.626245498657227, -5.342077255249023, 1.1773887872695923, 3.3080549240112305, 1.7565933465957642, 5.167886257171631, 10.806358337402344, -3.8226819038391113, -5.614140033721924, 2.6238479614257812, -0.8072972893714905, 1.9635076522827148, -7.312870025634766, 0.011035939678549767, -9.723129272460938, 0.6619706153869629, -6.976806163787842, 10.213476181030273, 7.494769096374512, 2.9105682373046875, 3.8949244022369385, 3.799983501434326, 7.106168746948242, 16.90532875061035, -7.149388313293457, 8.733108520507812, 3.423006296157837, -4.831653594970703, -11.403363227844238, 11.232224464416504, 7.127461910247803, -4.282842636108398, 2.452359437942505, -5.130749702453613, -18.17766761779785, -2.6116831302642822, -11.000344276428223, -6.731433391571045, 1.6564682722091675, 0.7618281245231628, 1.125300407409668, -2.0838370323181152, 4.725743293762207, -8.782588005065918, -3.5398752689361572, 3.8142364025115967, 5.142068862915039, 2.1620609760284424, 4.09643030166626, -6.416214942932129, 12.747446060180664, 1.9429892301559448, -15.15294361114502, 6.417416095733643, 16.09701156616211, -9.716667175292969, -1.9920575618743896, -3.36494779586792, -1.8719440698623657, 11.567351341247559, 3.6978814601898193, 11.258262634277344, 7.442368507385254, 9.183408737182617, 4.528149127960205, -1.2417854070663452, 4.395912170410156, 6.6727728843688965, 5.88988733291626, 7.627128601074219, -0.6691966652870178, -11.889698028564453, -9.20886516571045, -7.42740535736084, -3.777663230895996, 6.917238712310791, -9.848755836486816, -2.0944676399230957, -5.1351165771484375, 0.4956451654434204, 9.317537307739258, -5.914181232452393, -1.809860348701477, -0.11738915741443634, -7.1692705154418945, -1.057827353477478, -5.721670627593994, -5.117385387420654, 16.13765525817871, -4.473617076873779, 7.6624321937561035, -0.55381840467453, 9.631585121154785, -6.470459461212158, -8.548508644104004, 4.371616840362549, -0.7970245480537415, 4.4789886474609375, -2.975860834121704, 3.2721822261810303, 2.838287830352783, 5.134591102600098, -9.19079875946045, -0.5657302737236023, -4.8745832443237305, 2.3165574073791504, -5.984319686889648, -2.1798853874206543, 0.3554139733314514, -0.3178512752056122, 9.493552207946777, 2.1144471168518066, 4.358094692230225, -12.089824676513672, 8.451693534851074, -7.925466537475586, 4.624246597290039, 4.428936958312988, 18.69200897216797, -2.6204581260681152, -5.14918851852417, -0.3582090139389038, 8.488558769226074, 4.98148775100708, -9.326835632324219, -2.2544219493865967, 6.641760349273682, 1.2119598388671875, 10.977124214172363, 16.555034637451172, 3.3238420486450195, 9.551861763000488, -1.6676981449127197, -0.7953944206237793, -8.605667114257812, -0.4735655188560486, 2.674196243286133, -5.359177112579346, -2.66738224029541, 0.6660683155059814, 15.44322681427002, 4.740593433380127, -3.472534418106079, 11.592567443847656, -2.0544962882995605, 1.736127495765686, -8.265326499938965, -9.30447769165039, 5.406829833984375, -1.518022894859314, -7.746612548828125, -6.089611053466797, 0.07112743705511093, -0.3490503430366516, -8.64989185333252, -9.998957633972168, -2.564845085144043, -0.5399947762489319, 2.6018123626708984, -0.3192799389362335, -1.8815255165100098, -2.0721492767333984, -3.410574436187744, -8.29980754852295, 1.483638048171997, -15.365986824035645, -8.288211822509766, 3.884779930114746, -3.4876468181610107, 7.362999439239502, 0.4657334089279175, 3.1326050758361816, 12.438895225524902, -1.8337041139602661, 4.532927989959717, 2.7264339923858643, 10.14534854888916, -6.521963596343994, 2.897155523300171, -3.392582654953003, 5.079153060913086, 7.7597246170043945, 4.677570819854736, 5.845779895782471, 2.402411460876465, 7.7071051597595215, 3.9711380004882812, -6.39003849029541, 6.12687873840332, -3.776029348373413, -11.118121147155762]}} + {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'vec': [-1.3251205682754517, 7.860682487487793, -4.620625972747803, 0.3000721037387848, 2.2648534774780273, -1.1931440830230713, 3.064713716506958, 7.673594951629639, -6.004472732543945, -12.024259567260742, -1.9496068954467773, 3.126953601837158, 1.6188379526138306, -7.638310432434082, -1.2299772500991821, -12.33833122253418, 2.1373026371002197, -5.395712375640869, 9.717328071594238, 5.675230503082275, 3.7805123329162598, 3.0597171783447266, 3.429692029953003, 8.9760103225708, 13.174124717712402, -0.5313228368759155, 8.942471504211426, 4.465109825134277, -4.426247596740723, -9.726503372192383, 8.399328231811523, 7.223917484283447, -7.435853958129883, 2.9441683292388916, -4.343039512634277, -13.886964797973633, -1.6346734762191772, -10.902740478515625, -5.311244964599609, 3.800722122192383, 3.897603750228882, -2.123077392578125, -2.3521194458007812, 4.151031017303467, -7.404866695404053, 0.13911646604537964, 2.4626107215881348, 4.96645450592041, 0.9897574186325073, 5.483975410461426, -3.3574001789093018, 10.13400650024414, -0.6120170950889587, -10.403095245361328, 4.600754261016846, 16.009349822998047, -7.78369140625, -4.194530487060547, -6.93686056137085, 1.1789555549621582, 11.490800857543945, 4.23802375793457, 9.550930976867676, 8.375045776367188, 7.508914470672607, -0.6570729613304138, -0.3005157709121704, 2.8406054973602295, 3.0828027725219727, 0.7308170199394226, 6.1483540534973145, 0.1376611888408661, -13.424735069274902, -7.746140480041504, -2.322798252105713, -8.305252075195312, 2.98791241645813, -10.99522876739502, 0.15211068093776703, -2.3820347785949707, -1.7984174489974976, 8.49562931060791, -5.852236747741699, -3.755497932434082, 0.6989710927009583, -5.270299434661865, -2.6188621520996094, -1.8828465938568115, -4.6466498374938965, 14.078543663024902, -0.5495333075523376, 10.579157829284668, -3.216050148010254, 9.349003791809082, -4.381077766418457, -11.675816535949707, -2.863020658493042, 4.5721755027771, 2.246612071990967, -4.574341773986816, 1.8610187768936157, 2.3767874240875244, 5.625787734985352, -9.784077644348145, 0.6496725678443909, -1.457950472831726, 0.4263263940811157, -4.921126365661621, -2.4547839164733887, 3.4869801998138428, -0.4265422224998474, 8.341268539428711, 1.356552004814148, 7.096688270568848, -13.102828979492188, 8.01673412322998, -7.115934371948242, 1.8699780702590942, 0.20872099697589874, 14.699383735656738, -1.0252779722213745, -2.6107232570648193, -2.5082311630249023, 8.427192687988281, 6.913852691650391, -6.29124641418457, 0.6157366037368774, 2.489687919616699, -3.4668266773223877, 9.92176342010498, 11.200815200805664, -0.19664029777050018, 7.491600513458252, -0.6231271624565125, -0.2584814429283142, -9.947997093200684, -0.9611040949821472, 1.1649218797683716, -2.1907122135162354, -1.502848744392395, -0.5192610621452332, 15.165953636169434, 2.4649462699890137, -0.998044490814209, 7.44166374206543, -2.0768048763275146, 3.5896823406219482, -7.305543422698975, -7.562084674835205, 4.32333517074585, 0.08044180274009705, -6.564010143280029, -2.314805269241333, -1.7642345428466797, -2.470881700515747, -7.6756181716918945, -9.548877716064453, -1.017755389213562, 0.1698644608259201, 2.5877134799957275, -1.8752295970916748, -0.36614322662353516, -6.049378395080566, -2.3965611457824707, -5.945338726043701, 0.9424033164978027, -13.155974388122559, -7.45780086517334, 0.14658108353614807, -3.7427968978881836, 5.841492652893066, -1.2872905731201172, 5.569431304931641, 12.570590019226074, 1.0939218997955322, 2.2142086029052734, 1.9181575775146484, 6.991420745849609, -5.888138771057129, 3.1409823894500732, -2.0036280155181885, 2.4434285163879395, 9.973138809204102, 5.036680221557617, 2.005120277404785, 2.861560344696045, 5.860223770141602, 2.917618751525879, -1.63111412525177, 2.0292205810546875, -4.070415019989014, -6.831437110900879]}} ``` #### 7.2 Get the score between speaker audio embedding @@ -314,13 +314,13 @@ print(res) paddlespeech_client vector --task score --server_ip 127.0.0.1 --port 8090 --enroll 85236145389.wav --test 123456789.wav ``` - * Usage: + Usage: ``` bash paddlespeech_client vector --help ``` - * Arguments: + Arguments: * server_ip: server ip. Default: 127.0.0.1 * port: server port. Default: 8090 * input(required): Input text to generate. @@ -328,42 +328,42 @@ print(res) * enroll: enroll audio * test: test audio -* Output: - -``` bash - [2022-05-09 10:28:40,556] [ INFO] - vector score http client start - [2022-05-09 10:28:40,556] [ INFO] - enroll audio: 85236145389.wav, test audio: 123456789.wav - [2022-05-09 10:28:40,556] [ INFO] - endpoint: http://127.0.0.1:8090/paddlespeech/vector/score - [2022-05-09 10:28:40,731] [ INFO] - The vector score is: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.4292638897895813}} - [2022-05-09 10:28:40,731] [ INFO] - The vector: None - [2022-05-09 10:28:40,731] [ INFO] - Response time 0.175514 s. -``` + Output: + + ``` bash + [2022-05-25 12:33:24,527] [ INFO] - vector score http client start + [2022-05-25 12:33:24,527] [ INFO] - enroll audio: 85236145389.wav, test audio: 123456789.wav + [2022-05-25 12:33:24,528] [ INFO] - endpoint: http://127.0.0.1:8790/paddlespeech/vector/score + [2022-05-25 12:33:24,695] [ INFO] - The vector score is: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.45332613587379456}} + [2022-05-25 12:33:24,696] [ INFO] - The vector: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.45332613587379456}} + [2022-05-25 12:33:24,696] [ INFO] - Response time 0.168271 s. + ``` * Python API -``` python -from paddlespeech.server.bin.paddlespeech_client import VectorClientExecutor - -vectorclient_executor = VectorClientExecutor() -res = vectorclient_executor( - input=None, - enroll_audio="85236145389.wav", - test_audio="123456789.wav", - server_ip="127.0.0.1", - port=8090, - task="score") -print(res) -``` + ``` python + from paddlespeech.server.bin.paddlespeech_client import VectorClientExecutor -* Output: + vectorclient_executor = VectorClientExecutor() + res = vectorclient_executor( + input=None, + enroll_audio="85236145389.wav", + test_audio="123456789.wav", + server_ip="127.0.0.1", + port=8090, + task="score") + print(res) + ``` -``` bash -[2022-05-09 10:34:54,769] [ INFO] - vector score http client start -[2022-05-09 10:34:54,771] [ INFO] - enroll audio: 85236145389.wav, test audio: 123456789.wav -[2022-05-09 10:34:54,771] [ INFO] - endpoint: http://127.0.0.1:8090/paddlespeech/vector/score -[2022-05-09 10:34:55,026] [ INFO] - The vector score is: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.4292638897895813}} -``` + Output: + ``` bash + [2022-05-25 12:30:14,143] [ INFO] - vector score http client start + [2022-05-25 12:30:14,143] [ INFO] - enroll audio: 85236145389.wav, test audio: 123456789.wav + [2022-05-25 12:30:14,143] [ INFO] - endpoint: http://127.0.0.1:8790/paddlespeech/vector/score + [2022-05-25 12:30:14,363] [ INFO] - The vector score is: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.45332613587379456}} + {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.45332613587379456}} + ``` ### 8. Punctuation prediction @@ -382,7 +382,7 @@ print(res) ```bash paddlespeech_client text --help ``` - 参数: + Arguments: - `server_ip`: server ip. Default: 127.0.0.1 - `port`: server port. Default: 8090 - `input`(required): Input text to get punctuation. diff --git a/demos/speech_server/README_cn.md b/demos/speech_server/README_cn.md index 51b6caa40..29629b7e8 100644 --- a/demos/speech_server/README_cn.md +++ b/demos/speech_server/README_cn.md @@ -3,7 +3,7 @@ # 语音服务 ## 介绍 -这个demo是一个启动离线语音服务和访问服务的实现。它可以通过使用`paddlespeech_server` 和 `paddlespeech_client`的单个命令或 python 的几行代码来实现。 +这个 demo 是一个启动离线语音服务和访问服务的实现。它可以通过使用`paddlespeech_server` 和 `paddlespeech_client`的单个命令或 python 的几行代码来实现。 ## 使用方法 @@ -24,7 +24,7 @@ ASR client 的输入是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。 -可以下载此 ASR client的示例音频: +可以下载此 ASR client 的示例音频: ```bash wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav ``` @@ -99,7 +99,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ``` 参数: - - `server_ip`: 服务端ip地址,默认: 127.0.0.1。 + - `server_ip`: 服务端 ip 地址,默认: 127.0.0.1。 - `port`: 服务端口,默认: 8090。 - `input`(必须输入): 用于识别的音频文件。 - `sample_rate`: 音频采样率,默认值:16000。 @@ -198,10 +198,11 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ``` - ### 6. CLS 客户端使用方法 +### 6. CLS 客户端使用方法 - **注意:** 初次使用客户端时响应时间会略长 - - 命令行 (推荐使用) +**注意:** 初次使用客户端时响应时间会略长 + +- 命令行 (推荐使用) 若 `127.0.0.1` 不能访问,则需要使用实际服务 IP 地址 @@ -215,7 +216,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee paddlespeech_client cls --help ``` 参数: - - `server_ip`: 服务端ip地址,默认: 127.0.0.1。 + - `server_ip`: 服务端 ip 地址,默认: 127.0.0.1。 - `port`: 服务端口,默认: 8090。 - `input`(必须输入): 用于分类的音频文件。 - `topk`: 分类结果的topk。 @@ -261,48 +262,48 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee paddlespeech_client vector --task spk --server_ip 127.0.0.1 --port 8090 --input 85236145389.wav ``` -* 使用帮助: + 使用帮助: ``` bash paddlespeech_client vector --help ``` -* 参数: + 参数: * server_ip: 服务端ip地址,默认: 127.0.0.1。 * port: 服务端口,默认: 8090。 * input(必须输入): 用于识别的音频文件。 * task: vector 的任务,可选spk或者score。默认是 spk。 * enroll: 注册音频;。 * test: 测试音频。 -* 输出: - -``` bash - [2022-05-08 00:18:44,249] [ INFO] - vector http client start - [2022-05-08 00:18:44,250] [ INFO] - the input audio: 85236145389.wav - [2022-05-08 00:18:44,250] [ INFO] - endpoint: http://127.0.0.1:8090/paddlespeech/vector - [2022-05-08 00:18:44,250] [ INFO] - http://127.0.0.1:8590/paddlespeech/vector - [2022-05-08 00:18:44,406] [ INFO] - The vector: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'vec': [1.421751856803894, 5.626245498657227, -5.342077255249023, 1.1773887872695923, 3.3080549240112305, 1.7565933465957642, 5.167886257171631, 10.806358337402344, -3.8226819038391113, -5.614140033721924, 2.6238479614257812, -0.8072972893714905, 1.9635076522827148, -7.312870025634766, 0.011035939678549767, -9.723129272460938, 0.6619706153869629, -6.976806163787842, 10.213476181030273, 7.494769096374512, 2.9105682373046875, 3.8949244022369385, 3.799983501434326, 7.106168746948242, 16.90532875061035, -7.149388313293457, 8.733108520507812, 3.423006296157837, -4.831653594970703, -11.403363227844238, 11.232224464416504, 7.127461910247803, -4.282842636108398, 2.452359437942505, -5.130749702453613, -18.17766761779785, -2.6116831302642822, -11.000344276428223, -6.731433391571045, 1.6564682722091675, 0.7618281245231628, 1.125300407409668, -2.0838370323181152, 4.725743293762207, -8.782588005065918, -3.5398752689361572, 3.8142364025115967, 5.142068862915039, 2.1620609760284424, 4.09643030166626, -6.416214942932129, 12.747446060180664, 1.9429892301559448, -15.15294361114502, 6.417416095733643, 16.09701156616211, -9.716667175292969, -1.9920575618743896, -3.36494779586792, -1.8719440698623657, 11.567351341247559, 3.6978814601898193, 11.258262634277344, 7.442368507385254, 9.183408737182617, 4.528149127960205, -1.2417854070663452, 4.395912170410156, 6.6727728843688965, 5.88988733291626, 7.627128601074219, -0.6691966652870178, -11.889698028564453, -9.20886516571045, -7.42740535736084, -3.777663230895996, 6.917238712310791, -9.848755836486816, -2.0944676399230957, -5.1351165771484375, 0.4956451654434204, 9.317537307739258, -5.914181232452393, -1.809860348701477, -0.11738915741443634, -7.1692705154418945, -1.057827353477478, -5.721670627593994, -5.117385387420654, 16.13765525817871, -4.473617076873779, 7.6624321937561035, -0.55381840467453, 9.631585121154785, -6.470459461212158, -8.548508644104004, 4.371616840362549, -0.7970245480537415, 4.4789886474609375, -2.975860834121704, 3.2721822261810303, 2.838287830352783, 5.134591102600098, -9.19079875946045, -0.5657302737236023, -4.8745832443237305, 2.3165574073791504, -5.984319686889648, -2.1798853874206543, 0.3554139733314514, -0.3178512752056122, 9.493552207946777, 2.1144471168518066, 4.358094692230225, -12.089824676513672, 8.451693534851074, -7.925466537475586, 4.624246597290039, 4.428936958312988, 18.69200897216797, -2.6204581260681152, -5.14918851852417, -0.3582090139389038, 8.488558769226074, 4.98148775100708, -9.326835632324219, -2.2544219493865967, 6.641760349273682, 1.2119598388671875, 10.977124214172363, 16.555034637451172, 3.3238420486450195, 9.551861763000488, -1.6676981449127197, -0.7953944206237793, -8.605667114257812, -0.4735655188560486, 2.674196243286133, -5.359177112579346, -2.66738224029541, 0.6660683155059814, 15.44322681427002, 4.740593433380127, -3.472534418106079, 11.592567443847656, -2.0544962882995605, 1.736127495765686, -8.265326499938965, -9.30447769165039, 5.406829833984375, -1.518022894859314, -7.746612548828125, -6.089611053466797, 0.07112743705511093, -0.3490503430366516, -8.64989185333252, -9.998957633972168, -2.564845085144043, -0.5399947762489319, 2.6018123626708984, -0.3192799389362335, -1.8815255165100098, -2.0721492767333984, -3.410574436187744, -8.29980754852295, 1.483638048171997, -15.365986824035645, -8.288211822509766, 3.884779930114746, -3.4876468181610107, 7.362999439239502, 0.4657334089279175, 3.1326050758361816, 12.438895225524902, -1.8337041139602661, 4.532927989959717, 2.7264339923858643, 10.14534854888916, -6.521963596343994, 2.897155523300171, -3.392582654953003, 5.079153060913086, 7.7597246170043945, 4.677570819854736, 5.845779895782471, 2.402411460876465, 7.7071051597595215, 3.9711380004882812, -6.39003849029541, 6.12687873840332, -3.776029348373413, -11.118121147155762]}} - [2022-05-08 00:18:44,406] [ INFO] - Response time 0.156481 s. -``` + 输出: + + ``` bash + [2022-05-25 12:25:36,165] [ INFO] - vector http client start + [2022-05-25 12:25:36,165] [ INFO] - the input audio: 85236145389.wav + [2022-05-25 12:25:36,165] [ INFO] - endpoint: http://127.0.0.1:8790/paddlespeech/vector + [2022-05-25 12:25:36,166] [ INFO] - http://127.0.0.1:8790/paddlespeech/vector + [2022-05-25 12:25:36,324] [ INFO] - The vector: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'vec': [-1.3251205682754517, 7.860682487487793, -4.620625972747803, 0.3000721037387848, 2.2648534774780273, -1.1931440830230713, 3.064713716506958, 7.673594951629639, -6.004472732543945, -12.024259567260742, -1.9496068954467773, 3.126953601837158, 1.6188379526138306, -7.638310432434082, -1.2299772500991821, -12.33833122253418, 2.1373026371002197, -5.395712375640869, 9.717328071594238, 5.675230503082275, 3.7805123329162598, 3.0597171783447266, 3.429692029953003, 8.9760103225708, 13.174124717712402, -0.5313228368759155, 8.942471504211426, 4.465109825134277, -4.426247596740723, -9.726503372192383, 8.399328231811523, 7.223917484283447, -7.435853958129883, 2.9441683292388916, -4.343039512634277, -13.886964797973633, -1.6346734762191772, -10.902740478515625, -5.311244964599609, 3.800722122192383, 3.897603750228882, -2.123077392578125, -2.3521194458007812, 4.151031017303467, -7.404866695404053, 0.13911646604537964, 2.4626107215881348, 4.96645450592041, 0.9897574186325073, 5.483975410461426, -3.3574001789093018, 10.13400650024414, -0.6120170950889587, -10.403095245361328, 4.600754261016846, 16.009349822998047, -7.78369140625, -4.194530487060547, -6.93686056137085, 1.1789555549621582, 11.490800857543945, 4.23802375793457, 9.550930976867676, 8.375045776367188, 7.508914470672607, -0.6570729613304138, -0.3005157709121704, 2.8406054973602295, 3.0828027725219727, 0.7308170199394226, 6.1483540534973145, 0.1376611888408661, -13.424735069274902, -7.746140480041504, -2.322798252105713, -8.305252075195312, 2.98791241645813, -10.99522876739502, 0.15211068093776703, -2.3820347785949707, -1.7984174489974976, 8.49562931060791, -5.852236747741699, -3.755497932434082, 0.6989710927009583, -5.270299434661865, -2.6188621520996094, -1.8828465938568115, -4.6466498374938965, 14.078543663024902, -0.5495333075523376, 10.579157829284668, -3.216050148010254, 9.349003791809082, -4.381077766418457, -11.675816535949707, -2.863020658493042, 4.5721755027771, 2.246612071990967, -4.574341773986816, 1.8610187768936157, 2.3767874240875244, 5.625787734985352, -9.784077644348145, 0.6496725678443909, -1.457950472831726, 0.4263263940811157, -4.921126365661621, -2.4547839164733887, 3.4869801998138428, -0.4265422224998474, 8.341268539428711, 1.356552004814148, 7.096688270568848, -13.102828979492188, 8.01673412322998, -7.115934371948242, 1.8699780702590942, 0.20872099697589874, 14.699383735656738, -1.0252779722213745, -2.6107232570648193, -2.5082311630249023, 8.427192687988281, 6.913852691650391, -6.29124641418457, 0.6157366037368774, 2.489687919616699, -3.4668266773223877, 9.92176342010498, 11.200815200805664, -0.19664029777050018, 7.491600513458252, -0.6231271624565125, -0.2584814429283142, -9.947997093200684, -0.9611040949821472, 1.1649218797683716, -2.1907122135162354, -1.502848744392395, -0.5192610621452332, 15.165953636169434, 2.4649462699890137, -0.998044490814209, 7.44166374206543, -2.0768048763275146, 3.5896823406219482, -7.305543422698975, -7.562084674835205, 4.32333517074585, 0.08044180274009705, -6.564010143280029, -2.314805269241333, -1.7642345428466797, -2.470881700515747, -7.6756181716918945, -9.548877716064453, -1.017755389213562, 0.1698644608259201, 2.5877134799957275, -1.8752295970916748, -0.36614322662353516, -6.049378395080566, -2.3965611457824707, -5.945338726043701, 0.9424033164978027, -13.155974388122559, -7.45780086517334, 0.14658108353614807, -3.7427968978881836, 5.841492652893066, -1.2872905731201172, 5.569431304931641, 12.570590019226074, 1.0939218997955322, 2.2142086029052734, 1.9181575775146484, 6.991420745849609, -5.888138771057129, 3.1409823894500732, -2.0036280155181885, 2.4434285163879395, 9.973138809204102, 5.036680221557617, 2.005120277404785, 2.861560344696045, 5.860223770141602, 2.917618751525879, -1.63111412525177, 2.0292205810546875, -4.070415019989014, -6.831437110900879]}} + [2022-05-25 12:25:36,324] [ INFO] - Response time 0.159053 s. + ``` * Python API -``` python -from paddlespeech.server.bin.paddlespeech_client import VectorClientExecutor + ``` python + from paddlespeech.server.bin.paddlespeech_client import VectorClientExecutor -vectorclient_executor = VectorClientExecutor() -res = vectorclient_executor( - input="85236145389.wav", - server_ip="127.0.0.1", - port=8090, - task="spk") -print(res) -``` + vectorclient_executor = VectorClientExecutor() + res = vectorclient_executor( + input="85236145389.wav", + server_ip="127.0.0.1", + port=8090, + task="spk") + print(res) + ``` -* 输出: + 输出: -``` bash - {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'vec': [1.421751856803894, 5.626245498657227, -5.342077255249023, 1.1773887872695923, 3.3080549240112305, 1.7565933465957642, 5.167886257171631, 10.806358337402344, -3.8226819038391113, -5.614140033721924, 2.6238479614257812, -0.8072972893714905, 1.9635076522827148, -7.312870025634766, 0.011035939678549767, -9.723129272460938, 0.6619706153869629, -6.976806163787842, 10.213476181030273, 7.494769096374512, 2.9105682373046875, 3.8949244022369385, 3.799983501434326, 7.106168746948242, 16.90532875061035, -7.149388313293457, 8.733108520507812, 3.423006296157837, -4.831653594970703, -11.403363227844238, 11.232224464416504, 7.127461910247803, -4.282842636108398, 2.452359437942505, -5.130749702453613, -18.17766761779785, -2.6116831302642822, -11.000344276428223, -6.731433391571045, 1.6564682722091675, 0.7618281245231628, 1.125300407409668, -2.0838370323181152, 4.725743293762207, -8.782588005065918, -3.5398752689361572, 3.8142364025115967, 5.142068862915039, 2.1620609760284424, 4.09643030166626, -6.416214942932129, 12.747446060180664, 1.9429892301559448, -15.15294361114502, 6.417416095733643, 16.09701156616211, -9.716667175292969, -1.9920575618743896, -3.36494779586792, -1.8719440698623657, 11.567351341247559, 3.6978814601898193, 11.258262634277344, 7.442368507385254, 9.183408737182617, 4.528149127960205, -1.2417854070663452, 4.395912170410156, 6.6727728843688965, 5.88988733291626, 7.627128601074219, -0.6691966652870178, -11.889698028564453, -9.20886516571045, -7.42740535736084, -3.777663230895996, 6.917238712310791, -9.848755836486816, -2.0944676399230957, -5.1351165771484375, 0.4956451654434204, 9.317537307739258, -5.914181232452393, -1.809860348701477, -0.11738915741443634, -7.1692705154418945, -1.057827353477478, -5.721670627593994, -5.117385387420654, 16.13765525817871, -4.473617076873779, 7.6624321937561035, -0.55381840467453, 9.631585121154785, -6.470459461212158, -8.548508644104004, 4.371616840362549, -0.7970245480537415, 4.4789886474609375, -2.975860834121704, 3.2721822261810303, 2.838287830352783, 5.134591102600098, -9.19079875946045, -0.5657302737236023, -4.8745832443237305, 2.3165574073791504, -5.984319686889648, -2.1798853874206543, 0.3554139733314514, -0.3178512752056122, 9.493552207946777, 2.1144471168518066, 4.358094692230225, -12.089824676513672, 8.451693534851074, -7.925466537475586, 4.624246597290039, 4.428936958312988, 18.69200897216797, -2.6204581260681152, -5.14918851852417, -0.3582090139389038, 8.488558769226074, 4.98148775100708, -9.326835632324219, -2.2544219493865967, 6.641760349273682, 1.2119598388671875, 10.977124214172363, 16.555034637451172, 3.3238420486450195, 9.551861763000488, -1.6676981449127197, -0.7953944206237793, -8.605667114257812, -0.4735655188560486, 2.674196243286133, -5.359177112579346, -2.66738224029541, 0.6660683155059814, 15.44322681427002, 4.740593433380127, -3.472534418106079, 11.592567443847656, -2.0544962882995605, 1.736127495765686, -8.265326499938965, -9.30447769165039, 5.406829833984375, -1.518022894859314, -7.746612548828125, -6.089611053466797, 0.07112743705511093, -0.3490503430366516, -8.64989185333252, -9.998957633972168, -2.564845085144043, -0.5399947762489319, 2.6018123626708984, -0.3192799389362335, -1.8815255165100098, -2.0721492767333984, -3.410574436187744, -8.29980754852295, 1.483638048171997, -15.365986824035645, -8.288211822509766, 3.884779930114746, -3.4876468181610107, 7.362999439239502, 0.4657334089279175, 3.1326050758361816, 12.438895225524902, -1.8337041139602661, 4.532927989959717, 2.7264339923858643, 10.14534854888916, -6.521963596343994, 2.897155523300171, -3.392582654953003, 5.079153060913086, 7.7597246170043945, 4.677570819854736, 5.845779895782471, 2.402411460876465, 7.7071051597595215, 3.9711380004882812, -6.39003849029541, 6.12687873840332, -3.776029348373413, -11.118121147155762]}} -``` + ``` bash + {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'vec': [-1.3251205682754517, 7.860682487487793, -4.620625972747803, 0.3000721037387848, 2.2648534774780273, -1.1931440830230713, 3.064713716506958, 7.673594951629639, -6.004472732543945, -12.024259567260742, -1.9496068954467773, 3.126953601837158, 1.6188379526138306, -7.638310432434082, -1.2299772500991821, -12.33833122253418, 2.1373026371002197, -5.395712375640869, 9.717328071594238, 5.675230503082275, 3.7805123329162598, 3.0597171783447266, 3.429692029953003, 8.9760103225708, 13.174124717712402, -0.5313228368759155, 8.942471504211426, 4.465109825134277, -4.426247596740723, -9.726503372192383, 8.399328231811523, 7.223917484283447, -7.435853958129883, 2.9441683292388916, -4.343039512634277, -13.886964797973633, -1.6346734762191772, -10.902740478515625, -5.311244964599609, 3.800722122192383, 3.897603750228882, -2.123077392578125, -2.3521194458007812, 4.151031017303467, -7.404866695404053, 0.13911646604537964, 2.4626107215881348, 4.96645450592041, 0.9897574186325073, 5.483975410461426, -3.3574001789093018, 10.13400650024414, -0.6120170950889587, -10.403095245361328, 4.600754261016846, 16.009349822998047, -7.78369140625, -4.194530487060547, -6.93686056137085, 1.1789555549621582, 11.490800857543945, 4.23802375793457, 9.550930976867676, 8.375045776367188, 7.508914470672607, -0.6570729613304138, -0.3005157709121704, 2.8406054973602295, 3.0828027725219727, 0.7308170199394226, 6.1483540534973145, 0.1376611888408661, -13.424735069274902, -7.746140480041504, -2.322798252105713, -8.305252075195312, 2.98791241645813, -10.99522876739502, 0.15211068093776703, -2.3820347785949707, -1.7984174489974976, 8.49562931060791, -5.852236747741699, -3.755497932434082, 0.6989710927009583, -5.270299434661865, -2.6188621520996094, -1.8828465938568115, -4.6466498374938965, 14.078543663024902, -0.5495333075523376, 10.579157829284668, -3.216050148010254, 9.349003791809082, -4.381077766418457, -11.675816535949707, -2.863020658493042, 4.5721755027771, 2.246612071990967, -4.574341773986816, 1.8610187768936157, 2.3767874240875244, 5.625787734985352, -9.784077644348145, 0.6496725678443909, -1.457950472831726, 0.4263263940811157, -4.921126365661621, -2.4547839164733887, 3.4869801998138428, -0.4265422224998474, 8.341268539428711, 1.356552004814148, 7.096688270568848, -13.102828979492188, 8.01673412322998, -7.115934371948242, 1.8699780702590942, 0.20872099697589874, 14.699383735656738, -1.0252779722213745, -2.6107232570648193, -2.5082311630249023, 8.427192687988281, 6.913852691650391, -6.29124641418457, 0.6157366037368774, 2.489687919616699, -3.4668266773223877, 9.92176342010498, 11.200815200805664, -0.19664029777050018, 7.491600513458252, -0.6231271624565125, -0.2584814429283142, -9.947997093200684, -0.9611040949821472, 1.1649218797683716, -2.1907122135162354, -1.502848744392395, -0.5192610621452332, 15.165953636169434, 2.4649462699890137, -0.998044490814209, 7.44166374206543, -2.0768048763275146, 3.5896823406219482, -7.305543422698975, -7.562084674835205, 4.32333517074585, 0.08044180274009705, -6.564010143280029, -2.314805269241333, -1.7642345428466797, -2.470881700515747, -7.6756181716918945, -9.548877716064453, -1.017755389213562, 0.1698644608259201, 2.5877134799957275, -1.8752295970916748, -0.36614322662353516, -6.049378395080566, -2.3965611457824707, -5.945338726043701, 0.9424033164978027, -13.155974388122559, -7.45780086517334, 0.14658108353614807, -3.7427968978881836, 5.841492652893066, -1.2872905731201172, 5.569431304931641, 12.570590019226074, 1.0939218997955322, 2.2142086029052734, 1.9181575775146484, 6.991420745849609, -5.888138771057129, 3.1409823894500732, -2.0036280155181885, 2.4434285163879395, 9.973138809204102, 5.036680221557617, 2.005120277404785, 2.861560344696045, 5.860223770141602, 2.917618751525879, -1.63111412525177, 2.0292205810546875, -4.070415019989014, -6.831437110900879]}} + ``` #### 7.2 音频声纹打分 @@ -315,60 +316,62 @@ print(res) paddlespeech_client vector --task score --server_ip 127.0.0.1 --port 8090 --enroll 85236145389.wav --test 123456789.wav ``` -* 使用帮助: + 使用帮助: ``` bash paddlespeech_client vector --help ``` -* 参数: + 参数: * server_ip: 服务端ip地址,默认: 127.0.0.1。 * port: 服务端口,默认: 8090。 * input(必须输入): 用于识别的音频文件。 * task: vector 的任务,可选spk或者score。默认是 spk。 * enroll: 注册音频;。 * test: 测试音频。 -* 输出: - -``` bash - [2022-05-09 10:28:40,556] [ INFO] - vector score http client start - [2022-05-09 10:28:40,556] [ INFO] - enroll audio: 85236145389.wav, test audio: 123456789.wav - [2022-05-09 10:28:40,556] [ INFO] - endpoint: http://127.0.0.1:8090/paddlespeech/vector/score - [2022-05-09 10:28:40,731] [ INFO] - The vector score is: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.4292638897895813}} - [2022-05-09 10:28:40,731] [ INFO] - The vector: None - [2022-05-09 10:28:40,731] [ INFO] - Response time 0.175514 s. -``` + + 输出: + + ``` bash + [2022-05-25 12:33:24,527] [ INFO] - vector score http client start + [2022-05-25 12:33:24,527] [ INFO] - enroll audio: 85236145389.wav, test audio: 123456789.wav + [2022-05-25 12:33:24,528] [ INFO] - endpoint: http://127.0.0.1:8790/paddlespeech/vector/score + [2022-05-25 12:33:24,695] [ INFO] - The vector score is: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.45332613587379456}} + [2022-05-25 12:33:24,696] [ INFO] - The vector: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.45332613587379456}} + [2022-05-25 12:33:24,696] [ INFO] - Response time 0.168271 s. + ``` * Python API -``` python -from paddlespeech.server.bin.paddlespeech_client import VectorClientExecutor - -vectorclient_executor = VectorClientExecutor() -res = vectorclient_executor( - input=None, - enroll_audio="85236145389.wav", - test_audio="123456789.wav", - server_ip="127.0.0.1", - port=8090, - task="score") -print(res) -``` + ``` python + from paddlespeech.server.bin.paddlespeech_client import VectorClientExecutor -* 输出: + vectorclient_executor = VectorClientExecutor() + res = vectorclient_executor( + input=None, + enroll_audio="85236145389.wav", + test_audio="123456789.wav", + server_ip="127.0.0.1", + port=8090, + task="score") + print(res) + ``` -``` bash -[2022-05-09 10:34:54,769] [ INFO] - vector score http client start -[2022-05-09 10:34:54,771] [ INFO] - enroll audio: 85236145389.wav, test audio: 123456789.wav -[2022-05-09 10:34:54,771] [ INFO] - endpoint: http://127.0.0.1:8590/paddlespeech/vector/score -[2022-05-09 10:34:55,026] [ INFO] - The vector score is: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.4292638897895813}} -``` + 输出: + + ``` bash + [2022-05-25 12:30:14,143] [ INFO] - vector score http client start + [2022-05-25 12:30:14,143] [ INFO] - enroll audio: 85236145389.wav, test audio: 123456789.wav + [2022-05-25 12:30:14,143] [ INFO] - endpoint: http://127.0.0.1:8790/paddlespeech/vector/score + [2022-05-25 12:30:14,363] [ INFO] - The vector score is: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.45332613587379456}} + {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.45332613587379456}} + ``` ### 8. 标点预测 **注意:** 初次使用客户端时响应时间会略长 - - 命令行 (推荐使用) +- 命令行 (推荐使用) 若 `127.0.0.1` 不能访问,则需要使用实际服务 IP 地址 @@ -411,17 +414,17 @@ print(res) ``` ## 服务支持的模型 -### ASR支持的模型 -通过 `paddlespeech_server stats --task asr` 获取ASR服务支持的所有模型,其中静态模型可用于 paddle inference 推理。 +### ASR 支持的模型 +通过 `paddlespeech_server stats --task asr` 获取 ASR 服务支持的所有模型,其中静态模型可用于 paddle inference 推理。 -### TTS支持的模型 -通过 `paddlespeech_server stats --task tts` 获取TTS服务支持的所有模型,其中静态模型可用于 paddle inference 推理。 +### TTS 支持的模型 +通过 `paddlespeech_server stats --task tts` 获取 TTS 服务支持的所有模型,其中静态模型可用于 paddle inference 推理。 -### CLS支持的模型 -通过 `paddlespeech_server stats --task cls` 获取CLS服务支持的所有模型,其中静态模型可用于 paddle inference 推理。 +### CLS 支持的模型 +通过 `paddlespeech_server stats --task cls` 获取 CLS 服务支持的所有模型,其中静态模型可用于 paddle inference 推理。 -### Vector支持的模型 -通过 `paddlespeech_server stats --task vector` 获取Vector服务支持的所有模型。 +### Vector 支持的模型 +通过 `paddlespeech_server stats --task vector` 获取 Vector 服务支持的所有模型。 ### Text支持的模型 -通过 `paddlespeech_server stats --task text` 获取Text服务支持的所有模型。 +通过 `paddlespeech_server stats --task text` 获取 Text 服务支持的所有模型。 diff --git a/demos/speech_server/start_multi_progress_server.py b/demos/speech_server/start_multi_progress_server.py new file mode 100644 index 000000000..5e86befb7 --- /dev/null +++ b/demos/speech_server/start_multi_progress_server.py @@ -0,0 +1,70 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import warnings + +import uvicorn +from fastapi import FastAPI +from starlette.middleware.cors import CORSMiddleware + +from paddlespeech.server.engine.engine_pool import init_engine_pool +from paddlespeech.server.restful.api import setup_router as setup_http_router +from paddlespeech.server.utils.config import get_config +from paddlespeech.server.ws.api import setup_router as setup_ws_router +warnings.filterwarnings("ignore") +import sys + +app = FastAPI( + title="PaddleSpeech Serving API", description="Api", version="0.0.1") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"]) + +# change yaml file here +config_file = "./conf/application.yaml" +config = get_config(config_file) + +# init engine +if not init_engine_pool(config): + print("Failed to init engine.") + sys.exit(-1) + +# get api_router +api_list = list(engine.split("_")[0] for engine in config.engine_list) +if config.protocol == "websocket": + api_router = setup_ws_router(api_list) +elif config.protocol == "http": + api_router = setup_http_router(api_list) +else: + raise Exception("unsupported protocol") + sys.exit(-1) + +# app needs to operate outside the main function +app.include_router(api_router) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(add_help=True) + parser.add_argument( + "--workers", type=int, help="workers of server", default=1) + args = parser.parse_args() + + uvicorn.run( + "start_multi_progress_server:app", + host=config.host, + port=config.port, + debug=True, + workers=args.workers) diff --git a/demos/speech_translation/README.md b/demos/speech_translation/README.md index f675a4eda..00a9c7932 100644 --- a/demos/speech_translation/README.md +++ b/demos/speech_translation/README.md @@ -47,7 +47,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee - Python API ```python import paddle - from paddlespeech.cli import STExecutor + from paddlespeech.cli.st import STExecutor st_executor = STExecutor() text = st_executor( diff --git a/demos/speech_translation/README_cn.md b/demos/speech_translation/README_cn.md index bad9b392f..5119bf9f4 100644 --- a/demos/speech_translation/README_cn.md +++ b/demos/speech_translation/README_cn.md @@ -47,7 +47,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee - Python API ```python import paddle - from paddlespeech.cli import STExecutor + from paddlespeech.cli.st import STExecutor st_executor = STExecutor() text = st_executor( diff --git a/demos/streaming_asr_server/README.md b/demos/streaming_asr_server/README.md index 4824da628..a770f58c3 100644 --- a/demos/streaming_asr_server/README.md +++ b/demos/streaming_asr_server/README.md @@ -33,6 +33,8 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav ```bash # in PaddleSpeech/demos/streaming_asr_server start the service paddlespeech_server start --config_file ./conf/ws_conformer_wenetspeech_application.yaml + # if you want to increase decoding speed, you can use the config file below, it will increase decoding speed and reduce accuracy + paddlespeech_server start --config_file ./conf/ws_conformer_wenetspeech_application_faster.yaml ``` Usage: diff --git a/demos/streaming_asr_server/README_cn.md b/demos/streaming_asr_server/README_cn.md index 4ed15e17e..c771869e9 100644 --- a/demos/streaming_asr_server/README_cn.md +++ b/demos/streaming_asr_server/README_cn.md @@ -21,7 +21,7 @@ 下载好 `PaddleSpeech` 之后,进入到 `PaddleSpeech/demos/streaming_asr_server` 目录。 配置文件可参见该目录下 `conf/ws_application.yaml` 和 `conf/ws_conformer_wenetspeech_application.yaml` 。 -目前服务集成的模型有: DeepSpeech2和 conformer模型,对应的配置文件如下: +目前服务集成的模型有: DeepSpeech2 和 conformer模型,对应的配置文件如下: * DeepSpeech: `conf/ws_application.yaml` * conformer: `conf/ws_conformer_wenetspeech_application.yaml` @@ -40,6 +40,8 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav ```bash # 在 PaddleSpeech/demos/streaming_asr_server 目录启动服务 paddlespeech_server start --config_file ./conf/ws_conformer_wenetspeech_application.yaml + # 你如果愿意为了增加解码的速度而牺牲一定的模型精度,你可以使用如下的脚本 + paddlespeech_server start --config_file ./conf/ws_conformer_wenetspeech_application_faster.yaml ``` 使用方法: diff --git a/demos/streaming_asr_server/conf/application.yaml b/demos/streaming_asr_server/conf/application.yaml index e9a89c19d..683d86f03 100644 --- a/demos/streaming_asr_server/conf/application.yaml +++ b/demos/streaming_asr_server/conf/application.yaml @@ -31,6 +31,8 @@ asr_online: force_yes: True device: 'cpu' # cpu or gpu:id decode_method: "attention_rescoring" + continuous_decoding: True # enable continue decoding when endpoint detected + am_predictor_conf: device: # set 'gpu:id' or 'cpu' switch_ir_optim: True diff --git a/demos/streaming_asr_server/conf/ws_conformer_application.yaml b/demos/streaming_asr_server/conf/ws_conformer_application.yaml index 2affde073..01bb1e9c9 100644 --- a/demos/streaming_asr_server/conf/ws_conformer_application.yaml +++ b/demos/streaming_asr_server/conf/ws_conformer_application.yaml @@ -4,7 +4,7 @@ # SERVER SETTING # ################################################################################# host: 0.0.0.0 -port: 8090 +port: 8091 # The task format in the engin_list is: _ # task choices = ['asr_online'] @@ -28,8 +28,12 @@ asr_online: sample_rate: 16000 cfg_path: decode_method: + num_decoding_left_chunks: -1 force_yes: True device: 'cpu' # cpu or gpu:id + decode_method: "attention_rescoring" + continuous_decoding: True # enable continue decoding when endpoint detected + am_predictor_conf: device: # set 'gpu:id' or 'cpu' switch_ir_optim: True diff --git a/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml b/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml index e9a89c19d..d30bcd025 100644 --- a/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml +++ b/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml @@ -31,6 +31,8 @@ asr_online: force_yes: True device: 'cpu' # cpu or gpu:id decode_method: "attention_rescoring" + continuous_decoding: True # enable continue decoding when endpoint detected + num_decoding_left_chunks: -1 am_predictor_conf: device: # set 'gpu:id' or 'cpu' switch_ir_optim: True diff --git a/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application_faster.yaml b/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application_faster.yaml new file mode 100644 index 000000000..ba413c802 --- /dev/null +++ b/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application_faster.yaml @@ -0,0 +1,48 @@ +# This is the parameter configuration file for PaddleSpeech Serving. + +################################################################################# +# SERVER SETTING # +################################################################################# +host: 0.0.0.0 +port: 8090 + +# The task format in the engin_list is: _ +# task choices = ['asr_online'] +# protocol = ['websocket'] (only one can be selected). +# websocket only support online engine type. +protocol: 'websocket' +engine_list: ['asr_online'] + + +################################################################################# +# ENGINE CONFIG # +################################################################################# + +################################### ASR ######################################### +################### speech task: asr; engine_type: online ####################### +asr_online: + model_type: 'conformer_online_wenetspeech' + am_model: # the pdmodel file of am static model [optional] + am_params: # the pdiparams file of am static model [optional] + lang: 'zh' + sample_rate: 16000 + cfg_path: + decode_method: + force_yes: True + device: 'cpu' # cpu or gpu:id + decode_method: "attention_rescoring" + continuous_decoding: True # enable continue decoding when endpoint detected + num_decoding_left_chunks: 16 + am_predictor_conf: + device: # set 'gpu:id' or 'cpu' + switch_ir_optim: True + glog_info: False # True -> print glog + summary: True # False -> do not show predictor config + + chunk_buffer_conf: + window_n: 7 # frame + shift_n: 4 # frame + window_ms: 25 # ms + shift_ms: 10 # ms + sample_rate: 16000 + sample_width: 2 diff --git a/demos/streaming_asr_server/conf/ws_application.yaml b/demos/streaming_asr_server/conf/ws_ds2_application.yaml similarity index 98% rename from demos/streaming_asr_server/conf/ws_application.yaml rename to demos/streaming_asr_server/conf/ws_ds2_application.yaml index f2ea6330f..d19bd26dc 100644 --- a/demos/streaming_asr_server/conf/ws_application.yaml +++ b/demos/streaming_asr_server/conf/ws_ds2_application.yaml @@ -28,6 +28,7 @@ asr_online: sample_rate: 16000 cfg_path: decode_method: + num_decoding_left_chunks: force_yes: True device: 'cpu' # cpu or gpu:id diff --git a/demos/streaming_asr_server/server.sh b/demos/streaming_asr_server/server.sh index 4266f8c64..f532546e7 100755 --- a/demos/streaming_asr_server/server.sh +++ b/demos/streaming_asr_server/server.sh @@ -4,5 +4,6 @@ export CUDA_VISIBLE_DEVICE=0,1,2,3 # nohup python3 punc_server.py --config_file conf/punc_application.yaml > punc.log 2>&1 & paddlespeech_server start --config_file conf/punc_application.yaml &> punc.log & -# nohup python3 streaming_asr_server.py --config_file conf/ws_conformer_application.yaml > streaming_asr.log 2>&1 & -paddlespeech_server start --config_file conf/ws_conformer_application.yaml &> streaming_asr.log & \ No newline at end of file +# nohup python3 streaming_asr_server.py --config_file conf/ws_conformer_wenetspeech_application.yaml > streaming_asr.log 2>&1 & +paddlespeech_server start --config_file conf/ws_conformer_wenetspeech_application.yaml &> streaming_asr.log & + diff --git a/demos/streaming_asr_server/test.sh b/demos/streaming_asr_server/test.sh index 4f43c6534..f3075454d 100755 --- a/demos/streaming_asr_server/test.sh +++ b/demos/streaming_asr_server/test.sh @@ -9,4 +9,5 @@ paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --input ./zh.wa # read the wav and call streaming and punc service # If `127.0.0.1` is not accessible, you need to use the actual service IP address. # python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav -paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --input ./zh.wav \ No newline at end of file +paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --input ./zh.wav + diff --git a/demos/streaming_asr_server/web/templates/index.html b/demos/streaming_asr_server/web/templates/index.html index 56c630808..768aebb8c 100644 --- a/demos/streaming_asr_server/web/templates/index.html +++ b/demos/streaming_asr_server/web/templates/index.html @@ -93,6 +93,7 @@ function parseResult(data) { var data = JSON.parse(data) + console.log('result json:', data) var result = data.result console.log(result) $("#resultPanel").html(result) diff --git a/demos/streaming_asr_server/websocket_client.py b/demos/streaming_asr_server/websocket_client.py index 8a4fe330a..8e1f19a58 100644 --- a/demos/streaming_asr_server/websocket_client.py +++ b/demos/streaming_asr_server/websocket_client.py @@ -13,9 +13,7 @@ # limitations under the License. #!/usr/bin/python # -*- coding: UTF-8 -*- - # script for calc RTF: grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}' - import argparse import asyncio import codecs diff --git a/demos/streaming_tts_server/README.md b/demos/streaming_tts_server/README.md index 775cd9086..860d9a978 100644 --- a/demos/streaming_tts_server/README.md +++ b/demos/streaming_tts_server/README.md @@ -27,7 +27,7 @@ The configuration file can be found in `conf/tts_online_application.yaml`. - In streaming voc inference, one chunk of data is inferred at a time to achieve a streaming effect. Where `voc_block` indicates the number of valid frames in the chunk, and `voc_pad` indicates the number of frames added before and after the voc_block in a chunk. The existence of voc_pad is used to eliminate errors caused by streaming inference and avoid the influence of streaming inference on the quality of synthesized audio. - Both hifigan and mb_melgan support streaming voc inference. - When the voc model is mb_melgan, when voc_pad=14, the synthetic audio for streaming inference is consistent with the non-streaming synthetic audio; the minimum voc_pad can be set to 7, and the synthetic audio has no abnormal hearing. If the voc_pad is less than 7, the synthetic audio sounds abnormal. - - When the voc model is hifigan, when voc_pad=20, the streaming inference synthetic audio is consistent with the non-streaming synthetic audio; when voc_pad=14, the synthetic audio has no abnormal hearing. + - When the voc model is hifigan, when voc_pad=19, the streaming inference synthetic audio is consistent with the non-streaming synthetic audio; when voc_pad=14, the synthetic audio has no abnormal hearing. - Inference speed: mb_melgan > hifigan; Audio quality: mb_melgan < hifigan - **Note:** If the service can be started normally in the container, but the client access IP is unreachable, you can try to replace the `host` address in the configuration file with the local IP address. diff --git a/demos/streaming_tts_server/README_cn.md b/demos/streaming_tts_server/README_cn.md index 9c2cc50ec..254ec26a2 100644 --- a/demos/streaming_tts_server/README_cn.md +++ b/demos/streaming_tts_server/README_cn.md @@ -27,7 +27,7 @@ - 流式 voc 推理中,每次会对一个 chunk 的数据进行推理以达到流式的效果。其中 `voc_block` 表示chunk中的有效帧数,`voc_pad` 表示一个 chunk 中 voc_block 前后各加的帧数。voc_pad 的存在用于消除流式推理产生的误差,避免由流式推理对合成音频质量的影响。 - hifigan, mb_melgan 均支持流式 voc 推理 - 当 voc 模型为 mb_melgan,当 voc_pad=14 时,流式推理合成音频与非流式合成音频一致;voc_pad 最小可以设置为7,合成音频听感上没有异常,若 voc_pad 小于7,合成音频听感上存在异常。 - - 当 voc 模型为 hifigan,当 voc_pad=20 时,流式推理合成音频与非流式合成音频一致;当 voc_pad=14 时,合成音频听感上没有异常。 + - 当 voc 模型为 hifigan,当 voc_pad=19 时,流式推理合成音频与非流式合成音频一致;当 voc_pad=14 时,合成音频听感上没有异常。 - 推理速度:mb_melgan > hifigan; 音频质量:mb_melgan < hifigan - **注意:** 如果在容器里可正常启动服务,但客户端访问 ip 不可达,可尝试将配置文件中 `host` 地址换成本地 ip 地址。 diff --git a/demos/streaming_tts_server/conf/tts_online_application.yaml b/demos/streaming_tts_server/conf/tts_online_application.yaml index 964e85ef9..0460a5e16 100644 --- a/demos/streaming_tts_server/conf/tts_online_application.yaml +++ b/demos/streaming_tts_server/conf/tts_online_application.yaml @@ -47,7 +47,7 @@ tts_online: am_pad: 12 # voc_pad and voc_block voc model to streaming voc infer, # when voc model is mb_melgan_csmsc, voc_pad set 14, streaming synthetic audio is the same as non-streaming synthetic audio; The minimum value of pad can be set to 7, streaming synthetic audio sounds normal - # when voc model is hifigan_csmsc, voc_pad set 20, streaming synthetic audio is the same as non-streaming synthetic audio; voc_pad set 14, streaming synthetic audio sounds normal + # when voc model is hifigan_csmsc, voc_pad set 19, streaming synthetic audio is the same as non-streaming synthetic audio; voc_pad set 14, streaming synthetic audio sounds normal voc_block: 36 voc_pad: 14 @@ -95,7 +95,7 @@ tts_online-onnx: am_pad: 12 # voc_pad and voc_block voc model to streaming voc infer, # when voc model is mb_melgan_csmsc_onnx, voc_pad set 14, streaming synthetic audio is the same as non-streaming synthetic audio; The minimum value of pad can be set to 7, streaming synthetic audio sounds normal - # when voc model is hifigan_csmsc_onnx, voc_pad set 20, streaming synthetic audio is the same as non-streaming synthetic audio; voc_pad set 14, streaming synthetic audio sounds normal + # when voc model is hifigan_csmsc_onnx, voc_pad set 19, streaming synthetic audio is the same as non-streaming synthetic audio; voc_pad set 14, streaming synthetic audio sounds normal voc_block: 36 voc_pad: 14 # voc_upsample should be same as n_shift on voc config. diff --git a/demos/text_to_speech/README.md b/demos/text_to_speech/README.md index 2df72a82d..389847a12 100644 --- a/demos/text_to_speech/README.md +++ b/demos/text_to_speech/README.md @@ -77,7 +77,7 @@ The input of this demo should be a text of the specific language that can be pas - Python API ```python import paddle - from paddlespeech.cli import TTSExecutor + from paddlespeech.cli.tts import TTSExecutor tts_executor = TTSExecutor() wav_file = tts_executor( diff --git a/demos/text_to_speech/README_cn.md b/demos/text_to_speech/README_cn.md index 7e02b9624..f967d3d4d 100644 --- a/demos/text_to_speech/README_cn.md +++ b/demos/text_to_speech/README_cn.md @@ -80,7 +80,7 @@ - Python API ```python import paddle - from paddlespeech.cli import TTSExecutor + from paddlespeech.cli.tts import TTSExecutor tts_executor = TTSExecutor() wav_file = tts_executor( diff --git a/docker/ubuntu18-cpu/Dockerfile b/docker/ubuntu18-cpu/Dockerfile new file mode 100644 index 000000000..d14c01858 --- /dev/null +++ b/docker/ubuntu18-cpu/Dockerfile @@ -0,0 +1,15 @@ +FROM registry.baidubce.com/paddlepaddle/paddle:2.2.2 +LABEL maintainer="paddlesl@baidu.com" + +RUN git clone --depth 1 https://github.com/PaddlePaddle/PaddleSpeech.git /home/PaddleSpeech +RUN pip3 uninstall mccabe -y ; exit 0; +RUN pip3 install multiprocess==0.70.12 importlib-metadata==4.2.0 dill==0.3.4 + +RUN cd /home/PaddleSpeech/audio +RUN python setup.py bdist_wheel + +RUN cd /home/PaddleSpeech +RUN python setup.py bdist_wheel +RUN pip install audio/dist/*.whl dist/*.whl + +WORKDIR /home/PaddleSpeech/ diff --git a/docs/paddlespeech.pdf b/docs/paddlespeech.pdf new file mode 100644 index 000000000..a1c498ad2 Binary files /dev/null and b/docs/paddlespeech.pdf differ diff --git a/docs/source/asr/PPASR_cn.md b/docs/source/asr/PPASR_cn.md index 82b1c1d37..2e3f1cd97 100644 --- a/docs/source/asr/PPASR_cn.md +++ b/docs/source/asr/PPASR_cn.md @@ -92,5 +92,3 @@ server 的 demo: [streaming_asr_server](https://github.com/PaddlePaddle/Paddle ## 4. 快速开始 关于如果使用 PP-ASR,可以看这里的 [install](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md),其中提供了 **简单**、**中等**、**困难** 三种安装方式。如果想体验 paddlespeech 的推理功能,可以用 **简单** 安装方式。 - - diff --git a/audio/docs/source/_static/custom.css b/docs/source/audio/_static/custom.css similarity index 100% rename from audio/docs/source/_static/custom.css rename to docs/source/audio/_static/custom.css diff --git a/audio/docs/source/_templates/module.rst_t b/docs/source/audio/_templates/module.rst_t similarity index 100% rename from audio/docs/source/_templates/module.rst_t rename to docs/source/audio/_templates/module.rst_t diff --git a/audio/docs/source/_templates/package.rst_t b/docs/source/audio/_templates/package.rst_t similarity index 100% rename from audio/docs/source/_templates/package.rst_t rename to docs/source/audio/_templates/package.rst_t diff --git a/audio/docs/source/_templates/toc.rst_t b/docs/source/audio/_templates/toc.rst_t similarity index 100% rename from audio/docs/source/_templates/toc.rst_t rename to docs/source/audio/_templates/toc.rst_t diff --git a/audio/docs/source/conf.py b/docs/source/audio/conf.py similarity index 100% rename from audio/docs/source/conf.py rename to docs/source/audio/conf.py diff --git a/audio/docs/source/index.rst b/docs/source/audio/index.rst similarity index 100% rename from audio/docs/source/index.rst rename to docs/source/audio/index.rst diff --git a/docs/source/cls/custom_dataset.md b/docs/source/cls/custom_dataset.md index aaf5943c5..e39dcf12d 100644 --- a/docs/source/cls/custom_dataset.md +++ b/docs/source/cls/custom_dataset.md @@ -1,8 +1,8 @@ # Customize Dataset for Audio Classification -Following this tutorial you can customize your dataset for audio classification task by using `paddlespeech` and `paddleaudio`. +Following this tutorial you can customize your dataset for audio classification task by using `paddlespeech`. -A base class of classification dataset is `paddleaudio.dataset.AudioClassificationDataset`. To customize your dataset you should write a dataset class derived from `AudioClassificationDataset`. +A base class of classification dataset is `paddlespeech.audio.dataset.AudioClassificationDataset`. To customize your dataset you should write a dataset class derived from `AudioClassificationDataset`. Assuming you have some wave files that stored in your own directory. You should prepare a meta file with the information of filepaths and labels. For example the absolute path of it is `/PATH/TO/META_FILE.txt`: ``` @@ -14,7 +14,7 @@ Assuming you have some wave files that stored in your own directory. You should Here is an example to build your custom dataset in `custom_dataset.py`: ```python -from paddleaudio.datasets.dataset import AudioClassificationDataset +from paddlespeech.audio.datasets.dataset import AudioClassificationDataset class CustomDataset(AudioClassificationDataset): meta_file = '/PATH/TO/META_FILE.txt' @@ -48,7 +48,7 @@ class CustomDataset(AudioClassificationDataset): Then you can build dataset and data loader from `CustomDataset`: ```python import paddle -from paddleaudio.features import LogMelSpectrogram +from paddlespeech.audio.features import LogMelSpectrogram from custom_dataset import CustomDataset diff --git a/docs/source/install.md b/docs/source/install.md index 43cc784cc..e3ea74b27 100644 --- a/docs/source/install.md +++ b/docs/source/install.md @@ -4,7 +4,7 @@ There are 3 ways to use `PaddleSpeech`. According to the degree of difficulty, t | Way | Function | Support| |:---- |:----------------------------------------------------------- |:----| -| Easy | (1) Use command-line functions of PaddleSpeech.
(2) Experience PaddleSpeech on Ai Studio. | Linux, Mac(not support M1 chip),Windows | +| Easy | (1) Use command-line functions of PaddleSpeech.
(2) Experience PaddleSpeech on Ai Studio. | Linux, Mac(not support M1 chip),Windows ( For more information about installation, see [#1195](https://github.com/PaddlePaddle/PaddleSpeech/discussions/1195)) | | Medium | Support major functions ,such as using the` ready-made `examples and using PaddleSpeech to train your model. | Linux | | Hard | Support full function of Paddlespeech, including using join ctc decoder with kaldi, training n-gram language model, Montreal-Forced-Aligner, and so on. And you are more able to be a developer! | Ubuntu | diff --git a/docs/source/install_cn.md b/docs/source/install_cn.md index 55fef93d5..5a967f404 100644 --- a/docs/source/install_cn.md +++ b/docs/source/install_cn.md @@ -3,7 +3,7 @@ `PaddleSpeech` 有三种安装方法。根据安装的难易程度,这三种方法可以分为 **简单**, **中等** 和 **困难**. | 方式 | 功能 | 支持系统 | | :--- | :----------------------------------------------------------- | :------------------ | -| 简单 | (1) 使用 PaddleSpeech 的命令行功能.
(2) 在 Aistudio上体验 PaddleSpeech. | Linux, Mac(不支持M1芯片),Windows | +| 简单 | (1) 使用 PaddleSpeech 的命令行功能.
(2) 在 Aistudio上体验 PaddleSpeech. | Linux, Mac(不支持M1芯片),Windows (安装详情查看[#1195](https://github.com/PaddlePaddle/PaddleSpeech/discussions/1195)) | | 中等 | 支持 PaddleSpeech 主要功能,比如使用已有 examples 中的模型和使用 PaddleSpeech 来训练自己的模型. | Linux | | 困难 | 支持 PaddleSpeech 的各项功能,包含结合kaldi使用 join ctc decoder 方式解码,训练语言模型,使用强制对齐等。并且你更能成为一名开发者! | Ubuntu | ## 先决条件 diff --git a/docs/source/released_model.md b/docs/source/released_model.md index 74435ae1a..5afd3c478 100644 --- a/docs/source/released_model.md +++ b/docs/source/released_model.md @@ -6,15 +6,15 @@ ### Speech Recognition Model Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER | Hours of speech | Example Link :-------------:| :------------:| :-----: | -----: | :-----: |:-----:| :-----: | :-----: | :-----: -[Ds2 Online Wenetspeech ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz) | Wenetspeech Dataset | Char-based | 1.2 GB | 2 Conv + 5 LSTM layers | 0.152 (test\_net, w/o LM)
0.2417 (test\_meeting, w/o LM)
0.053 (aishell, w/ LM) |-| 10000 h |- +[Ds2 Online Wenetspeech ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.2.model.tar.gz) | Wenetspeech Dataset | Char-based | 1.2 GB | 2 Conv + 5 LSTM layers | 0.152 (test\_net, w/o LM)
0.2417 (test\_meeting, w/o LM)
0.053 (aishell, w/ LM) |-| 10000 h |- [Ds2 Online Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz) | Aishell Dataset | Char-based | 491 MB | 2 Conv + 5 LSTM layers | 0.0666 |-| 151 h | [D2 Online Aishell ASR0](../../examples/aishell/asr0) -[Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz)| Aishell Dataset | Char-based | 306 MB | 2 Conv + 3 bidirectional GRU layers| 0.064 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0) +[Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_offline_aishell_ckpt_1.0.1.model.tar.gz)| Aishell Dataset | Char-based | 1.4 GB | 2 Conv + 5 bidirectional LSTM layers| 0.0554 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0) [Conformer Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz) | WenetSpeech Dataset | Char-based | 457 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.11 (test\_net) 0.1879 (test\_meeting) |-| 10000 h |- [Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.0544 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1) [Conformer Offline Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_0.1.2.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0464 |-| 151 h | [Conformer Offline Aishell ASR1](../../examples/aishell/asr1) [Transformer Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz) | Aishell Dataset | Char-based | 128 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0523 || 151 h | [Transformer Aishell ASR1](../../examples/aishell/asr1) -[Ds2 Offline Librispeech ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz)| Librispeech Dataset | Char-based | 518 MB | 2 Conv + 3 bidirectional LSTM layers| - |0.0725| 960 h | [Ds2 Offline Librispeech ASR0](../../examples/librispeech/asr0) -[Conformer Librispeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_conformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 191 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0337 | 960 h | [Conformer Librispeech ASR1](../../examples/librispeech/asr1) +[Ds2 Offline Librispeech ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_offline_librispeech_ckpt_1.0.1.model.tar.gz)| Librispeech Dataset | Char-based | 1.3 GB | 2 Conv + 5 bidirectional LSTM layers| - |0.0467| 960 h | [Ds2 Offline Librispeech ASR0](../../examples/librispeech/asr0) +[Conformer Librispeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_conformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 191 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0338 | 960 h | [Conformer Librispeech ASR1](../../examples/librispeech/asr1) [Transformer Librispeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 131 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0381 | 960 h | [Transformer Librispeech ASR1](../../examples/librispeech/asr1) [Transformer Librispeech ASR2 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr2/asr2_transformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 131 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: JoinCTC w/ LM |-| 0.0240 | 960 h | [Transformer Librispeech ASR2](../../examples/librispeech/asr2) @@ -82,17 +82,9 @@ PANN | ESC-50 |[pann-esc50](../../examples/esc50/cls0)|[esc50_cnn6.tar.gz](https Model Type | Dataset| Example Link | Pretrained Models | Static Models :-------------:| :------------:| :-----: | :-----: | :-----: -PANN | VoxCeleb| [voxceleb_ecapatdnn](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/voxceleb/sv0) | [ecapatdnn.tar.gz](https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz) | - +ECAPA-TDNN | VoxCeleb| [voxceleb_ecapatdnn](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/voxceleb/sv0) | [ecapatdnn.tar.gz](https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_1.tar.gz) | - ## Punctuation Restoration Models Model Type | Dataset| Example Link | Pretrained Models :-------------:| :------------:| :-----: | :-----: Ernie Linear | IWLST2012_zh |[iwslt2012_punc0](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/iwslt2012/punc0)|[ernie_linear_p3_iwslt2012_zh_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_iwslt2012_zh_ckpt_0.1.1.zip) - -## Speech Recognition Model from paddle 1.8 - -| Acoustic Model |Training Data| Token-based | Size | Descriptions | CER | WER | Hours of speech | -| :-----:| :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | -| [Ds2 Offline Aishell model](https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_v1.8_to_v2.x.tar.gz) | Aishell Dataset | Char-based | 234 MB | 2 Conv + 3 bidirectional GRU layers | 0.0804 | — | 151 h | -| [Ds2 Offline Librispeech model](https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz) | Librispeech Dataset | Word-based | 307 MB | 2 Conv + 3 bidirectional sharing weight RNN layers | — | 0.0685 | 960 h | -| [Ds2 Offline Baidu en8k model](https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz) | Baidu Internal English Dataset | Word-based | 273 MB | 2 Conv + 3 bidirectional GRU layers |— | 0.0541 | 8628 h| diff --git a/examples/aishell/asr0/RESULTS.md b/examples/aishell/asr0/RESULTS.md index 131b66286..299445b77 100644 --- a/examples/aishell/asr0/RESULTS.md +++ b/examples/aishell/asr0/RESULTS.md @@ -12,7 +12,8 @@ ## Deepspeech2 Non-Streaming | Model | Number of Params | Release | Config | Test set | Valid Loss | CER | -| --- | --- | --- | --- | --- | --- | --- | +| --- | --- | --- | --- | --- | --- | --- | +| DeepSpeech2 | 122.3M | r1.0.1 | conf/deepspeech2.yaml + U2 Data pipline and spec aug + fbank161 | test | 5.780756044387817 | 0.055400 | | DeepSpeech2 | 58.4M | v2.2.0 | conf/deepspeech2.yaml + spec aug | test | 5.738585948944092 | 0.064000 | | DeepSpeech2 | 58.4M | v2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 | | DeepSpeech2 | 58.4M | v2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 | diff --git a/examples/aishell/asr0/conf/augmentation.json b/examples/aishell/asr0/conf/augmentation.json deleted file mode 100644 index 31c481c8d..000000000 --- a/examples/aishell/asr0/conf/augmentation.json +++ /dev/null @@ -1,36 +0,0 @@ -[ - { - "type": "speed", - "params": { - "min_speed_rate": 0.9, - "max_speed_rate": 1.1, - "num_rates": 3 - }, - "prob": 0.0 - }, - { - "type": "shift", - "params": { - "min_shift_ms": -5, - "max_shift_ms": 5 - }, - "prob": 1.0 - }, - { - "type": "specaug", - "params": { - "W": 0, - "warp_mode": "PIL", - "F": 10, - "n_freq_masks": 2, - "T": 50, - "n_time_masks": 2, - "p": 1.0, - "adaptive_number_ratio": 0, - "adaptive_size_ratio": 0, - "max_n_time_masks": 20, - "replace_with_zero": true - }, - "prob": 1.0 - } -] diff --git a/examples/aishell/asr0/conf/deepspeech2.yaml b/examples/aishell/asr0/conf/deepspeech2.yaml index fb6998647..913354f5d 100644 --- a/examples/aishell/asr0/conf/deepspeech2.yaml +++ b/examples/aishell/asr0/conf/deepspeech2.yaml @@ -15,50 +15,53 @@ max_output_input_ratio: .inf ########################################### # Dataloader # ########################################### -batch_size: 64 # one gpu -mean_std_filepath: data/mean_std.json -unit_type: char vocab_filepath: data/lang_char/vocab.txt -augmentation_config: conf/augmentation.json -random_seed: 0 -spm_model_prefix: -spectrum_type: linear +spm_model_prefix: '' +unit_type: 'char' +preprocess_config: conf/preprocess.yaml feat_dim: 161 -delta_delta: False stride_ms: 10.0 -window_ms: 20.0 -n_fft: None -max_freq: None -target_sample_rate: 16000 -use_dB_normalization: True -target_dB: -20 -dither: 1.0 -keep_transcription_text: False -sortagrad: True -shuffle_method: batch_shuffle -num_workers: 2 +window_ms: 25.0 +sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs +batch_size: 64 +maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced +minibatches: 0 # for debug +batch_count: auto +batch_bins: 0 +batch_frames_in: 0 +batch_frames_out: 0 +batch_frames_inout: 0 +num_workers: 8 +subsampling_factor: 1 +num_encs: 1 ############################################ # Network Architecture # ############################################ num_conv_layers: 2 -num_rnn_layers: 3 +num_rnn_layers: 5 rnn_layer_size: 1024 -use_gru: True -share_rnn_weights: False +rnn_direction: bidirect # [forward, bidirect] +num_fc_layers: 0 +fc_layers_size_list: -1, +use_gru: False blank_id: 0 -ctc_grad_norm_type: instance - + + ########################################### # Training # ########################################### -n_epoch: 80 +n_epoch: 50 accum_grad: 1 -lr: 2.0e-3 -lr_decay: 0.83 +lr: 5.0e-4 +lr_decay: 0.93 weight_decay: 1.0e-6 global_grad_clip: 3.0 -log_interval: 100 +dist_sampler: False +log_interval: 1 checkpoint: kbest_n: 50 latest_n: 5 + + diff --git a/examples/aishell/asr0/conf/deepspeech2_online.yaml b/examples/aishell/asr0/conf/deepspeech2_online.yaml index ef01ac595..a53e19f37 100644 --- a/examples/aishell/asr0/conf/deepspeech2_online.yaml +++ b/examples/aishell/asr0/conf/deepspeech2_online.yaml @@ -15,28 +15,26 @@ max_output_input_ratio: .inf ########################################### # Dataloader # ########################################### -batch_size: 64 # one gpu -mean_std_filepath: data/mean_std.json -unit_type: char vocab_filepath: data/lang_char/vocab.txt -augmentation_config: conf/augmentation.json -random_seed: 0 -spm_model_prefix: -spectrum_type: linear #linear, mfcc, fbank +spm_model_prefix: '' +unit_type: 'char' +preprocess_config: conf/preprocess.yaml feat_dim: 161 -delta_delta: False stride_ms: 10.0 -window_ms: 20.0 -n_fft: None -max_freq: None -target_sample_rate: 16000 -use_dB_normalization: True -target_dB: -20 -dither: 1.0 -keep_transcription_text: False -sortagrad: True -shuffle_method: batch_shuffle -num_workers: 0 +window_ms: 25.0 +sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs +batch_size: 64 +maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced +minibatches: 0 # for debug +batch_count: auto +batch_bins: 0 +batch_frames_in: 0 +batch_frames_out: 0 +batch_frames_inout: 0 +num_workers: 8 +subsampling_factor: 1 +num_encs: 1 ############################################ # Network Architecture # @@ -54,12 +52,13 @@ blank_id: 0 ########################################### # Training # ########################################### -n_epoch: 65 +n_epoch: 30 accum_grad: 1 lr: 5.0e-4 lr_decay: 0.93 weight_decay: 1.0e-6 global_grad_clip: 3.0 +dist_sampler: False log_interval: 100 checkpoint: kbest_n: 50 diff --git a/examples/aishell/asr0/conf/preprocess.yaml b/examples/aishell/asr0/conf/preprocess.yaml new file mode 100644 index 000000000..3f526e0ad --- /dev/null +++ b/examples/aishell/asr0/conf/preprocess.yaml @@ -0,0 +1,25 @@ +process: + # extract kaldi fbank from PCM + - type: fbank_kaldi + fs: 16000 + n_mels: 161 + n_shift: 160 + win_length: 400 + dither: 0.1 + - type: cmvn_json + cmvn_path: data/mean_std.json + # these three processes are a.k.a. SpecAugument + - type: time_warp + max_time_warp: 5 + inplace: true + mode: PIL + - type: freq_mask + F: 30 + n_mask: 2 + inplace: true + replace_with_zero: false + - type: time_mask + T: 40 + n_mask: 2 + inplace: true + replace_with_zero: false diff --git a/examples/aishell/asr0/conf/tuning/decode.yaml b/examples/aishell/asr0/conf/tuning/decode.yaml index 5778e6565..7dbc6fa82 100644 --- a/examples/aishell/asr0/conf/tuning/decode.yaml +++ b/examples/aishell/asr0/conf/tuning/decode.yaml @@ -2,9 +2,9 @@ decode_batch_size: 128 error_rate_type: cer decoding_method: ctc_beam_search lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm -alpha: 1.9 -beta: 5.0 -beam_size: 300 +alpha: 2.2 +beta: 4.3 +beam_size: 500 cutoff_prob: 0.99 cutoff_top_n: 40 num_proc_bsearch: 10 diff --git a/examples/aishell/asr0/local/data.sh b/examples/aishell/asr0/local/data.sh index ec692eba6..8722c1ca3 100755 --- a/examples/aishell/asr0/local/data.sh +++ b/examples/aishell/asr0/local/data.sh @@ -33,12 +33,13 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then num_workers=$(nproc) python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ --manifest_path="data/manifest.train.raw" \ - --spectrum_type="linear" \ + --spectrum_type="fbank" \ + --feat_dim=161 \ --delta_delta=false \ --stride_ms=10 \ - --window_ms=20 \ + --window_ms=25 \ --sample_rate=16000 \ - --use_dB_normalization=True \ + --use_dB_normalization=False \ --num_samples=2000 \ --num_workers=${num_workers} \ --output_path="data/mean_std.json" diff --git a/examples/aishell/asr0/local/export.sh b/examples/aishell/asr0/local/export.sh index 426a72fe5..ce7e6d642 100755 --- a/examples/aishell/asr0/local/export.sh +++ b/examples/aishell/asr0/local/export.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 4 ];then - echo "usage: $0 config_path ckpt_prefix jit_model_path model_type" +if [ $# != 3 ];then + echo "usage: $0 config_path ckpt_prefix jit_model_path" exit -1 fi @@ -11,14 +11,12 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_path_prefix=$2 jit_model_export_path=$3 -model_type=$4 python3 -u ${BIN_DIR}/export.py \ --ngpu ${ngpu} \ --config ${config_path} \ --checkpoint_path ${ckpt_path_prefix} \ ---export_path ${jit_model_export_path} \ ---model_type ${model_type} +--export_path ${jit_model_export_path} if [ $? -ne 0 ]; then echo "Failed in export!" diff --git a/examples/aishell/asr0/local/test.sh b/examples/aishell/asr0/local/test.sh index 363dbf0ab..778c7142e 100755 --- a/examples/aishell/asr0/local/test.sh +++ b/examples/aishell/asr0/local/test.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 4 ];then - echo "usage: ${0} config_path decode_config_path ckpt_path_prefix model_type" +if [ $# != 3 ];then + echo "usage: ${0} config_path decode_config_path ckpt_path_prefix" exit -1 fi @@ -13,7 +13,6 @@ echo "using $ngpu gpus..." config_path=$1 decode_config_path=$2 ckpt_prefix=$3 -model_type=$4 # download language model bash local/download_lm_ch.sh @@ -23,7 +22,7 @@ fi if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then # format the reference test file - python utils/format_rsl.py \ + python3 utils/format_rsl.py \ --origin_ref data/manifest.test.raw \ --trans_ref data/manifest.test.text @@ -32,8 +31,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then --config ${config_path} \ --decode_cfg ${decode_config_path} \ --result_file ${ckpt_prefix}.rsl \ - --checkpoint_path ${ckpt_prefix} \ - --model_type ${model_type} + --checkpoint_path ${ckpt_prefix} if [ $? -ne 0 ]; then echo "Failed in evaluation!" @@ -41,25 +39,25 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then fi # format the hyp file - python utils/format_rsl.py \ + python3 utils/format_rsl.py \ --origin_hyp ${ckpt_prefix}.rsl \ --trans_hyp ${ckpt_prefix}.rsl.text - python utils/compute-wer.py --char=1 --v=1 \ - data/manifest.test.text ${ckpt_prefix}.rsl.text > ${ckpt_prefix}.error + python3 utils/compute-wer.py --char=1 --v=1 \ + data/manifest.test.text ${ckpt_prefix}.rsl.text > ${ckpt_prefix}.error fi if [ ${stage} -le 101 ] && [ ${stop_stage} -ge 101 ]; then - python utils/format_rsl.py \ + python3 utils/format_rsl.py \ --origin_ref data/manifest.test.raw \ --trans_ref_sclite data/manifest.test.text.sclite - python utils/format_rsl.py \ - --origin_hyp ${ckpt_prefix}.rsl \ - --trans_hyp_sclite ${ckpt_prefix}.rsl.text.sclite + python3 utils/format_rsl.py \ + --origin_hyp ${ckpt_prefix}.rsl \ + --trans_hyp_sclite ${ckpt_prefix}.rsl.text.sclite - mkdir -p ${ckpt_prefix}_sclite - sclite -i wsj -r data/manifest.test.text.sclite -h ${ckpt_prefix}.rsl.text.sclite -e utf-8 -o all -O ${ckpt_prefix}_sclite -c NOASCII + mkdir -p ${ckpt_prefix}_sclite + sclite -i wsj -r data/manifest.test.text.sclite -h ${ckpt_prefix}.rsl.text.sclite -e utf-8 -o all -O ${ckpt_prefix}_sclite -c NOASCII fi exit 0 diff --git a/examples/aishell/asr0/local/test_export.sh b/examples/aishell/asr0/local/test_export.sh index 7a4b87f8c..a46a0d876 100755 --- a/examples/aishell/asr0/local/test_export.sh +++ b/examples/aishell/asr0/local/test_export.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 4 ];then - echo "usage: ${0} config_path decode_config_path ckpt_path_prefix model_type" +if [ $# != 3 ];then + echo "usage: ${0} config_path decode_config_path ckpt_path_prefix" exit -1 fi @@ -11,7 +11,6 @@ echo "using $ngpu gpus..." config_path=$1 decode_config_path=$2 jit_model_export_path=$3 -model_type=$4 # download language model bash local/download_lm_ch.sh > /dev/null 2>&1 @@ -24,8 +23,7 @@ python3 -u ${BIN_DIR}/test_export.py \ --config ${config_path} \ --decode_cfg ${decode_config_path} \ --result_file ${jit_model_export_path}.rsl \ ---export_path ${jit_model_export_path} \ ---model_type ${model_type} +--export_path ${jit_model_export_path} if [ $? -ne 0 ]; then echo "Failed in evaluation!" diff --git a/examples/aishell/asr0/local/test_wav.sh b/examples/aishell/asr0/local/test_wav.sh index 62b005a6a..a228dda5a 100755 --- a/examples/aishell/asr0/local/test_wav.sh +++ b/examples/aishell/asr0/local/test_wav.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 5 ];then - echo "usage: ${0} config_path decode_config_path ckpt_path_prefix model_type audio_file" +if [ $# != 4 ];then + echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_file" exit -1 fi @@ -11,8 +11,7 @@ echo "using $ngpu gpus..." config_path=$1 decode_config_path=$2 ckpt_prefix=$3 -model_type=$4 -audio_file=$5 +audio_file=$4 mkdir -p data wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/demo_01_03.wav -P data/ @@ -37,7 +36,6 @@ python3 -u ${BIN_DIR}/test_wav.py \ --decode_cfg ${decode_config_path} \ --result_file ${ckpt_prefix}.rsl \ --checkpoint_path ${ckpt_prefix} \ ---model_type ${model_type} \ --audio_file ${audio_file} if [ $? -ne 0 ]; then diff --git a/examples/aishell/asr0/local/train.sh b/examples/aishell/asr0/local/train.sh index 102c051c1..256b30d22 100755 --- a/examples/aishell/asr0/local/train.sh +++ b/examples/aishell/asr0/local/train.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 3 ];then - echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type" +if [ $# -lt 2 ] && [ $# -gt 3 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)" exit -1 fi @@ -10,7 +10,13 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_name=$2 -model_type=$3 +ips=$3 + +if [ ! $ips ];then + ips_config= +else + ips_config="--ips="${ips} +fi mkdir -p exp @@ -25,14 +31,12 @@ python3 -u ${BIN_DIR}/train.py \ --ngpu ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ ---model_type ${model_type} \ --seed ${seed} else -python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${BIN_DIR}/train.py \ +python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \ --ngpu ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ ---model_type ${model_type} \ --seed ${seed} fi diff --git a/examples/aishell/asr0/run.sh b/examples/aishell/asr0/run.sh index 114af5a97..530c013ac 100755 --- a/examples/aishell/asr0/run.sh +++ b/examples/aishell/asr0/run.sh @@ -6,9 +6,9 @@ gpus=0,1,2,3 stage=0 stop_stage=100 conf_path=conf/deepspeech2.yaml #conf/deepspeech2.yaml or conf/deepspeech2_online.yaml +ips= #xx.xx.xx.xx,xx.xx.xx.xx decode_conf_path=conf/tuning/decode.yaml -avg_num=1 -model_type=offline # offline or online +avg_num=10 audio_file=data/demo_01_03.wav source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; @@ -25,7 +25,7 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${model_type} + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then @@ -35,21 +35,21 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type}|| exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt}|| exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # export ckpt avg_n - CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type} + CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit fi if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then # test export ckpt avg_n - CUDA_VISIBLE_DEVICES=0 ./local/test_export.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}|| exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/test_export.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt}.jit|| exit -1 fi # Optionally, you can add LM and test it with runtime. if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then # test a single .wav file - CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type} ${audio_file} || exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1 fi diff --git a/examples/aishell/asr1/local/train.sh b/examples/aishell/asr1/local/train.sh index 5617f7efe..f514de303 100755 --- a/examples/aishell/asr1/local/train.sh +++ b/examples/aishell/asr1/local/train.sh @@ -17,13 +17,21 @@ if [ ${seed} != 0 ]; then echo "using seed $seed & FLAGS_cudnn_deterministic=True ..." fi -if [ $# != 2 ];then - echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" +if [ $# -lt 2 ] && [ $# -gt 3 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)" exit -1 fi config_path=$1 ckpt_name=$2 +ips=$3 + +if [ ! $ips ];then + ips_config= +else + ips_config="--ips="${ips} +fi +echo ${ips_config} mkdir -p exp @@ -37,7 +45,7 @@ python3 -u ${BIN_DIR}/train.py \ --benchmark-batch-size ${benchmark_batch_size} \ --benchmark-max-step ${benchmark_max_step} else -python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${BIN_DIR}/train.py \ +python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \ --ngpu ${ngpu} \ --seed ${seed} \ --config ${config_path} \ diff --git a/examples/aishell/asr1/run.sh b/examples/aishell/asr1/run.sh index cb781b208..bd4f50e3f 100644 --- a/examples/aishell/asr1/run.sh +++ b/examples/aishell/asr1/run.sh @@ -6,6 +6,7 @@ gpus=0,1,2,3 stage=0 stop_stage=50 conf_path=conf/conformer.yaml +ips= #xx.xx.xx.xx,xx.xx.xx.xx decode_conf_path=conf/tuning/decode.yaml avg_num=30 audio_file=data/demo_01_03.wav @@ -23,7 +24,7 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then diff --git a/examples/aishell3/tts3/README.md b/examples/aishell3/tts3/README.md index d02ad1b63..31c99898c 100644 --- a/examples/aishell3/tts3/README.md +++ b/examples/aishell3/tts3/README.md @@ -6,15 +6,8 @@ AISHELL-3 is a large-scale and high-fidelity multi-speaker Mandarin speech corpu We use AISHELL-3 to train a multi-speaker fastspeech2 model here. ## Dataset ### Download and Extract -Download AISHELL-3. -```bash -wget https://www.openslr.org/resources/93/data_aishell3.tgz -``` -Extract AISHELL-3. -```bash -mkdir data_aishell3 -tar zxvf data_aishell3.tgz -C data_aishell3 -``` +Download AISHELL-3 from it's [Official Website](http://www.aishelltech.com/aishell_3) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/data_aishell3`. + ### Get MFA Result and Extract We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for aishell3_fastspeech2. You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) (use MFA1.x now) of our repo. @@ -120,12 +113,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p ``` ```text usage: synthesize.py [-h] - [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] + [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--voice-cloning VOICE_CLONING] - [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] + [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_stat VOC_STAT] [--ngpu NGPU] [--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR] @@ -134,11 +127,10 @@ Synthesize with acoustic model & vocoder optional arguments: -h, --help show this help message and exit - --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} + --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3} Choose acoustic model type of tts task. --am_config AM_CONFIG - Config of acoustic model. Use deault config when it is - None. + Config of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_stat AM_STAT mean and standard deviation used to normalize spectrogram when training acoustic model. @@ -150,10 +142,10 @@ optional arguments: speaker id map file. --voice-cloning VOICE_CLONING whether training voice cloning model. - --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} + --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc} Choose vocoder type of tts task. --voc_config VOC_CONFIG - Config of voc. Use deault config when it is None. + Config of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_stat VOC_STAT mean and standard deviation used to normalize spectrogram when training voc. @@ -169,12 +161,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp ``` ```text usage: synthesize_e2e.py [-h] - [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] + [--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID] - [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] + [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_stat VOC_STAT] [--lang LANG] [--inference_dir INFERENCE_DIR] [--ngpu NGPU] @@ -184,11 +176,10 @@ Synthesize with acoustic model & vocoder optional arguments: -h, --help show this help message and exit - --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} + --am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech} Choose acoustic model type of tts task. --am_config AM_CONFIG - Config of acoustic model. Use deault config when it is - None. + Config of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_stat AM_STAT mean and standard deviation used to normalize spectrogram when training acoustic model. @@ -199,10 +190,10 @@ optional arguments: --speaker_dict SPEAKER_DICT speaker id map file. --spk_id SPK_ID spk id for multi speaker acoustic model - --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} + --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc} Choose vocoder type of tts task. --voc_config VOC_CONFIG - Config of voc. Use deault config when it is None. + Config of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_stat VOC_STAT mean and standard deviation used to normalize spectrogram when training voc. @@ -215,9 +206,9 @@ optional arguments: output dir. ``` 1. `--am` is acoustic model type with the format {model_name}_{dataset} -2. `--am_config`, `--am_checkpoint`, `--am_stat`, `--phones_dict` `--speaker_dict` are arguments for acoustic model, which correspond to the 5 files in the fastspeech2 pretrained model. +2. `--am_config`, `--am_ckpt`, `--am_stat`, `--phones_dict` `--speaker_dict` are arguments for acoustic model, which correspond to the 5 files in the fastspeech2 pretrained model. 3. `--voc` is vocoder type with the format {model_name}_{dataset} -4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model. +4. `--voc_config`, `--voc_ckpt`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model. 5. `--lang` is the model language, which can be `zh` or `en`. 6. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder. 7. `--text` is the text file, which contains sentences to synthesize. diff --git a/examples/aishell3/vc0/README.md b/examples/aishell3/vc0/README.md index 925663ab1..d64f961ad 100644 --- a/examples/aishell3/vc0/README.md +++ b/examples/aishell3/vc0/README.md @@ -6,15 +6,8 @@ This example contains code used to train a [Tacotron2](https://arxiv.org/abs/171 ## Dataset ### Download and Extract -Download AISHELL-3. -```bash -wget https://www.openslr.org/resources/93/data_aishell3.tgz -``` -Extract AISHELL-3. -```bash -mkdir data_aishell3 -tar zxvf data_aishell3.tgz -C data_aishell3 -``` +Download AISHELL-3 from it's [Official Website](http://www.aishelltech.com/aishell_3) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/data_aishell3`. + ### Get MFA Result and Extract We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get phonemes for Tacotron2, the durations of MFA are not needed here. You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) (use MFA1.x now) of our repo. diff --git a/examples/aishell3/vc1/README.md b/examples/aishell3/vc1/README.md index 8ab0f9c8c..aab525103 100644 --- a/examples/aishell3/vc1/README.md +++ b/examples/aishell3/vc1/README.md @@ -6,15 +6,8 @@ This example contains code used to train a [FastSpeech2](https://arxiv.org/abs/2 ## Dataset ### Download and Extract -Download AISHELL-3. -```bash -wget https://www.openslr.org/resources/93/data_aishell3.tgz -``` -Extract AISHELL-3. -```bash -mkdir data_aishell3 -tar zxvf data_aishell3.tgz -C data_aishell3 -``` +Download AISHELL-3 from it's [Official Website](http://www.aishelltech.com/aishell_3) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/data_aishell3`. + ### Get MFA Result and Extract We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for aishell3_fastspeech2. You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) (use MFA1.x now) of our repo. diff --git a/examples/aishell3/voc1/README.md b/examples/aishell3/voc1/README.md index eb30e7c40..a3daf3dfd 100644 --- a/examples/aishell3/voc1/README.md +++ b/examples/aishell3/voc1/README.md @@ -4,15 +4,8 @@ This example contains code used to train a [parallel wavegan](http://arxiv.org/a AISHELL-3 is a large-scale and high-fidelity multi-speaker Mandarin speech corpus that could be used to train multi-speaker Text-to-Speech (TTS) systems. ## Dataset ### Download and Extract -Download AISHELL-3. -```bash -wget https://www.openslr.org/resources/93/data_aishell3.tgz -``` -Extract AISHELL-3. -```bash -mkdir data_aishell3 -tar zxvf data_aishell3.tgz -C data_aishell3 -``` +Download AISHELL-3 from it's [Official Website](http://www.aishelltech.com/aishell_3) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/data_aishell3`. + ### Get MFA Result and Extract We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for aishell3_fastspeech2. You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) (use MFA1.x now) of our repo. @@ -75,7 +68,7 @@ Train a ParallelWaveGAN model. optional arguments: -h, --help show this help message and exit - --config CONFIG config file to overwrite default config. + --config CONFIG ParallelWaveGAN config file. --train-metadata TRAIN_METADATA training data. --dev-metadata DEV_METADATA diff --git a/examples/aishell3/voc5/README.md b/examples/aishell3/voc5/README.md index c957c4a3a..c3e3197d6 100644 --- a/examples/aishell3/voc5/README.md +++ b/examples/aishell3/voc5/README.md @@ -4,15 +4,7 @@ This example contains code used to train a [HiFiGAN](https://arxiv.org/abs/2010. AISHELL-3 is a large-scale and high-fidelity multi-speaker Mandarin speech corpus that could be used to train multi-speaker Text-to-Speech (TTS) systems. ## Dataset ### Download and Extract -Download AISHELL-3. -```bash -wget https://www.openslr.org/resources/93/data_aishell3.tgz -``` -Extract AISHELL-3. -```bash -mkdir data_aishell3 -tar zxvf data_aishell3.tgz -C data_aishell3 -``` +Download AISHELL-3 from it's [Official Website](http://www.aishelltech.com/aishell_3) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/data_aishell3`. ### Get MFA Result and Extract We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for aishell3_fastspeech2. You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) (use MFA1.x now) of our repo. @@ -67,15 +59,13 @@ Here's the complete help message. ```text usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA] [--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR] - [--ngpu NGPU] [--batch-size BATCH_SIZE] [--max-iter MAX_ITER] - [--run-benchmark RUN_BENCHMARK] - [--profiler_options PROFILER_OPTIONS] + [--ngpu NGPU] -Train a ParallelWaveGAN model. +Train a HiFiGAN model. optional arguments: -h, --help show this help message and exit - --config CONFIG config file to overwrite default config. + --config CONFIG HiFiGAN config file. --train-metadata TRAIN_METADATA training data. --dev-metadata DEV_METADATA @@ -83,19 +73,6 @@ optional arguments: --output-dir OUTPUT_DIR output dir. --ngpu NGPU if ngpu == 0, use cpu. - -benchmark: - arguments related to benchmark. - - --batch-size BATCH_SIZE - batch size. - --max-iter MAX_ITER train max steps. - --run-benchmark RUN_BENCHMARK - runing benchmark or not, if True, use the --batch-size - and --max-iter. - --profiler_options PROFILER_OPTIONS - The option of profiler, which should be in format - "key1=value1;key2=value2;key3=value3". ``` 1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`. diff --git a/examples/ami/sd0/README.md b/examples/ami/sd0/README.md index e9ecc2854..30f7a438d 100644 --- a/examples/ami/sd0/README.md +++ b/examples/ami/sd0/README.md @@ -26,4 +26,7 @@ Use the following command to run diarization on AMI corpus. ./run.sh --data_folder ./amicorpus --manual_annot_folder ./ami_public_manual_1.6.2 ``` -## Results (DER) coming soon! :) +## Best performance in terms of Diarization Error Rate (DER). + | System | Mic. |Orcl. (Dev)|Orcl. (Eval)| Est. (Dev) |Est. (Eval)| + | --------|-------- | ---------|----------- | --------|-----------| + | ECAPA-TDNN + SC | HeadsetMix| 1.54 % | 3.07 %| 1.56 %| 3.28 % | diff --git a/examples/callcenter/asr1/local/train.sh b/examples/callcenter/asr1/local/train.sh index 03b4588e3..41da89e22 100755 --- a/examples/callcenter/asr1/local/train.sh +++ b/examples/callcenter/asr1/local/train.sh @@ -1,7 +1,7 @@ #! /usr/bin/env bash -if [ $# != 2 ];then - echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" +if [ $# -lt 2 ] && [ $# -gt 3 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)" exit -1 fi @@ -10,6 +10,13 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_name=$2 +ips=$3 + +if [ ! $ips ];then + ips_config= +else + ips_config="--ips="${ips} +fi echo "using ${device}..." @@ -28,7 +35,7 @@ python3 -u ${BIN_DIR}/train.py \ --output exp/${ckpt_name} \ --seed ${seed} else -python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${BIN_DIR}/train.py \ +python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \ --ngpu ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ diff --git a/examples/callcenter/asr1/run.sh b/examples/callcenter/asr1/run.sh index 0c7ffc1e7..7e3b912ab 100644 --- a/examples/callcenter/asr1/run.sh +++ b/examples/callcenter/asr1/run.sh @@ -6,6 +6,7 @@ gpus=0,1,2,3 stage=0 stop_stage=50 conf_path=conf/conformer.yaml +ips= #xx.xx.xx.xx,xx.xx.xx.xx decode_conf_path=conf/tuning/decode.yaml avg_num=20 @@ -22,7 +23,7 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then diff --git a/examples/csmsc/tts0/README.md b/examples/csmsc/tts0/README.md index 01376bd61..bc7769d15 100644 --- a/examples/csmsc/tts0/README.md +++ b/examples/csmsc/tts0/README.md @@ -3,7 +3,7 @@ This example contains code used to train a [Tacotron2](https://arxiv.org/abs/171 ## Dataset ### Download and Extract -Download CSMSC from it's [Official Website](https://test.data-baker.com/data/index/source). +Download CSMSC from it's [Official Website](https://test.data-baker.com/data/index/TNtts/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. ### Get MFA Result and Extract We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get phonemes for Tacotron2, the durations of MFA are not needed here. @@ -103,12 +103,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p ``` ```text usage: synthesize.py [-h] - [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc}] + [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--voice-cloning VOICE_CLONING] - [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] + [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_stat VOC_STAT] [--ngpu NGPU] [--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR] @@ -117,11 +117,10 @@ Synthesize with acoustic model & vocoder optional arguments: -h, --help show this help message and exit - --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc} + --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3} Choose acoustic model type of tts task. --am_config AM_CONFIG - Config of acoustic model. Use deault config when it is - None. + Config of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_stat AM_STAT mean and standard deviation used to normalize spectrogram when training acoustic model. @@ -133,10 +132,10 @@ optional arguments: speaker id map file. --voice-cloning VOICE_CLONING whether training voice cloning model. - --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} + --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc} Choose vocoder type of tts task. --voc_config VOC_CONFIG - Config of voc. Use deault config when it is None. + Config of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_stat VOC_STAT mean and standard deviation used to normalize spectrogram when training voc. @@ -152,12 +151,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp ``` ```text usage: synthesize_e2e.py [-h] - [--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc}] + [--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID] - [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc}] + [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_stat VOC_STAT] [--lang LANG] [--inference_dir INFERENCE_DIR] [--ngpu NGPU] @@ -167,11 +166,10 @@ Synthesize with acoustic model & vocoder optional arguments: -h, --help show this help message and exit - --am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc} + --am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech} Choose acoustic model type of tts task. --am_config AM_CONFIG - Config of acoustic model. Use deault config when it is - None. + Config of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_stat AM_STAT mean and standard deviation used to normalize spectrogram when training acoustic model. @@ -182,10 +180,10 @@ optional arguments: --speaker_dict SPEAKER_DICT speaker id map file. --spk_id SPK_ID spk id for multi speaker acoustic model - --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc} + --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc} Choose vocoder type of tts task. --voc_config VOC_CONFIG - Config of voc. Use deault config when it is None. + Config of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_stat VOC_STAT mean and standard deviation used to normalize spectrogram when training voc. @@ -198,9 +196,9 @@ optional arguments: output dir. ``` 1. `--am` is acoustic model type with the format {model_name}_{dataset} -2. `--am_config`, `--am_checkpoint`, `--am_stat` and `--phones_dict` are arguments for acoustic model, which correspond to the 4 files in the Tacotron2 pretrained model. +2. `--am_config`, `--am_ckpt`, `--am_stat` and `--phones_dict` are arguments for acoustic model, which correspond to the 4 files in the Tacotron2 pretrained model. 3. `--voc` is vocoder type with the format {model_name}_{dataset} -4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model. +4. `--voc_config`, `--voc_ckpt`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model. 5. `--lang` is the model language, which can be `zh` or `en`. 6. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder. 7. `--text` is the text file, which contains sentences to synthesize. diff --git a/examples/csmsc/tts2/README.md b/examples/csmsc/tts2/README.md index 081d85848..f45561719 100644 --- a/examples/csmsc/tts2/README.md +++ b/examples/csmsc/tts2/README.md @@ -3,7 +3,7 @@ This example contains code used to train a [SpeedySpeech](http://arxiv.org/abs/2 ## Dataset ### Download and Extract -Download CSMSC from it's [Official Website](https://test.data-baker.com/data/index/source). +Download CSMSC from it's [Official Website](https://test.data-baker.com/data/index/TNtts/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. ### Get MFA Result and Extract We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for SPEEDYSPEECH. @@ -109,12 +109,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p ``` ```text usage: synthesize.py [-h] - [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] + [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--voice-cloning VOICE_CLONING] - [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] + [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_stat VOC_STAT] [--ngpu NGPU] [--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR] @@ -123,11 +123,10 @@ Synthesize with acoustic model & vocoder optional arguments: -h, --help show this help message and exit - --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} + --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3} Choose acoustic model type of tts task. --am_config AM_CONFIG - Config of acoustic model. Use deault config when it is - None. + Config of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_stat AM_STAT mean and standard deviation used to normalize spectrogram when training acoustic model. @@ -139,10 +138,10 @@ optional arguments: speaker id map file. --voice-cloning VOICE_CLONING whether training voice cloning model. - --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} + --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc} Choose vocoder type of tts task. --voc_config VOC_CONFIG - Config of voc. Use deault config when it is None. + Config of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_stat VOC_STAT mean and standard deviation used to normalize spectrogram when training voc. @@ -158,12 +157,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp ``` ```text usage: synthesize_e2e.py [-h] - [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] + [--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID] - [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] + [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_stat VOC_STAT] [--lang LANG] [--inference_dir INFERENCE_DIR] [--ngpu NGPU] @@ -173,11 +172,10 @@ Synthesize with acoustic model & vocoder optional arguments: -h, --help show this help message and exit - --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} + --am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech} Choose acoustic model type of tts task. --am_config AM_CONFIG - Config of acoustic model. Use deault config when it is - None. + Config of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_stat AM_STAT mean and standard deviation used to normalize spectrogram when training acoustic model. @@ -188,10 +186,10 @@ optional arguments: --speaker_dict SPEAKER_DICT speaker id map file. --spk_id SPK_ID spk id for multi speaker acoustic model - --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} + --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc} Choose vocoder type of tts task. --voc_config VOC_CONFIG - Config of voc. Use deault config when it is None. + Config of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_stat VOC_STAT mean and standard deviation used to normalize spectrogram when training voc. @@ -204,9 +202,9 @@ optional arguments: output dir. ``` 1. `--am` is acoustic model type with the format {model_name}_{dataset} -2. `--am_config`, `--am_checkpoint`, `--am_stat`, `--phones_dict` and `--tones_dict` are arguments for acoustic model, which correspond to the 5 files in the speedyspeech pretrained model. +2. `--am_config`, `--am_ckpt`, `--am_stat`, `--phones_dict` and `--tones_dict` are arguments for acoustic model, which correspond to the 5 files in the speedyspeech pretrained model. 3. `--voc` is vocoder type with the format {model_name}_{dataset} -4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model. +4. `--voc_config`, `--voc_ckpt`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model. 5. `--lang` is the model language, which can be `zh` or `en`. 6. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder. 7. `--text` is the text file, which contains sentences to synthesize. diff --git a/examples/csmsc/tts3/README.md b/examples/csmsc/tts3/README.md index c734199b4..371034e77 100644 --- a/examples/csmsc/tts3/README.md +++ b/examples/csmsc/tts3/README.md @@ -4,7 +4,7 @@ This example contains code used to train a [Fastspeech2](https://arxiv.org/abs/2 ## Dataset ### Download and Extract -Download CSMSC from it's [Official Website](https://test.data-baker.com/data/index/source). +Download CSMSC from it's [Official Website](https://test.data-baker.com/data/index/TNtts/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. ### Get MFA Result and Extract We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for fastspeech2. @@ -111,12 +111,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p ``` ```text usage: synthesize.py [-h] - [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] + [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--voice-cloning VOICE_CLONING] - [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] + [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_stat VOC_STAT] [--ngpu NGPU] [--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR] @@ -125,11 +125,10 @@ Synthesize with acoustic model & vocoder optional arguments: -h, --help show this help message and exit - --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} + --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3} Choose acoustic model type of tts task. --am_config AM_CONFIG - Config of acoustic model. Use deault config when it is - None. + Config of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_stat AM_STAT mean and standard deviation used to normalize spectrogram when training acoustic model. @@ -141,10 +140,10 @@ optional arguments: speaker id map file. --voice-cloning VOICE_CLONING whether training voice cloning model. - --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} + --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc} Choose vocoder type of tts task. --voc_config VOC_CONFIG - Config of voc. Use deault config when it is None. + Config of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_stat VOC_STAT mean and standard deviation used to normalize spectrogram when training voc. @@ -160,12 +159,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp ``` ```text usage: synthesize_e2e.py [-h] - [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] + [--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID] - [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] + [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_stat VOC_STAT] [--lang LANG] [--inference_dir INFERENCE_DIR] [--ngpu NGPU] @@ -175,11 +174,10 @@ Synthesize with acoustic model & vocoder optional arguments: -h, --help show this help message and exit - --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} + --am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech} Choose acoustic model type of tts task. --am_config AM_CONFIG - Config of acoustic model. Use deault config when it is - None. + Config of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_stat AM_STAT mean and standard deviation used to normalize spectrogram when training acoustic model. @@ -190,10 +188,10 @@ optional arguments: --speaker_dict SPEAKER_DICT speaker id map file. --spk_id SPK_ID spk id for multi speaker acoustic model - --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} + --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc} Choose vocoder type of tts task. --voc_config VOC_CONFIG - Config of voc. Use deault config when it is None. + Config of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_stat VOC_STAT mean and standard deviation used to normalize spectrogram when training voc. @@ -204,11 +202,12 @@ optional arguments: --text TEXT text to synthesize, a 'utt_id sentence' pair per line. --output_dir OUTPUT_DIR output dir. + ``` 1. `--am` is acoustic model type with the format {model_name}_{dataset} -2. `--am_config`, `--am_checkpoint`, `--am_stat` and `--phones_dict` are arguments for acoustic model, which correspond to the 4 files in the fastspeech2 pretrained model. +2. `--am_config`, `--am_ckpt`, `--am_stat` and `--phones_dict` are arguments for acoustic model, which correspond to the 4 files in the fastspeech2 pretrained model. 3. `--voc` is vocoder type with the format {model_name}_{dataset} -4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model. +4. `--voc_config`, `--voc_ckpt`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model. 5. `--lang` is the model language, which can be `zh` or `en`. 6. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder. 7. `--text` is the text file, which contains sentences to synthesize. diff --git a/examples/csmsc/tts3/README_cn.md b/examples/csmsc/tts3/README_cn.md index 25931ecb1..1829b7706 100644 --- a/examples/csmsc/tts3/README_cn.md +++ b/examples/csmsc/tts3/README_cn.md @@ -5,7 +5,7 @@ ## 数据集 ### 下载并解压 -从 [官方网站](https://test.data-baker.com/data/index/source) 下载数据集 +从 [官方网站](https://test.data-baker.com/data/index/TNtts/) 下载数据集 ### 获取MFA结果并解压 我们使用 [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) 去获得 fastspeech2 的音素持续时间。 @@ -117,12 +117,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p ``` ```text usage: synthesize.py [-h] - [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] + [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--voice-cloning VOICE_CLONING] - [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] + [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_stat VOC_STAT] [--ngpu NGPU] [--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR] @@ -131,11 +131,10 @@ Synthesize with acoustic model & vocoder optional arguments: -h, --help show this help message and exit - --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} + --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3} Choose acoustic model type of tts task. --am_config AM_CONFIG - Config of acoustic model. Use deault config when it is - None. + Config of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_stat AM_STAT mean and standard deviation used to normalize spectrogram when training acoustic model. @@ -147,10 +146,10 @@ optional arguments: speaker id map file. --voice-cloning VOICE_CLONING whether training voice cloning model. - --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} + --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc} Choose vocoder type of tts task. --voc_config VOC_CONFIG - Config of voc. Use deault config when it is None. + Config of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_stat VOC_STAT mean and standard deviation used to normalize spectrogram when training voc. @@ -167,12 +166,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp ``` ```text usage: synthesize_e2e.py [-h] - [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] + [--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID] - [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] + [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_stat VOC_STAT] [--lang LANG] [--inference_dir INFERENCE_DIR] [--ngpu NGPU] @@ -182,11 +181,10 @@ Synthesize with acoustic model & vocoder optional arguments: -h, --help show this help message and exit - --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} + --am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech} Choose acoustic model type of tts task. --am_config AM_CONFIG - Config of acoustic model. Use deault config when it is - None. + Config of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_stat AM_STAT mean and standard deviation used to normalize spectrogram when training acoustic model. @@ -197,10 +195,10 @@ optional arguments: --speaker_dict SPEAKER_DICT speaker id map file. --spk_id SPK_ID spk id for multi speaker acoustic model - --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} + --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc} Choose vocoder type of tts task. --voc_config VOC_CONFIG - Config of voc. Use deault config when it is None. + Config of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_stat VOC_STAT mean and standard deviation used to normalize spectrogram when training voc. @@ -213,9 +211,9 @@ optional arguments: output dir. ``` 1. `--am` 声学模型格式是否符合 {model_name}_{dataset} -2. `--am_config`, `--am_checkpoint`, `--am_stat` 和 `--phones_dict` 是声学模型的参数,对应于 fastspeech2 预训练模型中的 4 个文件。 +2. `--am_config`, `--am_ckpt`, `--am_stat` 和 `--phones_dict` 是声学模型的参数,对应于 fastspeech2 预训练模型中的 4 个文件。 3. `--voc` 声码器(vocoder)格式是否符合 {model_name}_{dataset} -4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` 是声码器的参数,对应于 parallel wavegan 预训练模型中的 3 个文件。 +4. `--voc_config`, `--voc_ckpt`, `--voc_stat` 是声码器的参数,对应于 parallel wavegan 预训练模型中的 3 个文件。 5. `--lang` 对应模型的语言可以是 `zh` 或 `en` 。 6. `--test_metadata` 应为 `dump` 文件夹中 `test` 下的规范化元数据文件、 7. `--text` 是文本文件,其中包含要合成的句子。 diff --git a/examples/csmsc/tts3/conf/default.yaml b/examples/csmsc/tts3/conf/default.yaml index 2c2a1ea10..08b6f75ba 100644 --- a/examples/csmsc/tts3/conf/default.yaml +++ b/examples/csmsc/tts3/conf/default.yaml @@ -86,8 +86,8 @@ updater: # OPTIMIZER SETTING # ########################################################### optimizer: - optim: adam # optimizer type - learning_rate: 0.001 # learning rate + optim: adam # optimizer type + learning_rate: 0.001 # learning rate ########################################################### # TRAINING SETTING # diff --git a/examples/csmsc/vits/README.md b/examples/csmsc/vits/README.md new file mode 100644 index 000000000..0c16840a0 --- /dev/null +++ b/examples/csmsc/vits/README.md @@ -0,0 +1,146 @@ +# VITS with CSMSC +This example contains code used to train a [VITS](https://arxiv.org/abs/2106.06103) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html). + +## Dataset +### Download and Extract +Download CSMSC from it's [Official Website](https://test.data-baker.com/data/index/source). + +### Get MFA Result and Extract +We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get phonemes for VITS, the durations of MFA are not needed here. +You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) of our repo. + +## Get Started +Assume the path to the dataset is `~/datasets/BZNSYP`. +Assume the path to the MFA result of CSMSC is `./baker_alignment_tone`. +Run the command below to +1. **source path**. +2. preprocess the dataset. +3. train the model. +4. synthesize wavs. + - synthesize waveform from `metadata.jsonl`. + - synthesize waveform from a text file. + +```bash +./run.sh +``` +You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, running the following command will only preprocess the dataset. +```bash +./run.sh --stage 0 --stop-stage 0 +``` +### Data Preprocessing +```bash +./local/preprocess.sh ${conf_path} +``` +When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below. + +```text +dump +├── dev +│   ├── norm +│   └── raw +├── phone_id_map.txt +├── speaker_id_map.txt +├── test +│   ├── norm +│   └── raw +└── train + ├── feats_stats.npy + ├── norm + └── raw +``` +The dataset is split into 3 parts, namely `train`, `dev`, and` test`, each of which contains a `norm` and `raw` subfolder. The raw folder contains wave and linear spectrogram of each utterance, while the norm folder contains normalized ones. The statistics used to normalize features are computed from the training set, which is located in `dump/train/feats_stats.npy`. + +Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains phones, text_lengths, feats, feats_lengths, the path of linear spectrogram features, the path of raw waves, speaker, and the id of each utterance. + +### Model Training +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} +``` +`./local/train.sh` calls `${BIN_DIR}/train.py`. +Here's the complete help message. +```text +usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA] + [--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR] + [--ngpu NGPU] [--phones-dict PHONES_DICT] + +Train a VITS model. + +optional arguments: + -h, --help show this help message and exit + --config CONFIG config file to overwrite default config. + --train-metadata TRAIN_METADATA + training data. + --dev-metadata DEV_METADATA + dev data. + --output-dir OUTPUT_DIR + output dir. + --ngpu NGPU if ngpu == 0, use cpu. + --phones-dict PHONES_DICT + phone vocabulary file. +``` +1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`. +2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder. +3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are saved in `checkpoints/` inside this directory. +4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. +5. `--phones-dict` is the path of the phone vocabulary file. + +### Synthesizing + +`./local/synthesize.sh` calls `${BIN_DIR}/synthesize.py`, which can synthesize waveform from `metadata.jsonl`. + +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} +``` +```text +usage: synthesize.py [-h] [--config CONFIG] [--ckpt CKPT] + [--phones_dict PHONES_DICT] [--ngpu NGPU] + [--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR] + +Synthesize with VITS + +optional arguments: + -h, --help show this help message and exit + --config CONFIG Config of VITS. + --ckpt CKPT Checkpoint file of VITS. + --phones_dict PHONES_DICT + phone vocabulary file. + --ngpu NGPU if ngpu == 0, use cpu. + --test_metadata TEST_METADATA + test metadata. + --output_dir OUTPUT_DIR + output dir. +``` +`./local/synthesize_e2e.sh` calls `${BIN_DIR}/synthesize_e2e.py`, which can synthesize waveform from text file. +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} +``` +```text +usage: synthesize_e2e.py [-h] [--config CONFIG] [--ckpt CKPT] + [--phones_dict PHONES_DICT] [--lang LANG] + [--inference_dir INFERENCE_DIR] [--ngpu NGPU] + [--text TEXT] [--output_dir OUTPUT_DIR] + +Synthesize with VITS + +optional arguments: + -h, --help show this help message and exit + --config CONFIG Config of VITS. + --ckpt CKPT Checkpoint file of VITS. + --phones_dict PHONES_DICT + phone vocabulary file. + --lang LANG Choose model language. zh or en + --inference_dir INFERENCE_DIR + dir to save inference models + --ngpu NGPU if ngpu == 0, use cpu. + --text TEXT text to synthesize, a 'utt_id sentence' pair per line. + --output_dir OUTPUT_DIR + output dir. +``` +1. `--config`, `--ckpt`, and `--phones_dict` are arguments for acoustic model, which correspond to the 3 files in the VITS pretrained model. +2. `--lang` is the model language, which can be `zh` or `en`. +3. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder. +4. `--text` is the text file, which contains sentences to synthesize. +5. `--output_dir` is the directory to save synthesized audio files. +6. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. + +## Pretrained Model diff --git a/examples/csmsc/vits/conf/default.yaml b/examples/csmsc/vits/conf/default.yaml new file mode 100644 index 000000000..47af780dc --- /dev/null +++ b/examples/csmsc/vits/conf/default.yaml @@ -0,0 +1,183 @@ +# This configuration tested on 4 GPUs (V100) with 32GB GPU +# memory. It takes around 2 weeks to finish the training +# but 100k iters model should generate reasonable results. +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### + +fs: 22050 # sr +n_fft: 1024 # FFT size (samples). +n_shift: 256 # Hop size (samples). 12.5ms +win_length: null # Window length (samples). 50ms + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. + + +########################################################## +# TTS MODEL SETTING # +########################################################## +model: + # generator related + generator_type: vits_generator + generator_params: + hidden_channels: 192 + spks: -1 + global_channels: -1 + segment_size: 32 + text_encoder_attention_heads: 2 + text_encoder_ffn_expand: 4 + text_encoder_blocks: 6 + text_encoder_positionwise_layer_type: "conv1d" + text_encoder_positionwise_conv_kernel_size: 3 + text_encoder_positional_encoding_layer_type: "rel_pos" + text_encoder_self_attention_layer_type: "rel_selfattn" + text_encoder_activation_type: "swish" + text_encoder_normalize_before: True + text_encoder_dropout_rate: 0.1 + text_encoder_positional_dropout_rate: 0.0 + text_encoder_attention_dropout_rate: 0.1 + use_macaron_style_in_text_encoder: True + use_conformer_conv_in_text_encoder: False + text_encoder_conformer_kernel_size: -1 + decoder_kernel_size: 7 + decoder_channels: 512 + decoder_upsample_scales: [8, 8, 2, 2] + decoder_upsample_kernel_sizes: [16, 16, 4, 4] + decoder_resblock_kernel_sizes: [3, 7, 11] + decoder_resblock_dilations: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + use_weight_norm_in_decoder: True + posterior_encoder_kernel_size: 5 + posterior_encoder_layers: 16 + posterior_encoder_stacks: 1 + posterior_encoder_base_dilation: 1 + posterior_encoder_dropout_rate: 0.0 + use_weight_norm_in_posterior_encoder: True + flow_flows: 4 + flow_kernel_size: 5 + flow_base_dilation: 1 + flow_layers: 4 + flow_dropout_rate: 0.0 + use_weight_norm_in_flow: True + use_only_mean_in_flow: True + stochastic_duration_predictor_kernel_size: 3 + stochastic_duration_predictor_dropout_rate: 0.5 + stochastic_duration_predictor_flows: 4 + stochastic_duration_predictor_dds_conv_layers: 3 + # discriminator related + discriminator_type: hifigan_multi_scale_multi_period_discriminator + discriminator_params: + scales: 1 + scale_downsample_pooling: "AvgPool1D" + scale_downsample_pooling_params: + kernel_size: 4 + stride: 2 + padding: 2 + scale_discriminator_params: + in_channels: 1 + out_channels: 1 + kernel_sizes: [15, 41, 5, 3] + channels: 128 + max_downsample_channels: 1024 + max_groups: 16 + bias: True + downsample_scales: [2, 2, 4, 4, 1] + nonlinear_activation: "leakyrelu" + nonlinear_activation_params: + negative_slope: 0.1 + use_weight_norm: True + use_spectral_norm: False + follow_official_norm: False + periods: [2, 3, 5, 7, 11] + period_discriminator_params: + in_channels: 1 + out_channels: 1 + kernel_sizes: [5, 3] + channels: 32 + downsample_scales: [3, 3, 3, 3, 1] + max_downsample_channels: 1024 + bias: True + nonlinear_activation: "leakyrelu" + nonlinear_activation_params: + negative_slope: 0.1 + use_weight_norm: True + use_spectral_norm: False + # others + sampling_rate: 22050 # needed in the inference for saving wav + cache_generator_outputs: True # whether to cache generator outputs in the training + +########################################################### +# LOSS SETTING # +########################################################### +# loss function related +generator_adv_loss_params: + average_by_discriminators: False # whether to average loss value by #discriminators + loss_type: mse # loss type, "mse" or "hinge" +discriminator_adv_loss_params: + average_by_discriminators: False # whether to average loss value by #discriminators + loss_type: mse # loss type, "mse" or "hinge" +feat_match_loss_params: + average_by_discriminators: False # whether to average loss value by #discriminators + average_by_layers: False # whether to average loss value by #layers of each discriminator + include_final_outputs: True # whether to include final outputs for loss calculation +mel_loss_params: + fs: 22050 # must be the same as the training data + fft_size: 1024 # fft points + hop_size: 256 # hop size + win_length: null # window length + window: hann # window type + num_mels: 80 # number of Mel basis + fmin: 0 # minimum frequency for Mel basis + fmax: null # maximum frequency for Mel basis + log_base: null # null represent natural log + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +lambda_adv: 1.0 # loss scaling coefficient for adversarial loss +lambda_mel: 45.0 # loss scaling coefficient for Mel loss +lambda_feat_match: 2.0 # loss scaling coefficient for feat match loss +lambda_dur: 1.0 # loss scaling coefficient for duration loss +lambda_kl: 1.0 # loss scaling coefficient for KL divergence loss +# others +sampling_rate: 22050 # needed in the inference for saving wav +cache_generator_outputs: True # whether to cache generator outputs in the training + + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 64 # Batch size. +num_workers: 4 # Number of workers in DataLoader. + +########################################################## +# OPTIMIZER & SCHEDULER SETTING # +########################################################## +# optimizer setting for generator +generator_optimizer_params: + beta1: 0.8 + beta2: 0.99 + epsilon: 1.0e-9 + weight_decay: 0.0 +generator_scheduler: exponential_decay +generator_scheduler_params: + learning_rate: 2.0e-4 + gamma: 0.999875 + +# optimizer setting for discriminator +discriminator_optimizer_params: + beta1: 0.8 + beta2: 0.99 + epsilon: 1.0e-9 + weight_decay: 0.0 +discriminator_scheduler: exponential_decay +discriminator_scheduler_params: + learning_rate: 2.0e-4 + gamma: 0.999875 +generator_first: False # whether to start updating generator first + +########################################################## +# OTHER TRAINING SETTING # +########################################################## +max_epoch: 1000 # number of epochs +num_snapshots: 10 # max number of snapshots to keep while training +seed: 777 # random seed number diff --git a/examples/csmsc/vits/local/preprocess.sh b/examples/csmsc/vits/local/preprocess.sh new file mode 100755 index 000000000..1d3ae5937 --- /dev/null +++ b/examples/csmsc/vits/local/preprocess.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +stage=0 +stop_stage=100 + +config_path=$1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # get durations from MFA's result + echo "Generate durations.txt from MFA results ..." + python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \ + --inputdir=./baker_alignment_tone \ + --output=durations.txt \ + --config=${config_path} +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # extract features + echo "Extract features ..." + python3 ${BIN_DIR}/preprocess.py \ + --dataset=baker \ + --rootdir=~/datasets/BZNSYP/ \ + --dumpdir=dump \ + --dur-file=durations.txt \ + --config=${config_path} \ + --num-cpu=20 \ + --cut-sil=True +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # get features' stats(mean and std) + echo "Get features' stats ..." + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="feats" +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # normalize and covert phone/speaker to id, dev and test should use train's stats + echo "Normalize ..." + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --dumpdir=dump/train/norm \ + --feats-stats=dump/train/feats_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt \ + --skip-wav-copy + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/dev/raw/metadata.jsonl \ + --dumpdir=dump/dev/norm \ + --feats-stats=dump/train/feats_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt \ + --skip-wav-copy + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/test/raw/metadata.jsonl \ + --dumpdir=dump/test/norm \ + --feats-stats=dump/train/feats_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt \ + --skip-wav-copy +fi diff --git a/examples/csmsc/vits/local/synthesize.sh b/examples/csmsc/vits/local/synthesize.sh new file mode 100755 index 000000000..c15d5f99f --- /dev/null +++ b/examples/csmsc/vits/local/synthesize.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 +stage=0 +stop_stage=0 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/synthesize.py \ + --config=${config_path} \ + --ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --phones_dict=dump/phone_id_map.txt \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test +fi \ No newline at end of file diff --git a/examples/csmsc/vits/local/synthesize_e2e.sh b/examples/csmsc/vits/local/synthesize_e2e.sh new file mode 100755 index 000000000..edbb07bfc --- /dev/null +++ b/examples/csmsc/vits/local/synthesize_e2e.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 +stage=0 +stop_stage=0 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/synthesize_e2e.py \ + --config=${config_path} \ + --ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --phones_dict=dump/phone_id_map.txt \ + --output_dir=${train_output_path}/test_e2e \ + --text=${BIN_DIR}/../sentences.txt +fi diff --git a/examples/csmsc/vits/local/train.sh b/examples/csmsc/vits/local/train.sh new file mode 100755 index 000000000..42fff26ca --- /dev/null +++ b/examples/csmsc/vits/local/train.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 + +python3 ${BIN_DIR}/train.py \ + --train-metadata=dump/train/norm/metadata.jsonl \ + --dev-metadata=dump/dev/norm/metadata.jsonl \ + --config=${config_path} \ + --output-dir=${train_output_path} \ + --ngpu=4 \ + --phones-dict=dump/phone_id_map.txt diff --git a/examples/csmsc/vits/path.sh b/examples/csmsc/vits/path.sh new file mode 100755 index 000000000..52d0c3783 --- /dev/null +++ b/examples/csmsc/vits/path.sh @@ -0,0 +1,13 @@ +#!/bin/bash +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +MODEL=vits +export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL} \ No newline at end of file diff --git a/examples/csmsc/vits/run.sh b/examples/csmsc/vits/run.sh new file mode 100755 index 000000000..80e56e7c1 --- /dev/null +++ b/examples/csmsc/vits/run.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +set -e +source path.sh + +gpus=0,1 +stage=0 +stop_stage=100 + +conf_path=conf/default.yaml +train_output_path=exp/default +ckpt_name=snapshot_iter_153.pdz + +# with the following command, you can choose the stage range you want to run +# such as `./run.sh --stage 0 --stop-stage 0` +# this can not be mixed use with `$1`, `$2` ... +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + ./local/preprocess.sh ${conf_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `train_output_path/checkpoints/` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # synthesize_e2e, vocoder is pwgan + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi diff --git a/examples/csmsc/voc1/README.md b/examples/csmsc/voc1/README.md index 77da5b185..4646a0345 100644 --- a/examples/csmsc/voc1/README.md +++ b/examples/csmsc/voc1/README.md @@ -2,7 +2,7 @@ This example contains code used to train a [parallel wavegan](http://arxiv.org/abs/1910.11480) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html). ## Dataset ### Download and Extract -Download CSMSC from the [official website](https://www.data-baker.com/data/index/source) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. +Download CSMSC from it's [official website](https://test.data-baker.com/data/index/TNtts/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. ### Get MFA Result and Extract We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence at the edge of audio. @@ -65,7 +65,7 @@ Train a ParallelWaveGAN model. optional arguments: -h, --help show this help message and exit - --config CONFIG config file to overwrite default config. + --config CONFIG ParallelWaveGAN config file. --train-metadata TRAIN_METADATA training data. --dev-metadata DEV_METADATA diff --git a/examples/csmsc/voc3/README.md b/examples/csmsc/voc3/README.md index 12adaf7f4..09fb8836c 100644 --- a/examples/csmsc/voc3/README.md +++ b/examples/csmsc/voc3/README.md @@ -2,7 +2,7 @@ This example contains code used to train a [Multi Band MelGAN](https://arxiv.org/abs/2005.05106) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html). ## Dataset ### Download and Extract -Download CSMSC from the [official website](https://www.data-baker.com/data/index/source) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. +Download CSMSC from it's [official website](https://test.data-baker.com/data/index/TNtts/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. ### Get MFA Result and Extract We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio. @@ -63,7 +63,7 @@ Train a Multi-Band MelGAN model. optional arguments: -h, --help show this help message and exit - --config CONFIG config file to overwrite default config. + --config CONFIG Multi-Band MelGAN config file. --train-metadata TRAIN_METADATA training data. --dev-metadata DEV_METADATA diff --git a/examples/csmsc/voc4/README.md b/examples/csmsc/voc4/README.md index b7add3e57..f1a132a84 100644 --- a/examples/csmsc/voc4/README.md +++ b/examples/csmsc/voc4/README.md @@ -2,7 +2,7 @@ This example contains code used to train a [Style MelGAN](https://arxiv.org/abs/2011.01557) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html). ## Dataset ### Download and Extract -Download CSMSC from the [official website](https://www.data-baker.com/data/index/source) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. +Download CSMSC from it's [official website](https://test.data-baker.com/data/index/TNtts/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. ### Get MFA Result and Extract We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio. @@ -63,7 +63,7 @@ Train a Style MelGAN model. optional arguments: -h, --help show this help message and exit - --config CONFIG config file to overwrite default config. + --config CONFIG Style MelGAN config file. --train-metadata TRAIN_METADATA training data. --dev-metadata DEV_METADATA diff --git a/examples/csmsc/voc5/README.md b/examples/csmsc/voc5/README.md index 33e676165..ef552fd30 100644 --- a/examples/csmsc/voc5/README.md +++ b/examples/csmsc/voc5/README.md @@ -2,7 +2,7 @@ This example contains code used to train a [HiFiGAN](https://arxiv.org/abs/2010.05646) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html). ## Dataset ### Download and Extract -Download CSMSC from the [official website](https://www.data-baker.com/data/index/source) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. +Download CSMSC from it's [official website](https://test.data-baker.com/data/index/TNtts/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. ### Get MFA Result and Extract We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence at the edge of audio. @@ -63,7 +63,7 @@ Train a HiFiGAN model. optional arguments: -h, --help show this help message and exit - --config CONFIG config file to overwrite default config. + --config CONFIG HiFiGAN config file. --train-metadata TRAIN_METADATA training data. --dev-metadata DEV_METADATA @@ -130,7 +130,7 @@ HiFiGAN checkpoint contains files listed below. ```text hifigan_csmsc_ckpt_0.1.1 ├── default.yaml # default config used to train hifigan -├── feats_stats.npy # statistics used to normalize spectrogram when training hifigan +├── feats_stats.npy # statistics used to normalize spectrogram when training hifigan └── snapshot_iter_2500000.pdz # generator parameters of hifigan ``` diff --git a/examples/csmsc/voc6/README.md b/examples/csmsc/voc6/README.md index 7dcf133bd..b48c36414 100644 --- a/examples/csmsc/voc6/README.md +++ b/examples/csmsc/voc6/README.md @@ -2,7 +2,7 @@ This example contains code used to train a [WaveRNN](https://arxiv.org/abs/1802.08435) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html). ## Dataset ### Download and Extract -Download CSMSC from the [official website](https://www.data-baker.com/data/index/source) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. +Download CSMSC from it's [official website](https://test.data-baker.com/data/index/TNtts/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. ### Get MFA Result and Extract We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence at the edge of audio. @@ -63,7 +63,7 @@ Train a WaveRNN model. optional arguments: -h, --help show this help message and exit - --config CONFIG config file to overwrite default config. + --config CONFIG WaveRNN config file. --train-metadata TRAIN_METADATA training data. --dev-metadata DEV_METADATA diff --git a/examples/esc50/cls0/conf/panns.yaml b/examples/esc50/cls0/conf/panns.yaml index 3a9d42aa5..1f0323f0d 100644 --- a/examples/esc50/cls0/conf/panns.yaml +++ b/examples/esc50/cls0/conf/panns.yaml @@ -1,5 +1,5 @@ data: - dataset: 'paddleaudio.datasets:ESC50' + dataset: 'paddlespeech.audio.datasets:ESC50' num_classes: 50 train: mode: 'train' diff --git a/examples/hey_snips/kws0/conf/mdtc.yaml b/examples/hey_snips/kws0/conf/mdtc.yaml index 4bd0708ce..76e47bc7c 100644 --- a/examples/hey_snips/kws0/conf/mdtc.yaml +++ b/examples/hey_snips/kws0/conf/mdtc.yaml @@ -2,7 +2,7 @@ ########################################### # Data # ########################################### -dataset: 'paddleaudio.datasets:HeySnips' +dataset: 'paddlespeech.audio.datasets:HeySnips' data_dir: '/PATH/TO/DATA/hey_snips_research_6k_en_train_eval_clean_ter' ############################################ diff --git a/examples/librispeech/asr0/RESULTS.md b/examples/librispeech/asr0/RESULTS.md index 9f6d1cc04..5e5ce387b 100644 --- a/examples/librispeech/asr0/RESULTS.md +++ b/examples/librispeech/asr0/RESULTS.md @@ -3,6 +3,7 @@ ## Deepspeech2 Non-Streaming | Model | Params | release | Config | Test set | Loss | WER | | --- | --- | --- | --- | --- | --- | --- | +| DeepSpeech2 | 113.96M | r1.0.1 | conf/deepspeech2.yaml + U2 Data pipline and spec aug + fbank161 | test-clean | 10.76069622039795 | 0.046700 | | DeepSpeech2 | 42.96M | 2.2.0 | conf/deepspeech2.yaml + spec_aug | test-clean | 14.49190807 | 0.067283 | | DeepSpeech2 | 42.96M | 2.1.0 | conf/deepspeech2.yaml | test-clean | 15.184467315673828 | 0.072154 | | DeepSpeech2 | 42.96M | 2.0.0 | conf/deepspeech2.yaml | test-clean | - | 0.073973 | diff --git a/examples/librispeech/asr0/conf/augmentation.json b/examples/librispeech/asr0/conf/augmentation.json deleted file mode 100644 index 31c481c8d..000000000 --- a/examples/librispeech/asr0/conf/augmentation.json +++ /dev/null @@ -1,36 +0,0 @@ -[ - { - "type": "speed", - "params": { - "min_speed_rate": 0.9, - "max_speed_rate": 1.1, - "num_rates": 3 - }, - "prob": 0.0 - }, - { - "type": "shift", - "params": { - "min_shift_ms": -5, - "max_shift_ms": 5 - }, - "prob": 1.0 - }, - { - "type": "specaug", - "params": { - "W": 0, - "warp_mode": "PIL", - "F": 10, - "n_freq_masks": 2, - "T": 50, - "n_time_masks": 2, - "p": 1.0, - "adaptive_number_ratio": 0, - "adaptive_size_ratio": 0, - "max_n_time_masks": 20, - "replace_with_zero": true - }, - "prob": 1.0 - } -] diff --git a/examples/librispeech/asr0/conf/deepspeech2.yaml b/examples/librispeech/asr0/conf/deepspeech2.yaml index 0307b9f39..cca695fe5 100644 --- a/examples/librispeech/asr0/conf/deepspeech2.yaml +++ b/examples/librispeech/asr0/conf/deepspeech2.yaml @@ -15,51 +15,51 @@ max_output_input_ratio: .inf ########################################### # Dataloader # ########################################### -batch_size: 20 -mean_std_filepath: data/mean_std.json -unit_type: char -vocab_filepath: data/lang_char/vocab.txt -augmentation_config: conf/augmentation.json -random_seed: 0 -spm_model_prefix: -spectrum_type: linear -feat_dim: -target_sample_rate: 16000 -max_freq: None -n_fft: None +vocab_filepath: data/lang_char/vocab.txt +spm_model_prefix: '' +unit_type: 'char' +preprocess_config: conf/preprocess.yaml +feat_dim: 161 stride_ms: 10.0 -window_ms: 20.0 -delta_delta: False -dither: 1.0 -use_dB_normalization: True -target_dB: -20 -random_seed: 0 -keep_transcription_text: False -sortagrad: True -shuffle_method: batch_shuffle -num_workers: 2 +window_ms: 25.0 +sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs +batch_size: 64 +maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced +minibatches: 0 # for debug +batch_count: auto +batch_bins: 0 +batch_frames_in: 0 +batch_frames_out: 0 +batch_frames_inout: 0 +num_workers: 8 +subsampling_factor: 1 +num_encs: 1 ############################################ # Network Architecture # ############################################ num_conv_layers: 2 -num_rnn_layers: 3 -rnn_layer_size: 2048 +num_rnn_layers: 5 +rnn_layer_size: 1024 +rnn_direction: bidirect +num_fc_layers: 0 +fc_layers_size_list: -1 use_gru: False -share_rnn_weights: True blank_id: 0 ########################################### # Training # ########################################### -n_epoch: 50 +n_epoch: 15 accum_grad: 1 -lr: 1.0e-3 -lr_decay: 0.83 +lr: 5.0e-4 +lr_decay: 0.93 weight_decay: 1.0e-6 global_grad_clip: 5.0 -log_interval: 100 +dist_sampler: False +log_interval: 1 checkpoint: kbest_n: 50 latest_n: 5 diff --git a/examples/librispeech/asr0/conf/deepspeech2_online.yaml b/examples/librispeech/asr0/conf/deepspeech2_online.yaml index a0d2bcfe2..93421ef44 100644 --- a/examples/librispeech/asr0/conf/deepspeech2_online.yaml +++ b/examples/librispeech/asr0/conf/deepspeech2_online.yaml @@ -15,39 +15,36 @@ max_output_input_ratio: .inf ########################################### # Dataloader # ########################################### -batch_size: 15 -mean_std_filepath: data/mean_std.json -unit_type: char -vocab_filepath: data/lang_char/vocab.txt -augmentation_config: conf/augmentation.json -random_seed: 0 -spm_model_prefix: -spectrum_type: linear -feat_dim: -target_sample_rate: 16000 -max_freq: None -n_fft: None +vocab_filepath: data/lang_char/vocab.txt +spm_model_prefix: '' +unit_type: 'char' +preprocess_config: conf/preprocess.yaml +feat_dim: 161 stride_ms: 10.0 -window_ms: 20.0 -delta_delta: False -dither: 1.0 -use_dB_normalization: True -target_dB: -20 -random_seed: 0 -keep_transcription_text: False -sortagrad: True -shuffle_method: batch_shuffle -num_workers: 0 +window_ms: 25.0 +sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs +batch_size: 64 +maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced +minibatches: 0 # for debug +batch_count: auto +batch_bins: 0 +batch_frames_in: 0 +batch_frames_out: 0 +batch_frames_inout: 0 +num_workers: 8 +subsampling_factor: 1 +num_encs: 1 ############################################ # Network Architecture # ############################################ num_conv_layers: 2 -num_rnn_layers: 3 -rnn_layer_size: 2048 +num_rnn_layers: 5 +rnn_layer_size: 1024 rnn_direction: forward -num_fc_layers: 2 -fc_layers_size_list: 512, 256 +num_fc_layers: 0 +fc_layers_size_list: -1 use_gru: False blank_id: 0 @@ -55,13 +52,13 @@ blank_id: 0 ########################################### # Training # ########################################### -n_epoch: 50 -accum_grad: 4 -lr: 1.0e-3 -lr_decay: 0.83 +n_epoch: 65 +accum_grad: 1 +lr: 5.0e-4 +lr_decay: 0.93 weight_decay: 1.0e-6 global_grad_clip: 5.0 -log_interval: 100 +log_interval: 1 checkpoint: kbest_n: 50 latest_n: 5 diff --git a/examples/librispeech/asr0/conf/preprocess.yaml b/examples/librispeech/asr0/conf/preprocess.yaml new file mode 100644 index 000000000..3f526e0ad --- /dev/null +++ b/examples/librispeech/asr0/conf/preprocess.yaml @@ -0,0 +1,25 @@ +process: + # extract kaldi fbank from PCM + - type: fbank_kaldi + fs: 16000 + n_mels: 161 + n_shift: 160 + win_length: 400 + dither: 0.1 + - type: cmvn_json + cmvn_path: data/mean_std.json + # these three processes are a.k.a. SpecAugument + - type: time_warp + max_time_warp: 5 + inplace: true + mode: PIL + - type: freq_mask + F: 30 + n_mask: 2 + inplace: true + replace_with_zero: false + - type: time_mask + T: 40 + n_mask: 2 + inplace: true + replace_with_zero: false diff --git a/examples/librispeech/asr0/local/data.sh b/examples/librispeech/asr0/local/data.sh index b97e8c211..a28fddc96 100755 --- a/examples/librispeech/asr0/local/data.sh +++ b/examples/librispeech/asr0/local/data.sh @@ -49,12 +49,13 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ --manifest_path="data/manifest.train.raw" \ --num_samples=2000 \ - --spectrum_type="linear" \ + --spectrum_type="fbank" \ + --feat_dim=161 \ --delta_delta=false \ --sample_rate=16000 \ --stride_ms=10 \ - --window_ms=20 \ - --use_dB_normalization=True \ + --window_ms=25 \ + --use_dB_normalization=False \ --num_workers=${num_workers} \ --output_path="data/mean_std.json" diff --git a/examples/librispeech/asr0/local/export.sh b/examples/librispeech/asr0/local/export.sh index 426a72fe5..ce7e6d642 100755 --- a/examples/librispeech/asr0/local/export.sh +++ b/examples/librispeech/asr0/local/export.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 4 ];then - echo "usage: $0 config_path ckpt_prefix jit_model_path model_type" +if [ $# != 3 ];then + echo "usage: $0 config_path ckpt_prefix jit_model_path" exit -1 fi @@ -11,14 +11,12 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_path_prefix=$2 jit_model_export_path=$3 -model_type=$4 python3 -u ${BIN_DIR}/export.py \ --ngpu ${ngpu} \ --config ${config_path} \ --checkpoint_path ${ckpt_path_prefix} \ ---export_path ${jit_model_export_path} \ ---model_type ${model_type} +--export_path ${jit_model_export_path} if [ $? -ne 0 ]; then echo "Failed in export!" diff --git a/examples/librispeech/asr0/local/test.sh b/examples/librispeech/asr0/local/test.sh index ea40046b1..728569d1f 100755 --- a/examples/librispeech/asr0/local/test.sh +++ b/examples/librispeech/asr0/local/test.sh @@ -1,9 +1,11 @@ #!/bin/bash -if [ $# != 4 ];then - echo "usage: ${0} config_path decode_config_path ckpt_path_prefix model_type" +if [ $# != 3 ];then + echo "usage: ${0} config_path decode_config_path ckpt_path_prefix" exit -1 fi +stage=0 +stop_stage=100 ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." @@ -11,7 +13,6 @@ echo "using $ngpu gpus..." config_path=$1 decode_config_path=$2 ckpt_prefix=$3 -model_type=$4 # download language model bash local/download_lm_en.sh @@ -19,17 +20,43 @@ if [ $? -ne 0 ]; then exit 1 fi -python3 -u ${BIN_DIR}/test.py \ ---ngpu ${ngpu} \ ---config ${config_path} \ ---decode_cfg ${decode_config_path} \ ---result_file ${ckpt_prefix}.rsl \ ---checkpoint_path ${ckpt_prefix} \ ---model_type ${model_type} +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # format the reference test file + python3 utils/format_rsl.py \ + --origin_ref data/manifest.test-clean.raw \ + --trans_ref data/manifest.test-clean.text + + python3 -u ${BIN_DIR}/test.py \ + --ngpu ${ngpu} \ + --config ${config_path} \ + --decode_cfg ${decode_config_path} \ + --result_file ${ckpt_prefix}.rsl \ + --checkpoint_path ${ckpt_prefix} + + if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 + fi + + python3 utils/format_rsl.py \ + --origin_hyp ${ckpt_prefix}.rsl \ + --trans_hyp ${ckpt_prefix}.rsl.text + + python3 utils/compute-wer.py --char=1 --v=1 \ + data/manifest.test-clean.text ${ckpt_prefix}.rsl.text > ${ckpt_prefix}.error +fi -if [ $? -ne 0 ]; then - echo "Failed in evaluation!" - exit 1 +if [ ${stage} -le 101 ] && [ ${stop_stage} -ge 101 ]; then + python3 utils/format_rsl.py \ + --origin_ref data/manifest.test-clean.raw \ + --trans_ref_sclite data/manifest.test.text-clean.sclite + + python3 utils/format_rsl.py \ + --origin_hyp ${ckpt_prefix}.rsl \ + --trans_hyp_sclite ${ckpt_prefix}.rsl.text.sclite + + mkdir -p ${ckpt_prefix}_sclite + sclite -i wsj -r data/manifest.test-clean.text.sclite -h ${ckpt_prefix}.rsl.text.sclite -e utf-8 -o all -O ${ckpt_prefix}_sclite -c NOASCII fi diff --git a/examples/librispeech/asr0/local/test_wav.sh b/examples/librispeech/asr0/local/test_wav.sh index 25cfc45e3..a5712b608 100755 --- a/examples/librispeech/asr0/local/test_wav.sh +++ b/examples/librispeech/asr0/local/test_wav.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 5 ];then - echo "usage: ${0} config_path decode_config_path ckpt_path_prefix model_type audio_file" +if [ $# != 4 ];then + echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_file" exit -1 fi @@ -11,8 +11,7 @@ echo "using $ngpu gpus..." config_path=$1 decode_config_path=$2 ckpt_prefix=$3 -model_type=$4 -audio_file=$5 +audio_file=$4 mkdir -p data wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/en/demo_002_en.wav -P data/ @@ -37,7 +36,6 @@ python3 -u ${BIN_DIR}/test_wav.py \ --decode_cfg ${decode_config_path} \ --result_file ${ckpt_prefix}.rsl \ --checkpoint_path ${ckpt_prefix} \ ---model_type ${model_type} \ --audio_file ${audio_file} if [ $? -ne 0 ]; then diff --git a/examples/librispeech/asr0/local/train.sh b/examples/librispeech/asr0/local/train.sh index 50d1d1922..71659e28d 100755 --- a/examples/librispeech/asr0/local/train.sh +++ b/examples/librispeech/asr0/local/train.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 3 ];then - echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type" +if [ $# -lt 2 ] && [ $# -gt 3 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)" exit -1 fi @@ -10,7 +10,13 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_name=$2 -model_type=$3 +ips=$3 + +if [ ! $ips ];then + ips_config= +else + ips_config="--ips="${ips} +fi mkdir -p exp @@ -25,14 +31,12 @@ python3 -u ${BIN_DIR}/train.py \ --ngpu ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ ---model_type ${model_type} \ --seed ${seed} else -python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${BIN_DIR}/train.py \ +python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \ --ngpu ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ ---model_type ${model_type} \ --seed ${seed} fi diff --git a/examples/librispeech/asr0/run.sh b/examples/librispeech/asr0/run.sh index ca2c2b9da..38112398a 100755 --- a/examples/librispeech/asr0/run.sh +++ b/examples/librispeech/asr0/run.sh @@ -2,13 +2,13 @@ set -e source path.sh -gpus=0,1,2,3,4,5,6,7 +gpus=0,1,2,3 stage=0 stop_stage=100 conf_path=conf/deepspeech2.yaml +ips= #xx.xx.xx.xx,xx.xx.xx.xx decode_conf_path=conf/tuning/decode.yaml -avg_num=30 -model_type=offline +avg_num=5 audio_file=data/demo_002_en.wav source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; @@ -24,7 +24,7 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${model_type} + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then @@ -34,15 +34,20 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type} || exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt}|| exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # export ckpt avg_n - CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type} + CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit fi if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + # test export ckpt avg_n + CUDA_VISIBLE_DEVICES=0 ./local/test_export.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt}.jit|| exit -1 +fi + +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then # test a single .wav file - CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type} ${audio_file} || exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1 fi diff --git a/examples/librispeech/asr1/local/test.sh b/examples/librispeech/asr1/local/test.sh index 51ced18b2..03cef9a62 100755 --- a/examples/librispeech/asr1/local/test.sh +++ b/examples/librispeech/asr1/local/test.sh @@ -42,6 +42,11 @@ echo "chunk mode ${chunk_mode}" if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # format the reference test file + python3 utils/format_rsl.py \ + --origin_ref data/manifest.test-clean.raw \ + --trans_ref data/manifest.test-clean.text + for type in attention; do echo "decoding ${type}" if [ ${chunk_mode} == true ];then @@ -63,54 +68,90 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then echo "Failed in evaluation!" exit 1 fi + python3 utils/format_rsl.py \ + --origin_hyp ${ckpt_prefix}.${type}.rsl \ + --trans_hyp ${ckpt_prefix}.${type}.rsl.text + + python3 utils/compute-wer.py --char=1 --v=1 \ + data/manifest.test-clean.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error + echo "decoding ${type} done." + done + + for type in ctc_greedy_search; do + echo "decoding ${type}" + if [ ${chunk_mode} == true ];then + # stream decoding only support batchsize=1 + batch_size=1 + else + batch_size=64 + fi + python3 -u ${BIN_DIR}/test.py \ + --ngpu ${ngpu} \ + --config ${config_path} \ + --decode_cfg ${decode_config_path} \ + --result_file ${ckpt_prefix}.${type}.rsl \ + --checkpoint_path ${ckpt_prefix} \ + --opts decode.decoding_method ${type} \ + --opts decode.decode_batch_size ${batch_size} + + if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 + fi + python3 utils/format_rsl.py \ + --origin_hyp ${ckpt_prefix}.${type}.rsl \ + --trans_hyp ${ckpt_prefix}.${type}.rsl.text + + python3 utils/compute-wer.py --char=1 --v=1 \ + data/manifest.test-clean.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error echo "decoding ${type} done." done -fi -for type in ctc_greedy_search; do - echo "decoding ${type}" - if [ ${chunk_mode} == true ];then - # stream decoding only support batchsize=1 + + + for type in ctc_prefix_beam_search attention_rescoring; do + echo "decoding ${type}" batch_size=1 - else - batch_size=64 - fi - python3 -u ${BIN_DIR}/test.py \ - --ngpu ${ngpu} \ - --config ${config_path} \ - --decode_cfg ${decode_config_path} \ - --result_file ${ckpt_prefix}.${type}.rsl \ - --checkpoint_path ${ckpt_prefix} \ - --opts decode.decoding_method ${type} \ - --opts decode.decode_batch_size ${batch_size} - - if [ $? -ne 0 ]; then - echo "Failed in evaluation!" - exit 1 - fi - echo "decoding ${type} done." -done - - - -for type in ctc_prefix_beam_search attention_rescoring; do - echo "decoding ${type}" - batch_size=1 - python3 -u ${BIN_DIR}/test.py \ - --ngpu ${ngpu} \ - --config ${config_path} \ - --decode_cfg ${decode_config_path} \ - --result_file ${ckpt_prefix}.${type}.rsl \ - --checkpoint_path ${ckpt_prefix} \ - --opts decode.decoding_method ${type} \ - --opts decode.decode_batch_size ${batch_size} - - if [ $? -ne 0 ]; then - echo "Failed in evaluation!" - exit 1 - fi - echo "decoding ${type} done." -done + python3 -u ${BIN_DIR}/test.py \ + --ngpu ${ngpu} \ + --config ${config_path} \ + --decode_cfg ${decode_config_path} \ + --result_file ${ckpt_prefix}.${type}.rsl \ + --checkpoint_path ${ckpt_prefix} \ + --opts decode.decoding_method ${type} \ + --opts decode.decode_batch_size ${batch_size} + + if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 + fi + python3 utils/format_rsl.py \ + --origin_hyp ${ckpt_prefix}.${type}.rsl \ + --trans_hyp ${ckpt_prefix}.${type}.rsl.text + + python3 utils/compute-wer.py --char=1 --v=1 \ + data/manifest.test-clean.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error + echo "decoding ${type} done." + done +fi + +if [ ${stage} -le 101 ] && [ ${stop_stage} -ge 101 ]; then + python3 utils/format_rsl.py \ + --origin_ref data/manifest.test-clean.raw \ + --trans_ref_sclite data/manifest.test.text-clean.sclite + + + output_dir=${ckpt_prefix} + for type in attention ctc_greedy_search ctc_prefix_beam_search attention_rescoring; do + python utils/format_rsl.py \ + --origin_hyp ${output_dir}/${type}.rsl \ + --trans_hyp_sclite ${output_dir}/${type}.rsl.text.sclite + + mkdir -p ${output_dir}/${type}_sclite + sclite -i wsj -r data/manifest.test-clean.text.sclite -h ${output_dir}/${type}.rsl.text.sclite -e utf-8 -o all -O ${output_dir}/${type}_sclite -c NOASCII + done +fi + echo "Finished" diff --git a/examples/librispeech/asr1/local/train.sh b/examples/librispeech/asr1/local/train.sh index 3860d85cf..f729ed22c 100755 --- a/examples/librispeech/asr1/local/train.sh +++ b/examples/librispeech/asr1/local/train.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 2 ];then - echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" +if [ $# -lt 2 ] && [ $# -gt 3 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)" exit -1 fi @@ -10,6 +10,13 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_name=$2 +ips=$3 + +if [ ! $ips ];then + ips_config= +else + ips_config="--ips="${ips} +fi mkdir -p exp @@ -29,7 +36,7 @@ python3 -u ${BIN_DIR}/train.py \ --output exp/${ckpt_name} \ --seed ${seed} else -python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${BIN_DIR}/train.py \ +python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \ --ngpu ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ diff --git a/examples/librispeech/asr1/run.sh b/examples/librispeech/asr1/run.sh index 116dae126..a14240ee0 100755 --- a/examples/librispeech/asr1/run.sh +++ b/examples/librispeech/asr1/run.sh @@ -8,6 +8,7 @@ gpus=0,1,2,3 stage=0 stop_stage=50 conf_path=conf/transformer.yaml +ips= #xx.xx.xx.xx,xx.xx.xx.xx decode_conf_path=conf/tuning/decode.yaml avg_num=30 audio_file=data/demo_002_en.wav @@ -25,7 +26,7 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then diff --git a/examples/librispeech/asr2/local/train.sh b/examples/librispeech/asr2/local/train.sh index 560424ea4..1f414ad41 100755 --- a/examples/librispeech/asr2/local/train.sh +++ b/examples/librispeech/asr2/local/train.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 2 ];then - echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" +if [ $# -lt 2 ] && [ $# -gt 3 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)" exit -1 fi @@ -10,6 +10,13 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_name=$2 +ips=$3 + +if [ ! $ips ];then + ips_config= +else + ips_config="--ips="${ips} +fi mkdir -p exp @@ -27,7 +34,7 @@ python3 -u ${BIN_DIR}/train.py \ --output exp/${ckpt_name} \ --seed ${seed} else -python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${BIN_DIR}/train.py \ +python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \ --ngpu ${ngpu} \ --model-name u2_kaldi \ --config ${config_path} \ diff --git a/examples/librispeech/asr2/run.sh b/examples/librispeech/asr2/run.sh index c9a794e34..d156159f2 100755 --- a/examples/librispeech/asr2/run.sh +++ b/examples/librispeech/asr2/run.sh @@ -9,6 +9,7 @@ gpus=0,1,2,3,4,5,6,7 stage=0 stop_stage=50 conf_path=conf/transformer.yaml +ips= #xx.xx.xx.xx,xx.xx.xx.xx decode_conf_path=conf/decode/decode_base.yaml dict_path=data/lang_char/train_960_unigram5000_units.txt avg_num=10 @@ -26,7 +27,7 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then diff --git a/examples/ljspeech/tts0/README.md b/examples/ljspeech/tts0/README.md index ba7ad6193..85d9e448b 100644 --- a/examples/ljspeech/tts0/README.md +++ b/examples/ljspeech/tts0/README.md @@ -3,7 +3,7 @@ This example contains code used to train a [Tacotron2](https://arxiv.org/abs/171 ## Dataset ### Download and Extract -Download LJSpeech-1.1 from the [official website](https://keithito.com/LJ-Speech-Dataset/). +Download LJSpeech-1.1 from it's [Official Website](https://keithito.com/LJ-Speech-Dataset/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/LJSpeech-1.1`. ### Get MFA Result and Extract We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get phonemes for Tacotron2, the durations of MFA are not needed here. @@ -103,12 +103,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p ``` ```text usage: synthesize.py [-h] - [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc}] + [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--voice-cloning VOICE_CLONING] - [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] + [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_stat VOC_STAT] [--ngpu NGPU] [--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR] @@ -117,11 +117,10 @@ Synthesize with acoustic model & vocoder optional arguments: -h, --help show this help message and exit - --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc} + --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3} Choose acoustic model type of tts task. --am_config AM_CONFIG - Config of acoustic model. Use deault config when it is - None. + Config of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_stat AM_STAT mean and standard deviation used to normalize spectrogram when training acoustic model. @@ -133,10 +132,10 @@ optional arguments: speaker id map file. --voice-cloning VOICE_CLONING whether training voice cloning model. - --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} + --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc} Choose vocoder type of tts task. --voc_config VOC_CONFIG - Config of voc. Use deault config when it is None. + Config of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_stat VOC_STAT mean and standard deviation used to normalize spectrogram when training voc. @@ -152,12 +151,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp ``` ```text usage: synthesize_e2e.py [-h] - [--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc}] + [--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID] - [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc}] + [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_stat VOC_STAT] [--lang LANG] [--inference_dir INFERENCE_DIR] [--ngpu NGPU] @@ -167,11 +166,10 @@ Synthesize with acoustic model & vocoder optional arguments: -h, --help show this help message and exit - --am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc} + --am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech} Choose acoustic model type of tts task. --am_config AM_CONFIG - Config of acoustic model. Use deault config when it is - None. + Config of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_stat AM_STAT mean and standard deviation used to normalize spectrogram when training acoustic model. @@ -182,10 +180,10 @@ optional arguments: --speaker_dict SPEAKER_DICT speaker id map file. --spk_id SPK_ID spk id for multi speaker acoustic model - --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc} + --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc} Choose vocoder type of tts task. --voc_config VOC_CONFIG - Config of voc. Use deault config when it is None. + Config of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_stat VOC_STAT mean and standard deviation used to normalize spectrogram when training voc. @@ -198,9 +196,9 @@ optional arguments: output dir. ``` 1. `--am` is acoustic model type with the format {model_name}_{dataset} -2. `--am_config`, `--am_checkpoint`, `--am_stat` and `--phones_dict` are arguments for acoustic model, which correspond to the 4 files in the Tacotron2 pretrained model. +2. `--am_config`, `--am_ckpt`, `--am_stat` and `--phones_dict` are arguments for acoustic model, which correspond to the 4 files in the Tacotron2 pretrained model. 3. `--voc` is vocoder type with the format {model_name}_{dataset} -4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model. +4. `--voc_config`, `--voc_ckpt`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model. 5. `--lang` is the model language, which can be `zh` or `en`. 6. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder. 7. `--text` is the text file, which contains sentences to synthesize. diff --git a/examples/ljspeech/tts1/README.md b/examples/ljspeech/tts1/README.md index 7f32522ac..85621653f 100644 --- a/examples/ljspeech/tts1/README.md +++ b/examples/ljspeech/tts1/README.md @@ -1,13 +1,10 @@ # TransformerTTS with LJSpeech ## Dataset -We experiment with the LJSpeech dataset. Download and unzip [LJSpeech](https://keithito.com/LJ-Speech-Dataset/). - -```bash -wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2 -tar xjvf LJSpeech-1.1.tar.bz2 -``` +### Download and Extract +Download LJSpeech-1.1 from it's [Official Website](https://keithito.com/LJ-Speech-Dataset/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/LJSpeech-1.1`. ## Get Started -Assume the path to the dataset is `~/datasets/LJSpeech-1.1`. +Assume the path to the dataset is `~/datasets/LJSpeech-1.1` and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/LJSpeech-1.1`. + Run the command below to 1. **source path**. 2. preprocess the dataset. @@ -61,7 +58,7 @@ Train a TransformerTTS model with LJSpeech TTS dataset. optional arguments: -h, --help show this help message and exit - --config CONFIG config file to overwrite default config. + --config CONFIG TransformerTTS config file. --train-metadata TRAIN_METADATA training data. --dev-metadata DEV_METADATA diff --git a/examples/ljspeech/tts3/README.md b/examples/ljspeech/tts3/README.md index e028fa05d..81a0580c0 100644 --- a/examples/ljspeech/tts3/README.md +++ b/examples/ljspeech/tts3/README.md @@ -3,7 +3,7 @@ This example contains code used to train a [Fastspeech2](https://arxiv.org/abs/2 ## Dataset ### Download and Extract -Download LJSpeech-1.1 from the [official website](https://keithito.com/LJ-Speech-Dataset/). +Download LJSpeech-1.1 from it's [Official Website](https://keithito.com/LJ-Speech-Dataset/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/LJSpeech-1.1`. ### Get MFA Result and Extract We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for fastspeech2. @@ -107,14 +107,14 @@ pwg_ljspeech_ckpt_0.5 ```bash CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} ``` -``text +```text usage: synthesize.py [-h] - [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] + [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--voice-cloning VOICE_CLONING] - [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] + [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_stat VOC_STAT] [--ngpu NGPU] [--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR] @@ -123,11 +123,10 @@ Synthesize with acoustic model & vocoder optional arguments: -h, --help show this help message and exit - --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} + --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3} Choose acoustic model type of tts task. --am_config AM_CONFIG - Config of acoustic model. Use deault config when it is - None. + Config of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_stat AM_STAT mean and standard deviation used to normalize spectrogram when training acoustic model. @@ -139,10 +138,10 @@ optional arguments: speaker id map file. --voice-cloning VOICE_CLONING whether training voice cloning model. - --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} + --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc} Choose vocoder type of tts task. --voc_config VOC_CONFIG - Config of voc. Use deault config when it is None. + Config of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_stat VOC_STAT mean and standard deviation used to normalize spectrogram when training voc. @@ -158,12 +157,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp ``` ```text usage: synthesize_e2e.py [-h] - [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] + [--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID] - [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] + [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_stat VOC_STAT] [--lang LANG] [--inference_dir INFERENCE_DIR] [--ngpu NGPU] @@ -173,11 +172,10 @@ Synthesize with acoustic model & vocoder optional arguments: -h, --help show this help message and exit - --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} + --am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech} Choose acoustic model type of tts task. --am_config AM_CONFIG - Config of acoustic model. Use deault config when it is - None. + Config of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_stat AM_STAT mean and standard deviation used to normalize spectrogram when training acoustic model. @@ -188,10 +186,10 @@ optional arguments: --speaker_dict SPEAKER_DICT speaker id map file. --spk_id SPK_ID spk id for multi speaker acoustic model - --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} + --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc} Choose vocoder type of tts task. --voc_config VOC_CONFIG - Config of voc. Use deault config when it is None. + Config of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_stat VOC_STAT mean and standard deviation used to normalize spectrogram when training voc. @@ -204,9 +202,9 @@ optional arguments: output dir. ``` 1. `--am` is acoustic model type with the format {model_name}_{dataset} -2. `--am_config`, `--am_checkpoint`, `--am_stat` and `--phones_dict` are arguments for acoustic model, which correspond to the 4 files in the fastspeech2 pretrained model. +2. `--am_config`, `--am_ckpt`, `--am_stat` and `--phones_dict` are arguments for acoustic model, which correspond to the 4 files in the fastspeech2 pretrained model. 3. `--voc` is vocoder type with the format {model_name}_{dataset} -4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model. +4. `--voc_config`, `--voc_ckpt`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model. 5. `--lang` is the model language, which can be `zh` or `en`. 6. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder. 7. `--text` is the text file, which contains sentences to synthesize. diff --git a/examples/ljspeech/voc0/README.md b/examples/ljspeech/voc0/README.md index 41b08d57f..ae48a9a7f 100644 --- a/examples/ljspeech/voc0/README.md +++ b/examples/ljspeech/voc0/README.md @@ -1,11 +1,7 @@ # WaveFlow with LJSpeech ## Dataset -We experiment with the LJSpeech dataset. Download and unzip [LJSpeech](https://keithito.com/LJ-Speech-Dataset/). - -```bash -wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2 -tar xjvf LJSpeech-1.1.tar.bz2 -``` +### Download and Extract +Download LJSpeech-1.1 from it's [Official Website](https://keithito.com/LJ-Speech-Dataset/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/LJSpeech-1.1`. ## Get Started Assume the path to the dataset is `~/datasets/LJSpeech-1.1`. Assume the path to the Tacotron2 generated mels is `../tts0/output/test`. diff --git a/examples/ljspeech/voc1/README.md b/examples/ljspeech/voc1/README.md index 4513b2a05..d16c0e35f 100644 --- a/examples/ljspeech/voc1/README.md +++ b/examples/ljspeech/voc1/README.md @@ -2,7 +2,7 @@ This example contains code used to train a [parallel wavegan](http://arxiv.org/abs/1910.11480) model with [LJSpeech-1.1](https://keithito.com/LJ-Speech-Dataset/). ## Dataset ### Download and Extract -Download LJSpeech-1.1 from the [official website](https://keithito.com/LJ-Speech-Dataset/). +Download LJSpeech-1.1 from it's [Official Website](https://keithito.com/LJ-Speech-Dataset/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/LJSpeech-1.1`. ### Get MFA Result and Extract We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio. You can download from here [ljspeech_alignment.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/LJSpeech-1.1/ljspeech_alignment.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) of our repo. @@ -65,7 +65,7 @@ Train a ParallelWaveGAN model. optional arguments: -h, --help show this help message and exit - --config CONFIG config file to overwrite default config. + --config CONFIG ParallelWaveGAN config file. --train-metadata TRAIN_METADATA training data. --dev-metadata DEV_METADATA diff --git a/examples/ljspeech/voc5/README.md b/examples/ljspeech/voc5/README.md index 9b31e2650..d856cfecf 100644 --- a/examples/ljspeech/voc5/README.md +++ b/examples/ljspeech/voc5/README.md @@ -2,7 +2,7 @@ This example contains code used to train a [HiFiGAN](https://arxiv.org/abs/2010.05646) model with [LJSpeech-1.1](https://keithito.com/LJ-Speech-Dataset/). ## Dataset ### Download and Extract -Download LJSpeech-1.1 from the [official website](https://keithito.com/LJ-Speech-Dataset/). +Download LJSpeech-1.1 from it's [Official Website](https://keithito.com/LJ-Speech-Dataset/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/LJSpeech-1.1`. ### Get MFA Result and Extract We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio. You can download from here [ljspeech_alignment.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/LJSpeech-1.1/ljspeech_alignment.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) of our repo. @@ -57,15 +57,13 @@ Here's the complete help message. ```text usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA] [--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR] - [--ngpu NGPU] [--batch-size BATCH_SIZE] [--max-iter MAX_ITER] - [--run-benchmark RUN_BENCHMARK] - [--profiler_options PROFILER_OPTIONS] + [--ngpu NGPU] -Train a ParallelWaveGAN model. +Train a HiFiGAN model. optional arguments: -h, --help show this help message and exit - --config CONFIG config file to overwrite default config. + --config CONFIG HiFiGAN config file. --train-metadata TRAIN_METADATA training data. --dev-metadata DEV_METADATA @@ -73,19 +71,6 @@ optional arguments: --output-dir OUTPUT_DIR output dir. --ngpu NGPU if ngpu == 0, use cpu. - -benchmark: - arguments related to benchmark. - - --batch-size BATCH_SIZE - batch size. - --max-iter MAX_ITER train max steps. - --run-benchmark RUN_BENCHMARK - runing benchmark or not, if True, use the --batch-size - and --max-iter. - --profiler_options PROFILER_OPTIONS - The option of profiler, which should be in format - "key1=value1;key2=value2;key3=value3". ``` 1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`. diff --git a/examples/mustc/st1/local/train.sh b/examples/mustc/st1/local/train.sh index 456c94169..db2a575a6 100755 --- a/examples/mustc/st1/local/train.sh +++ b/examples/mustc/st1/local/train.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 3 ];then - echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ckpt_path" +if [ $# -lt 3 ] && [ $# -gt 4 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ckpt_path ips(optional)" exit -1 fi @@ -11,6 +11,13 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_name=$2 ckpt_path=$3 +ips=$3 + +if [ ! $ips ];then + ips_config= +else + ips_config="--ips="${ips} +fi mkdir -p exp @@ -21,12 +28,21 @@ if [ ${seed} != 0 ]; then export FLAGS_cudnn_deterministic=True fi +if [ ${ngpu} == 0 ]; then python3 -u ${BIN_DIR}/train.py \ --ngpu ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ --checkpoint_path "${ckpt_path}" \ --seed ${seed} +else +python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \ +--ngpu ${ngpu} \ +--config ${config_path} \ +--output exp/${ckpt_name} \ +--checkpoint_path "${ckpt_path}" \ +--seed ${seed} +fi if [ ${seed} != 0 ]; then unset FLAGS_cudnn_deterministic diff --git a/examples/mustc/st1/run.sh b/examples/mustc/st1/run.sh index 6ceae3b84..99ee2295c 100755 --- a/examples/mustc/st1/run.sh +++ b/examples/mustc/st1/run.sh @@ -7,6 +7,7 @@ gpus=0,1,2,3 stage=0 stop_stage=3 conf_path=conf/transformer_es.yaml +ips= #xx.xx.xx.xx,xx.xx.xx.xx decode_conf_path=conf/tuning/decode.yaml must_c_path= lang=es @@ -25,7 +26,7 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} "${ckpt_path}" + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} "${ckpt_path}" ${ips} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then @@ -36,4 +37,4 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # test ckpt avg_n CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${lang} || exit -1 -fi \ No newline at end of file +fi diff --git a/examples/other/1xt2x/.gitignore b/examples/other/1xt2x/.gitignore deleted file mode 100644 index a9a5aecf4..000000000 --- a/examples/other/1xt2x/.gitignore +++ /dev/null @@ -1 +0,0 @@ -tmp diff --git a/examples/other/1xt2x/README.md b/examples/other/1xt2x/README.md deleted file mode 100644 index 49f850d26..000000000 --- a/examples/other/1xt2x/README.md +++ /dev/null @@ -1,19 +0,0 @@ -# 1xt2x - -Convert Deepspeech 1.8 released model to 2.x. - -## Model source directory -* Deepspeech2x - -## Expriment directory -* aishell -* librispeech -* baidu_en8k - -# The released model - -Acoustic Model | Training Data | Hours of Speech | Token-based | CER | WER -:-------------:| :------------:| :---------------: | :---------: | :---: | :----: -Ds2 Offline Aishell 1xt2x model| Aishell Dataset | 151 h | Char-based | 0.080447 | -Ds2 Offline Librispeech 1xt2x model | Librispeech Dataset | 960 h | Word-based | | 0.068548 -Ds2 Offline Baidu en8k 1x2x model | Baidu Internal English Dataset | 8628 h |Word-based | | 0.054112 diff --git a/examples/other/1xt2x/aishell/.gitignore b/examples/other/1xt2x/aishell/.gitignore deleted file mode 100644 index 3631e544a..000000000 --- a/examples/other/1xt2x/aishell/.gitignore +++ /dev/null @@ -1,5 +0,0 @@ -exp -data -*log -tmp -nohup* diff --git a/examples/other/1xt2x/aishell/conf/augmentation.json b/examples/other/1xt2x/aishell/conf/augmentation.json deleted file mode 100644 index fe51488c7..000000000 --- a/examples/other/1xt2x/aishell/conf/augmentation.json +++ /dev/null @@ -1 +0,0 @@ -[] diff --git a/examples/other/1xt2x/aishell/conf/deepspeech2.yaml b/examples/other/1xt2x/aishell/conf/deepspeech2.yaml deleted file mode 100644 index c2db2c7c2..000000000 --- a/examples/other/1xt2x/aishell/conf/deepspeech2.yaml +++ /dev/null @@ -1,65 +0,0 @@ -# https://yaml.org/type/float.html -########################################### -# Data # -########################################### -train_manifest: data/manifest.train -dev_manifest: data/manifest.dev -test_manifest: data/manifest.test -min_input_len: 0.0 -max_input_len: 27.0 # second -min_output_len: 0.0 -max_output_len: .inf -min_output_input_ratio: 0.00 -max_output_input_ratio: .inf - -########################################### -# Dataloader # -########################################### -batch_size: 64 # one gpu -mean_std_filepath: data/mean_std.npz -unit_type: char -vocab_filepath: data/vocab.txt -augmentation_config: conf/augmentation.json -random_seed: 0 -spm_model_prefix: -spectrum_type: linear -feat_dim: -delta_delta: False -stride_ms: 10.0 -window_ms: 20.0 -n_fft: None -max_freq: None -target_sample_rate: 16000 -use_dB_normalization: True -target_dB: -20 -dither: 1.0 -keep_transcription_text: False -sortagrad: True -shuffle_method: batch_shuffle -num_workers: 2 - -############################################ -# Network Architecture # -############################################ -num_conv_layers: 2 -num_rnn_layers: 3 -rnn_layer_size: 1024 -use_gru: True -share_rnn_weights: False -blank_id: 4333 - -########################################### -# Training # -########################################### -n_epoch: 80 -accum_grad: 1 -lr: 2e-3 -lr_decay: 0.83 -weight_decay: 1e-06 -global_grad_clip: 3.0 -log_interval: 100 -checkpoint: - kbest_n: 50 - latest_n: 5 - - diff --git a/examples/other/1xt2x/aishell/conf/tuning/decode.yaml b/examples/other/1xt2x/aishell/conf/tuning/decode.yaml deleted file mode 100644 index b5283a934..000000000 --- a/examples/other/1xt2x/aishell/conf/tuning/decode.yaml +++ /dev/null @@ -1,10 +0,0 @@ -decode_batch_size: 32 -error_rate_type: cer -decoding_method: ctc_beam_search -lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm -alpha: 2.6 -beta: 5.0 -beam_size: 300 -cutoff_prob: 0.99 -cutoff_top_n: 40 -num_proc_bsearch: 8 \ No newline at end of file diff --git a/examples/other/1xt2x/aishell/local/data.sh b/examples/other/1xt2x/aishell/local/data.sh deleted file mode 100755 index a9d5b1412..000000000 --- a/examples/other/1xt2x/aishell/local/data.sh +++ /dev/null @@ -1,69 +0,0 @@ -#!/bin/bash -if [ $# != 1 ];then - echo "usage: ${0} ckpt_dir" - exit -1 -fi - -ckpt_dir=$1 - -stage=-1 -stop_stage=100 - -source ${MAIN_ROOT}/utils/parse_options.sh - -mkdir -p data -TARGET_DIR=${MAIN_ROOT}/dataset -mkdir -p ${TARGET_DIR} - -bash local/download_model.sh ${ckpt_dir} -if [ $? -ne 0 ]; then - exit 1 -fi - -cd ${ckpt_dir} -tar xzvf aishell_model_v1.8_to_v2.x.tar.gz -cd - -mv ${ckpt_dir}/mean_std.npz data/ -mv ${ckpt_dir}/vocab.txt data/ - - -if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then - # download data, generate manifests - python3 ${TARGET_DIR}/aishell/aishell.py \ - --manifest_prefix="data/manifest" \ - --target_dir="${TARGET_DIR}/aishell" - - if [ $? -ne 0 ]; then - echo "Prepare Aishell failed. Terminated." - exit 1 - fi - - for dataset in train dev test; do - mv data/manifest.${dataset} data/manifest.${dataset}.raw - done -fi - - - -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - # format manifest with tokenids, vocab size - for dataset in train dev test; do - { - python3 ${MAIN_ROOT}/utils/format_data.py \ - --cmvn_path "data/mean_std.npz" \ - --unit_type "char" \ - --vocab_path="data/vocab.txt" \ - --manifest_path="data/manifest.${dataset}.raw" \ - --output_path="data/manifest.${dataset}" - - if [ $? -ne 0 ]; then - echo "Formt mnaifest failed. Terminated." - exit 1 - fi - } & - done - wait -fi - -echo "Aishell data preparation done." -exit 0 diff --git a/examples/other/1xt2x/aishell/local/download_lm_ch.sh b/examples/other/1xt2x/aishell/local/download_lm_ch.sh deleted file mode 100755 index 47153f4b6..000000000 --- a/examples/other/1xt2x/aishell/local/download_lm_ch.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash - -. ${MAIN_ROOT}/utils/utility.sh - -DIR=data/lm -mkdir -p ${DIR} - -URL='https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm' -MD5="29e02312deb2e59b3c8686c7966d4fe3" -TARGET=${DIR}/zh_giga.no_cna_cmn.prune01244.klm - - -echo "Start downloading the language model. The language model is large, please wait for a moment ..." -download $URL $MD5 $TARGET > /dev/null 2>&1 -if [ $? -ne 0 ]; then - echo "Fail to download the language model!" - exit 1 -else - echo "Download the language model sucessfully" -fi - - -exit 0 diff --git a/examples/other/1xt2x/aishell/local/download_model.sh b/examples/other/1xt2x/aishell/local/download_model.sh deleted file mode 100644 index ffa2f8101..000000000 --- a/examples/other/1xt2x/aishell/local/download_model.sh +++ /dev/null @@ -1,25 +0,0 @@ -#! /usr/bin/env bash - -if [ $# != 1 ];then - echo "usage: ${0} ckpt_dir" - exit -1 -fi - -ckpt_dir=$1 - -. ${MAIN_ROOT}/utils/utility.sh - -URL='https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_v1.8_to_v2.x.tar.gz' -MD5=87e7577d4bea737dbf3e8daab37aa808 -TARGET=${ckpt_dir}/aishell_model_v1.8_to_v2.x.tar.gz - - -echo "Download Aishell model ..." -download $URL $MD5 $TARGET -if [ $? -ne 0 ]; then - echo "Fail to download Aishell model!" - exit 1 -fi - - -exit 0 diff --git a/examples/other/1xt2x/aishell/local/test.sh b/examples/other/1xt2x/aishell/local/test.sh deleted file mode 100755 index 463593ef3..000000000 --- a/examples/other/1xt2x/aishell/local/test.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/bin/bash - -if [ $# != 4 ];then - echo "usage: ${0} config_path decode_config_path ckpt_path_prefix model_type" - exit -1 -fi - -ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') -echo "using $ngpu gpus..." - -config_path=$1 -decode_config_path=$2 -ckpt_prefix=$3 -model_type=$4 - -# download language model -bash local/download_lm_ch.sh -if [ $? -ne 0 ]; then - exit 1 -fi - -python3 -u ${BIN_DIR}/test.py \ ---ngpu ${ngpu} \ ---config ${config_path} \ ---decode_cfg ${decode_config_path} \ ---result_file ${ckpt_prefix}.rsl \ ---checkpoint_path ${ckpt_prefix} \ ---model_type ${model_type} - -if [ $? -ne 0 ]; then - echo "Failed in evaluation!" - exit 1 -fi - - -exit 0 diff --git a/examples/other/1xt2x/aishell/path.sh b/examples/other/1xt2x/aishell/path.sh deleted file mode 100644 index ce44e65cb..000000000 --- a/examples/other/1xt2x/aishell/path.sh +++ /dev/null @@ -1,17 +0,0 @@ -export MAIN_ROOT=`realpath ${PWD}/../../../../` -export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../` - -export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} -export LC_ALL=C - -export PYTHONDONTWRITEBYTECODE=1 -# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C -export PYTHONIOENCODING=UTF-8 -export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} -export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH} - -export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ - -MODEL=deepspeech2 -export BIN_DIR=${LOCAL_DEEPSPEECH2}/src_deepspeech2x/bin -echo "BIN_DIR "${BIN_DIR} diff --git a/examples/other/1xt2x/aishell/run.sh b/examples/other/1xt2x/aishell/run.sh deleted file mode 100755 index 89a634119..000000000 --- a/examples/other/1xt2x/aishell/run.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/bash -set -e -source path.sh - -stage=0 -stop_stage=100 -conf_path=conf/deepspeech2.yaml -decode_conf_path=conf/tuning/decode.yaml -avg_num=1 -model_type=offline -gpus=2 - -source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; - -v18_ckpt=aishell_v1.8 -ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') -echo "checkpoint name ${ckpt}" - -if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then - # prepare data - mkdir -p exp/${ckpt}/checkpoints - bash ./local/data.sh exp/${ckpt}/checkpoints || exit -1 -fi - -if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - # test ckpt avg_n - CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1 -fi - diff --git a/examples/other/1xt2x/baidu_en8k/.gitignore b/examples/other/1xt2x/baidu_en8k/.gitignore deleted file mode 100644 index 3631e544a..000000000 --- a/examples/other/1xt2x/baidu_en8k/.gitignore +++ /dev/null @@ -1,5 +0,0 @@ -exp -data -*log -tmp -nohup* diff --git a/examples/other/1xt2x/baidu_en8k/conf/augmentation.json b/examples/other/1xt2x/baidu_en8k/conf/augmentation.json deleted file mode 100644 index fe51488c7..000000000 --- a/examples/other/1xt2x/baidu_en8k/conf/augmentation.json +++ /dev/null @@ -1 +0,0 @@ -[] diff --git a/examples/other/1xt2x/baidu_en8k/conf/deepspeech2.yaml b/examples/other/1xt2x/baidu_en8k/conf/deepspeech2.yaml deleted file mode 100644 index 0c08fbc63..000000000 --- a/examples/other/1xt2x/baidu_en8k/conf/deepspeech2.yaml +++ /dev/null @@ -1,64 +0,0 @@ -# https://yaml.org/type/float.html -########################################### -# Data # -########################################### -train_manifest: data/manifest.train -dev_manifest: data/manifest.dev -test_manifest: data/manifest.test-clean -min_input_len: 0.0 -max_input_len: .inf # second -min_output_len: 0.0 -max_output_len: .inf -min_output_input_ratio: 0.00 -max_output_input_ratio: .inf - -########################################### -# Dataloader # -########################################### -batch_size: 64 # one gpu -mean_std_filepath: data/mean_std.npz -unit_type: char -vocab_filepath: data/vocab.txt -augmentation_config: conf/augmentation.json -random_seed: 0 -spm_model_prefix: -spectrum_type: linear -feat_dim: -delta_delta: False -stride_ms: 10.0 -window_ms: 20.0 -n_fft: None -max_freq: None -target_sample_rate: 16000 -use_dB_normalization: True -target_dB: -20 -dither: 1.0 -keep_transcription_text: False -sortagrad: True -shuffle_method: batch_shuffle -num_workers: 2 - -############################################ -# Network Architecture # -############################################ -num_conv_layers: 2 -num_rnn_layers: 3 -rnn_layer_size: 1024 -use_gru: True -share_rnn_weights: False -blank_id: 28 - -########################################### -# Training # -########################################### -n_epoch: 80 -accum_grad: 1 -lr: 2e-3 -lr_decay: 0.83 -weight_decay: 1e-06 -global_grad_clip: 3.0 -log_interval: 100 -checkpoint: - kbest_n: 50 - latest_n: 5 - diff --git a/examples/other/1xt2x/baidu_en8k/conf/tuning/decode.yaml b/examples/other/1xt2x/baidu_en8k/conf/tuning/decode.yaml deleted file mode 100644 index f52dde320..000000000 --- a/examples/other/1xt2x/baidu_en8k/conf/tuning/decode.yaml +++ /dev/null @@ -1,10 +0,0 @@ -decode_batch_size: 32 -error_rate_type: wer -decoding_method: ctc_beam_search -lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm -alpha: 1.4 -beta: 0.35 -beam_size: 500 -cutoff_prob: 1.0 -cutoff_top_n: 40 -num_proc_bsearch: 8 \ No newline at end of file diff --git a/examples/other/1xt2x/baidu_en8k/local/data.sh b/examples/other/1xt2x/baidu_en8k/local/data.sh deleted file mode 100755 index 9b017324d..000000000 --- a/examples/other/1xt2x/baidu_en8k/local/data.sh +++ /dev/null @@ -1,85 +0,0 @@ -#!/bin/bash -if [ $# != 1 ];then - echo "usage: ${0} ckpt_dir" - exit -1 -fi - -ckpt_dir=$1 - -stage=-1 -stop_stage=100 -unit_type=char - -source ${MAIN_ROOT}/utils/parse_options.sh - -mkdir -p data -TARGET_DIR=${MAIN_ROOT}/dataset -mkdir -p ${TARGET_DIR} - - -bash local/download_model.sh ${ckpt_dir} -if [ $? -ne 0 ]; then - exit 1 -fi - -cd ${ckpt_dir} -tar xzvf baidu_en8k_v1.8_to_v2.x.tar.gz -cd - -mv ${ckpt_dir}/mean_std.npz data/ -mv ${ckpt_dir}/vocab.txt data/ - - -if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then - # download data, generate manifests - python3 ${TARGET_DIR}/librispeech/librispeech.py \ - --manifest_prefix="data/manifest" \ - --target_dir="${TARGET_DIR}/librispeech" \ - --full_download="True" - - if [ $? -ne 0 ]; then - echo "Prepare LibriSpeech failed. Terminated." - exit 1 - fi - - for set in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do - mv data/manifest.${set} data/manifest.${set}.raw - done - - rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw - for set in train-clean-100 train-clean-360 train-other-500; do - cat data/manifest.${set}.raw >> data/manifest.train.raw - done - - for set in dev-clean dev-other; do - cat data/manifest.${set}.raw >> data/manifest.dev.raw - done - - for set in test-clean test-other; do - cat data/manifest.${set}.raw >> data/manifest.test.raw - done -fi - - -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - # format manifest with tokenids, vocab size - for set in train dev test dev-clean dev-other test-clean test-other; do - { - python3 ${MAIN_ROOT}/utils/format_data.py \ - --cmvn_path "data/mean_std.npz" \ - --unit_type ${unit_type} \ - --vocab_path="data/vocab.txt" \ - --manifest_path="data/manifest.${set}.raw" \ - --output_path="data/manifest.${set}" - - if [ $? -ne 0 ]; then - echo "Formt mnaifest.${set} failed. Terminated." - exit 1 - fi - }& - done - wait -fi - -echo "LibriSpeech Data preparation done." -exit 0 - diff --git a/examples/other/1xt2x/baidu_en8k/local/download_lm_en.sh b/examples/other/1xt2x/baidu_en8k/local/download_lm_en.sh deleted file mode 100755 index 390fffc93..000000000 --- a/examples/other/1xt2x/baidu_en8k/local/download_lm_en.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash - -. ${MAIN_ROOT}/utils/utility.sh - -DIR=data/lm -mkdir -p ${DIR} - -URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm -MD5="099a601759d467cd0a8523ff939819c5" -TARGET=${DIR}/common_crawl_00.prune01111.trie.klm - -echo "Start downloading the language model. The language model is large, please wait for a moment ..." -download $URL $MD5 $TARGET > /dev/null 2>&1 -if [ $? -ne 0 ]; then - echo "Fail to download the language model!" - exit 1 -else - echo "Download the language model sucessfully" -fi - - -exit 0 diff --git a/examples/other/1xt2x/baidu_en8k/local/download_model.sh b/examples/other/1xt2x/baidu_en8k/local/download_model.sh deleted file mode 100644 index a8fbc31e8..000000000 --- a/examples/other/1xt2x/baidu_en8k/local/download_model.sh +++ /dev/null @@ -1,25 +0,0 @@ -#! /usr/bin/env bash -if [ $# != 1 ];then - echo "usage: ${0} ckpt_dir" - exit -1 -fi - -ckpt_dir=$1 - - -. ${MAIN_ROOT}/utils/utility.sh - -URL='https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz' -MD5=c1676be8505cee436e6f312823e9008c -TARGET=${ckpt_dir}/baidu_en8k_v1.8_to_v2.x.tar.gz - - -echo "Download BaiduEn8k model ..." -download $URL $MD5 $TARGET -if [ $? -ne 0 ]; then - echo "Fail to download BaiduEn8k model!" - exit 1 -fi - - -exit 0 diff --git a/examples/other/1xt2x/baidu_en8k/local/test.sh b/examples/other/1xt2x/baidu_en8k/local/test.sh deleted file mode 100755 index ea40046b1..000000000 --- a/examples/other/1xt2x/baidu_en8k/local/test.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/bin/bash - -if [ $# != 4 ];then - echo "usage: ${0} config_path decode_config_path ckpt_path_prefix model_type" - exit -1 -fi - -ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') -echo "using $ngpu gpus..." - -config_path=$1 -decode_config_path=$2 -ckpt_prefix=$3 -model_type=$4 - -# download language model -bash local/download_lm_en.sh -if [ $? -ne 0 ]; then - exit 1 -fi - -python3 -u ${BIN_DIR}/test.py \ ---ngpu ${ngpu} \ ---config ${config_path} \ ---decode_cfg ${decode_config_path} \ ---result_file ${ckpt_prefix}.rsl \ ---checkpoint_path ${ckpt_prefix} \ ---model_type ${model_type} - -if [ $? -ne 0 ]; then - echo "Failed in evaluation!" - exit 1 -fi - - -exit 0 diff --git a/examples/other/1xt2x/baidu_en8k/path.sh b/examples/other/1xt2x/baidu_en8k/path.sh deleted file mode 100644 index ce44e65cb..000000000 --- a/examples/other/1xt2x/baidu_en8k/path.sh +++ /dev/null @@ -1,17 +0,0 @@ -export MAIN_ROOT=`realpath ${PWD}/../../../../` -export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../` - -export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} -export LC_ALL=C - -export PYTHONDONTWRITEBYTECODE=1 -# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C -export PYTHONIOENCODING=UTF-8 -export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} -export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH} - -export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ - -MODEL=deepspeech2 -export BIN_DIR=${LOCAL_DEEPSPEECH2}/src_deepspeech2x/bin -echo "BIN_DIR "${BIN_DIR} diff --git a/examples/other/1xt2x/baidu_en8k/run.sh b/examples/other/1xt2x/baidu_en8k/run.sh deleted file mode 100755 index 82de56b09..000000000 --- a/examples/other/1xt2x/baidu_en8k/run.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/bash -set -e -source path.sh - -stage=0 -stop_stage=100 -conf_path=conf/deepspeech2.yaml -decode_conf_path=conf/tuning/decode.yaml -avg_num=1 -model_type=offline -gpus=0 - -source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; - -v18_ckpt=baidu_en8k_v1.8 -ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') -echo "checkpoint name ${ckpt}" - -if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then - # prepare data - mkdir -p exp/${ckpt}/checkpoints - bash ./local/data.sh exp/${ckpt}/checkpoints || exit -1 -fi - -if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - # test ckpt avg_n - CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1 -fi - diff --git a/examples/other/1xt2x/librispeech/.gitignore b/examples/other/1xt2x/librispeech/.gitignore deleted file mode 100644 index 3631e544a..000000000 --- a/examples/other/1xt2x/librispeech/.gitignore +++ /dev/null @@ -1,5 +0,0 @@ -exp -data -*log -tmp -nohup* diff --git a/examples/other/1xt2x/librispeech/conf/augmentation.json b/examples/other/1xt2x/librispeech/conf/augmentation.json deleted file mode 100644 index fe51488c7..000000000 --- a/examples/other/1xt2x/librispeech/conf/augmentation.json +++ /dev/null @@ -1 +0,0 @@ -[] diff --git a/examples/other/1xt2x/librispeech/conf/deepspeech2.yaml b/examples/other/1xt2x/librispeech/conf/deepspeech2.yaml deleted file mode 100644 index a2a5649ba..000000000 --- a/examples/other/1xt2x/librispeech/conf/deepspeech2.yaml +++ /dev/null @@ -1,64 +0,0 @@ -# https://yaml.org/type/float.html -########################################### -# Data # -########################################### -train_manifest: data/manifest.train -dev_manifest: data/manifest.dev -test_manifest: data/manifest.test-clean -min_input_len: 0.0 -max_input_len: 1000.0 # second -min_output_len: 0.0 -max_output_len: .inf -min_output_input_ratio: 0.00 -max_output_input_ratio: .inf - -########################################### -# Dataloader # -########################################### -batch_size: 64 # one gpu -mean_std_filepath: data/mean_std.npz -unit_type: char -vocab_filepath: data/vocab.txt -augmentation_config: conf/augmentation.json -random_seed: 0 -spm_model_prefix: -spectrum_type: linear -feat_dim: -delta_delta: False -stride_ms: 10.0 -window_ms: 20.0 -n_fft: None -max_freq: None -target_sample_rate: 16000 -use_dB_normalization: True -target_dB: -20 -dither: 1.0 -keep_transcription_text: False -sortagrad: True -shuffle_method: batch_shuffle -num_workers: 2 - -############################################ -# Network Architecture # -############################################ -num_conv_layers: 2 -num_rnn_layers: 3 -rnn_layer_size: 2048 -use_gru: False -share_rnn_weights: True -blank_id: 28 - -########################################### -# Training # -########################################### -n_epoch: 80 -accum_grad: 1 -lr: 2e-3 -lr_decay: 0.83 -weight_decay: 1e-06 -global_grad_clip: 3.0 -log_interval: 100 -checkpoint: - kbest_n: 50 - latest_n: 5 - diff --git a/examples/other/1xt2x/librispeech/conf/tuning/decode.yaml b/examples/other/1xt2x/librispeech/conf/tuning/decode.yaml deleted file mode 100644 index f3b51defe..000000000 --- a/examples/other/1xt2x/librispeech/conf/tuning/decode.yaml +++ /dev/null @@ -1,10 +0,0 @@ -decode_batch_size: 32 -error_rate_type: wer -decoding_method: ctc_beam_search -lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm -alpha: 2.5 -beta: 0.3 -beam_size: 500 -cutoff_prob: 1.0 -cutoff_top_n: 40 -num_proc_bsearch: 8 \ No newline at end of file diff --git a/examples/other/1xt2x/librispeech/local/data.sh b/examples/other/1xt2x/librispeech/local/data.sh deleted file mode 100755 index 43b5426d9..000000000 --- a/examples/other/1xt2x/librispeech/local/data.sh +++ /dev/null @@ -1,83 +0,0 @@ -#!/bin/bash - -if [ $# != 1 ];then - echo "usage: ${0} ckpt_dir" - exit -1 -fi - -ckpt_dir=$1 - -stage=-1 -stop_stage=100 -unit_type=char - -source ${MAIN_ROOT}/utils/parse_options.sh - -mkdir -p data -TARGET_DIR=${MAIN_ROOT}/dataset -mkdir -p ${TARGET_DIR} - -bash local/download_model.sh ${ckpt_dir} -if [ $? -ne 0 ]; then - exit 1 -fi - -cd ${ckpt_dir} -tar xzvf librispeech_v1.8_to_v2.x.tar.gz -cd - -mv ${ckpt_dir}/mean_std.npz data/ -mv ${ckpt_dir}/vocab.txt data/ - -if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then - # download data, generate manifests - python3 ${TARGET_DIR}/librispeech/librispeech.py \ - --manifest_prefix="data/manifest" \ - --target_dir="${TARGET_DIR}/librispeech" \ - --full_download="True" - - if [ $? -ne 0 ]; then - echo "Prepare LibriSpeech failed. Terminated." - exit 1 - fi - - for set in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do - mv data/manifest.${set} data/manifest.${set}.raw - done - - rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw - for set in train-clean-100 train-clean-360 train-other-500; do - cat data/manifest.${set}.raw >> data/manifest.train.raw - done - - for set in dev-clean dev-other; do - cat data/manifest.${set}.raw >> data/manifest.dev.raw - done - - for set in test-clean test-other; do - cat data/manifest.${set}.raw >> data/manifest.test.raw - done -fi - -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - # format manifest with tokenids, vocab size - for set in train dev test dev-clean dev-other test-clean test-other; do - { - python3 ${MAIN_ROOT}/utils/format_data.py \ - --cmvn_path "data/mean_std.npz" \ - --unit_type ${unit_type} \ - --vocab_path="data/vocab.txt" \ - --manifest_path="data/manifest.${set}.raw" \ - --output_path="data/manifest.${set}" - - if [ $? -ne 0 ]; then - echo "Formt mnaifest.${set} failed. Terminated." - exit 1 - fi - }& - done - wait -fi - -echo "LibriSpeech Data preparation done." -exit 0 - diff --git a/examples/other/1xt2x/librispeech/local/download_lm_en.sh b/examples/other/1xt2x/librispeech/local/download_lm_en.sh deleted file mode 100755 index 390fffc93..000000000 --- a/examples/other/1xt2x/librispeech/local/download_lm_en.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash - -. ${MAIN_ROOT}/utils/utility.sh - -DIR=data/lm -mkdir -p ${DIR} - -URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm -MD5="099a601759d467cd0a8523ff939819c5" -TARGET=${DIR}/common_crawl_00.prune01111.trie.klm - -echo "Start downloading the language model. The language model is large, please wait for a moment ..." -download $URL $MD5 $TARGET > /dev/null 2>&1 -if [ $? -ne 0 ]; then - echo "Fail to download the language model!" - exit 1 -else - echo "Download the language model sucessfully" -fi - - -exit 0 diff --git a/examples/other/1xt2x/librispeech/local/download_model.sh b/examples/other/1xt2x/librispeech/local/download_model.sh deleted file mode 100644 index 375d66404..000000000 --- a/examples/other/1xt2x/librispeech/local/download_model.sh +++ /dev/null @@ -1,25 +0,0 @@ -#! /usr/bin/env bash - -if [ $# != 1 ];then - echo "usage: ${0} ckpt_dir" - exit -1 -fi - -ckpt_dir=$1 - -. ${MAIN_ROOT}/utils/utility.sh - -URL='https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz' -MD5=a06d9aadb560ea113984dc98d67232c8 -TARGET=${ckpt_dir}/librispeech_v1.8_to_v2.x.tar.gz - - -echo "Download LibriSpeech model ..." -download $URL $MD5 $TARGET -if [ $? -ne 0 ]; then - echo "Fail to download LibriSpeech model!" - exit 1 -fi - - -exit 0 diff --git a/examples/other/1xt2x/librispeech/local/test.sh b/examples/other/1xt2x/librispeech/local/test.sh deleted file mode 100755 index ea40046b1..000000000 --- a/examples/other/1xt2x/librispeech/local/test.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/bin/bash - -if [ $# != 4 ];then - echo "usage: ${0} config_path decode_config_path ckpt_path_prefix model_type" - exit -1 -fi - -ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') -echo "using $ngpu gpus..." - -config_path=$1 -decode_config_path=$2 -ckpt_prefix=$3 -model_type=$4 - -# download language model -bash local/download_lm_en.sh -if [ $? -ne 0 ]; then - exit 1 -fi - -python3 -u ${BIN_DIR}/test.py \ ---ngpu ${ngpu} \ ---config ${config_path} \ ---decode_cfg ${decode_config_path} \ ---result_file ${ckpt_prefix}.rsl \ ---checkpoint_path ${ckpt_prefix} \ ---model_type ${model_type} - -if [ $? -ne 0 ]; then - echo "Failed in evaluation!" - exit 1 -fi - - -exit 0 diff --git a/examples/other/1xt2x/librispeech/path.sh b/examples/other/1xt2x/librispeech/path.sh deleted file mode 100644 index e3696ddd5..000000000 --- a/examples/other/1xt2x/librispeech/path.sh +++ /dev/null @@ -1,16 +0,0 @@ -export MAIN_ROOT=`realpath ${PWD}/../../../../` -export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../` - -export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} -export LC_ALL=C - -export PYTHONDONTWRITEBYTECODE=1 -# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C -export PYTHONIOENCODING=UTF-8 -export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} -export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH} - -export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ - -MODEL=deepspeech2 -export BIN_DIR=${LOCAL_DEEPSPEECH2}/src_deepspeech2x/bin diff --git a/examples/other/1xt2x/librispeech/run.sh b/examples/other/1xt2x/librispeech/run.sh deleted file mode 100755 index 8b614bbbf..000000000 --- a/examples/other/1xt2x/librispeech/run.sh +++ /dev/null @@ -1,28 +0,0 @@ -#!/bin/bash -set -e -source path.sh - -stage=0 -stop_stage=100 -conf_path=conf/deepspeech2.yaml -decode_conf_path=conf/tuning/decode.yaml -avg_num=1 -model_type=offline -gpus=1 - -source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; - -v18_ckpt=librispeech_v1.8 -ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') -echo "checkpoint name ${ckpt}" - -if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then - # prepare data - mkdir -p exp/${ckpt}/checkpoints - bash ./local/data.sh exp/${ckpt}/checkpoints || exit -1 -fi - -if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - # test ckpt avg_n - CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1 -fi diff --git a/examples/other/1xt2x/src_deepspeech2x/__init__.py b/examples/other/1xt2x/src_deepspeech2x/__init__.py deleted file mode 100644 index 74be4a254..000000000 --- a/examples/other/1xt2x/src_deepspeech2x/__init__.py +++ /dev/null @@ -1,370 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any -from typing import List -from typing import Tuple -from typing import Union - -import paddle -from paddle import nn -from paddle.fluid import core -from paddle.nn import functional as F - -from paddlespeech.s2t.utils.log import Log - -#TODO(Hui Zhang): remove fluid import -logger = Log(__name__).getlog() - -########### hack logging ############# -logger.warn = logger.warning - -########### hack paddle ############# -paddle.half = 'float16' -paddle.float = 'float32' -paddle.double = 'float64' -paddle.short = 'int16' -paddle.int = 'int32' -paddle.long = 'int64' -paddle.uint16 = 'uint16' -paddle.cdouble = 'complex128' - - -def convert_dtype_to_string(tensor_dtype): - """ - Convert the data type in numpy to the data type in Paddle - Args: - tensor_dtype(core.VarDesc.VarType): the data type in numpy. - Returns: - core.VarDesc.VarType: the data type in Paddle. - """ - dtype = tensor_dtype - if dtype == core.VarDesc.VarType.FP32: - return paddle.float32 - elif dtype == core.VarDesc.VarType.FP64: - return paddle.float64 - elif dtype == core.VarDesc.VarType.FP16: - return paddle.float16 - elif dtype == core.VarDesc.VarType.INT32: - return paddle.int32 - elif dtype == core.VarDesc.VarType.INT16: - return paddle.int16 - elif dtype == core.VarDesc.VarType.INT64: - return paddle.int64 - elif dtype == core.VarDesc.VarType.BOOL: - return paddle.bool - elif dtype == core.VarDesc.VarType.BF16: - # since there is still no support for bfloat16 in NumPy, - # uint16 is used for casting bfloat16 - return paddle.uint16 - elif dtype == core.VarDesc.VarType.UINT8: - return paddle.uint8 - elif dtype == core.VarDesc.VarType.INT8: - return paddle.int8 - elif dtype == core.VarDesc.VarType.COMPLEX64: - return paddle.complex64 - elif dtype == core.VarDesc.VarType.COMPLEX128: - return paddle.complex128 - else: - raise ValueError("Not supported tensor dtype %s" % dtype) - - -if not hasattr(paddle, 'softmax'): - logger.warn("register user softmax to paddle, remove this when fixed!") - setattr(paddle, 'softmax', paddle.nn.functional.softmax) - -if not hasattr(paddle, 'log_softmax'): - logger.warn("register user log_softmax to paddle, remove this when fixed!") - setattr(paddle, 'log_softmax', paddle.nn.functional.log_softmax) - -if not hasattr(paddle, 'sigmoid'): - logger.warn("register user sigmoid to paddle, remove this when fixed!") - setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid) - -if not hasattr(paddle, 'log_sigmoid'): - logger.warn("register user log_sigmoid to paddle, remove this when fixed!") - setattr(paddle, 'log_sigmoid', paddle.nn.functional.log_sigmoid) - -if not hasattr(paddle, 'relu'): - logger.warn("register user relu to paddle, remove this when fixed!") - setattr(paddle, 'relu', paddle.nn.functional.relu) - - -def cat(xs, dim=0): - return paddle.concat(xs, axis=dim) - - -if not hasattr(paddle, 'cat'): - logger.warn( - "override cat of paddle if exists or register, remove this when fixed!") - paddle.cat = cat - - -########### hack paddle.Tensor ############# -def item(x: paddle.Tensor): - return x.numpy().item() - - -if not hasattr(paddle.Tensor, 'item'): - logger.warn( - "override item of paddle.Tensor if exists or register, remove this when fixed!" - ) - paddle.Tensor.item = item - - -def func_long(x: paddle.Tensor): - return paddle.cast(x, paddle.long) - - -if not hasattr(paddle.Tensor, 'long'): - logger.warn( - "override long of paddle.Tensor if exists or register, remove this when fixed!" - ) - paddle.Tensor.long = func_long - -if not hasattr(paddle.Tensor, 'numel'): - logger.warn( - "override numel of paddle.Tensor if exists or register, remove this when fixed!" - ) - paddle.Tensor.numel = paddle.numel - - -def new_full(x: paddle.Tensor, - size: Union[List[int], Tuple[int], paddle.Tensor], - fill_value: Union[float, int, bool, paddle.Tensor], - dtype=None): - return paddle.full(size, fill_value, dtype=x.dtype) - - -if not hasattr(paddle.Tensor, 'new_full'): - logger.warn( - "override new_full of paddle.Tensor if exists or register, remove this when fixed!" - ) - paddle.Tensor.new_full = new_full - - -def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor: - if convert_dtype_to_string(xs.dtype) == paddle.bool: - xs = xs.astype(paddle.int) - return xs.equal( - paddle.to_tensor( - ys, dtype=convert_dtype_to_string(xs.dtype), place=xs.place)) - - -if not hasattr(paddle.Tensor, 'eq'): - logger.warn( - "override eq of paddle.Tensor if exists or register, remove this when fixed!" - ) - paddle.Tensor.eq = eq - -if not hasattr(paddle, 'eq'): - logger.warn( - "override eq of paddle if exists or register, remove this when fixed!") - paddle.eq = eq - - -def contiguous(xs: paddle.Tensor) -> paddle.Tensor: - return xs - - -if not hasattr(paddle.Tensor, 'contiguous'): - logger.warn( - "override contiguous of paddle.Tensor if exists or register, remove this when fixed!" - ) - paddle.Tensor.contiguous = contiguous - - -def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor: - nargs = len(args) - assert (nargs <= 1) - s = paddle.shape(xs) - if nargs == 1: - return s[args[0]] - else: - return s - - -#`to_static` do not process `size` property, maybe some `paddle` api dependent on it. -logger.warn( - "override size of paddle.Tensor " - "(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!" -) -paddle.Tensor.size = size - - -def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor: - return xs.reshape(args) - - -if not hasattr(paddle.Tensor, 'view'): - logger.warn("register user view to paddle.Tensor, remove this when fixed!") - paddle.Tensor.view = view - - -def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor: - return xs.reshape(ys.size()) - - -if not hasattr(paddle.Tensor, 'view_as'): - logger.warn( - "register user view_as to paddle.Tensor, remove this when fixed!") - paddle.Tensor.view_as = view_as - - -def is_broadcastable(shp1, shp2): - for a, b in zip(shp1[::-1], shp2[::-1]): - if a == 1 or b == 1 or a == b: - pass - else: - return False - return True - - -def masked_fill(xs: paddle.Tensor, - mask: paddle.Tensor, - value: Union[float, int]): - assert is_broadcastable(xs.shape, mask.shape) is True - bshape = paddle.broadcast_shape(xs.shape, mask.shape) - mask = mask.broadcast_to(bshape) - trues = paddle.ones_like(xs) * value - xs = paddle.where(mask, trues, xs) - return xs - - -if not hasattr(paddle.Tensor, 'masked_fill'): - logger.warn( - "register user masked_fill to paddle.Tensor, remove this when fixed!") - paddle.Tensor.masked_fill = masked_fill - - -def masked_fill_(xs: paddle.Tensor, - mask: paddle.Tensor, - value: Union[float, int]) -> paddle.Tensor: - assert is_broadcastable(xs.shape, mask.shape) is True - bshape = paddle.broadcast_shape(xs.shape, mask.shape) - mask = mask.broadcast_to(bshape) - trues = paddle.ones_like(xs) * value - ret = paddle.where(mask, trues, xs) - paddle.assign(ret.detach(), output=xs) - return xs - - -if not hasattr(paddle.Tensor, 'masked_fill_'): - logger.warn( - "register user masked_fill_ to paddle.Tensor, remove this when fixed!") - paddle.Tensor.masked_fill_ = masked_fill_ - - -def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor: - val = paddle.full_like(xs, value) - paddle.assign(val.detach(), output=xs) - return xs - - -if not hasattr(paddle.Tensor, 'fill_'): - logger.warn("register user fill_ to paddle.Tensor, remove this when fixed!") - paddle.Tensor.fill_ = fill_ - - -def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor: - return paddle.tile(xs, size) - - -if not hasattr(paddle.Tensor, 'repeat'): - logger.warn( - "register user repeat to paddle.Tensor, remove this when fixed!") - paddle.Tensor.repeat = repeat - -if not hasattr(paddle.Tensor, 'softmax'): - logger.warn( - "register user softmax to paddle.Tensor, remove this when fixed!") - setattr(paddle.Tensor, 'softmax', paddle.nn.functional.softmax) - -if not hasattr(paddle.Tensor, 'sigmoid'): - logger.warn( - "register user sigmoid to paddle.Tensor, remove this when fixed!") - setattr(paddle.Tensor, 'sigmoid', paddle.nn.functional.sigmoid) - -if not hasattr(paddle.Tensor, 'relu'): - logger.warn("register user relu to paddle.Tensor, remove this when fixed!") - setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu) - - -def type_as(x: paddle.Tensor, other: paddle.Tensor) -> paddle.Tensor: - return x.astype(other.dtype) - - -if not hasattr(paddle.Tensor, 'type_as'): - logger.warn( - "register user type_as to paddle.Tensor, remove this when fixed!") - setattr(paddle.Tensor, 'type_as', type_as) - - -def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor: - assert len(args) == 1 - if isinstance(args[0], str): # dtype - return x.astype(args[0]) - elif isinstance(args[0], paddle.Tensor): #Tensor - return x.astype(args[0].dtype) - else: # Device - return x - - -if not hasattr(paddle.Tensor, 'to'): - logger.warn("register user to to paddle.Tensor, remove this when fixed!") - setattr(paddle.Tensor, 'to', to) - - -def func_float(x: paddle.Tensor) -> paddle.Tensor: - return x.astype(paddle.float) - - -if not hasattr(paddle.Tensor, 'float'): - logger.warn("register user float to paddle.Tensor, remove this when fixed!") - setattr(paddle.Tensor, 'float', func_float) - - -def func_int(x: paddle.Tensor) -> paddle.Tensor: - return x.astype(paddle.int) - - -if not hasattr(paddle.Tensor, 'int'): - logger.warn("register user int to paddle.Tensor, remove this when fixed!") - setattr(paddle.Tensor, 'int', func_int) - - -def tolist(x: paddle.Tensor) -> List[Any]: - return x.numpy().tolist() - - -if not hasattr(paddle.Tensor, 'tolist'): - logger.warn( - "register user tolist to paddle.Tensor, remove this when fixed!") - setattr(paddle.Tensor, 'tolist', tolist) - - -########### hack paddle.nn ############# -class GLU(nn.Layer): - """Gated Linear Units (GLU) Layer""" - - def __init__(self, dim: int=-1): - super().__init__() - self.dim = dim - - def forward(self, xs): - return F.glu(xs, axis=self.dim) - - -if not hasattr(paddle.nn, 'GLU'): - logger.warn("register user GLU to paddle.nn, remove this when fixed!") - setattr(paddle.nn, 'GLU', GLU) diff --git a/examples/other/1xt2x/src_deepspeech2x/bin/test.py b/examples/other/1xt2x/src_deepspeech2x/bin/test.py deleted file mode 100644 index 88a13fdca..000000000 --- a/examples/other/1xt2x/src_deepspeech2x/bin/test.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Evaluation for DeepSpeech2 model.""" -from src_deepspeech2x.test_model import DeepSpeech2Tester as Tester -from yacs.config import CfgNode - -from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments - - -def main_sp(config, args): - exp = Tester(config, args) - exp.setup() - exp.run_test() - - -def main(config, args): - main_sp(config, args) - - -if __name__ == "__main__": - parser = default_argument_parser() - parser.add_argument( - "--model_type", type=str, default='offline', help='offline/online') - # save asr result to - parser.add_argument( - "--result_file", type=str, help="path of save the asr result") - args = parser.parse_args() - print_arguments(args, globals()) - print("model_type:{}".format(args.model_type)) - - # https://yaml.org/type/float.html - config = CfgNode(new_allowed=True) - if args.config: - config.merge_from_file(args.config) - if args.decode_cfg: - decode_confs = CfgNode(new_allowed=True) - decode_confs.merge_from_file(args.decode_cfg) - config.decode = decode_confs - if args.opts: - config.merge_from_list(args.opts) - config.freeze() - print(config) - if args.dump_config: - with open(args.dump_config, 'w') as f: - print(config, file=f) - - main(config, args) diff --git a/examples/other/1xt2x/src_deepspeech2x/models/ds2/deepspeech2.py b/examples/other/1xt2x/src_deepspeech2x/models/ds2/deepspeech2.py deleted file mode 100644 index f6e185ff1..000000000 --- a/examples/other/1xt2x/src_deepspeech2x/models/ds2/deepspeech2.py +++ /dev/null @@ -1,275 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Deepspeech2 ASR Model""" -import paddle -from paddle import nn -from src_deepspeech2x.models.ds2.rnn import RNNStack - -from paddlespeech.s2t.models.ds2.conv import ConvStack -from paddlespeech.s2t.modules.ctc import CTCDecoder -from paddlespeech.s2t.utils import layer_tools -from paddlespeech.s2t.utils.checkpoint import Checkpoint -from paddlespeech.s2t.utils.log import Log -logger = Log(__name__).getlog() - -__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel'] - - -class CRNNEncoder(nn.Layer): - def __init__(self, - feat_size, - dict_size, - num_conv_layers=2, - num_rnn_layers=3, - rnn_size=1024, - use_gru=False, - share_rnn_weights=True): - super().__init__() - self.rnn_size = rnn_size - self.feat_size = feat_size # 161 for linear - self.dict_size = dict_size - - self.conv = ConvStack(feat_size, num_conv_layers) - - i_size = self.conv.output_height # H after conv stack - self.rnn = RNNStack( - i_size=i_size, - h_size=rnn_size, - num_stacks=num_rnn_layers, - use_gru=use_gru, - share_rnn_weights=share_rnn_weights) - - @property - def output_size(self): - return self.rnn_size * 2 - - def forward(self, audio, audio_len): - """Compute Encoder outputs - - Args: - audio (Tensor): [B, Tmax, D] - text (Tensor): [B, Umax] - audio_len (Tensor): [B] - text_len (Tensor): [B] - Returns: - x (Tensor): encoder outputs, [B, T, D] - x_lens (Tensor): encoder length, [B] - """ - # [B, T, D] -> [B, D, T] - audio = audio.transpose([0, 2, 1]) - # [B, D, T] -> [B, C=1, D, T] - x = audio.unsqueeze(1) - x_lens = audio_len - - # convolution group - x, x_lens = self.conv(x, x_lens) - x_val = x.numpy() - - # convert data from convolution feature map to sequence of vectors - #B, C, D, T = paddle.shape(x) # not work under jit - x = x.transpose([0, 3, 1, 2]) #[B, T, C, D] - #x = x.reshape([B, T, C * D]) #[B, T, C*D] # not work under jit - x = x.reshape([0, 0, -1]) #[B, T, C*D] - - # remove padding part - x, x_lens = self.rnn(x, x_lens) #[B, T, D] - return x, x_lens - - -class DeepSpeech2Model(nn.Layer): - """The DeepSpeech2 network structure. - - :param audio_data: Audio spectrogram data layer. - :type audio_data: Variable - :param text_data: Transcription text data layer. - :type text_data: Variable - :param audio_len: Valid sequence length data layer. - :type audio_len: Variable - :param masks: Masks data layer to reset padding. - :type masks: Variable - :param dict_size: Dictionary size for tokenized transcription. - :type dict_size: int - :param num_conv_layers: Number of stacking convolution layers. - :type num_conv_layers: int - :param num_rnn_layers: Number of stacking RNN layers. - :type num_rnn_layers: int - :param rnn_size: RNN layer size (dimension of RNN cells). - :type rnn_size: int - :param use_gru: Use gru if set True. Use simple rnn if set False. - :type use_gru: bool - :param share_rnn_weights: Whether to share input-hidden weights between - forward and backward direction RNNs. - It is only available when use_gru=False. - :type share_weights: bool - :return: A tuple of an output unnormalized log probability layer ( - before softmax) and a ctc cost layer. - :rtype: tuple of LayerOutput - """ - - def __init__(self, - feat_size, - dict_size, - num_conv_layers=2, - num_rnn_layers=3, - rnn_size=1024, - use_gru=False, - share_rnn_weights=True, - blank_id=0): - super().__init__() - self.encoder = CRNNEncoder( - feat_size=feat_size, - dict_size=dict_size, - num_conv_layers=num_conv_layers, - num_rnn_layers=num_rnn_layers, - rnn_size=rnn_size, - use_gru=use_gru, - share_rnn_weights=share_rnn_weights) - assert (self.encoder.output_size == rnn_size * 2) - - self.decoder = CTCDecoder( - odim=dict_size, # is in vocab - enc_n_units=self.encoder.output_size, - blank_id=blank_id, # first token is - dropout_rate=0.0, - reduction=True, # sum - batch_average=True) # sum / batch_size - - def forward(self, audio, audio_len, text, text_len): - """Compute Model loss - - Args: - audio (Tensor): [B, T, D] - audio_len (Tensor): [B] - text (Tensor): [B, U] - text_len (Tensor): [B] - - Returns: - loss (Tensor): [1] - """ - eouts, eouts_len = self.encoder(audio, audio_len) - loss = self.decoder(eouts, eouts_len, text, text_len) - return loss - - @paddle.no_grad() - def decode(self, audio, audio_len): - # decoders only accept string encoded in utf-8 - - # Make sure the decoder has been initialized - eouts, eouts_len = self.encoder(audio, audio_len) - probs = self.decoder.softmax(eouts) - batch_size = probs.shape[0] - self.decoder.reset_decoder(batch_size=batch_size) - self.decoder.next(probs, eouts_len) - trans_best, trans_beam = self.decoder.decode() - return trans_best - - @classmethod - def from_pretrained(cls, dataloader, config, checkpoint_path): - """Build a DeepSpeech2Model model from a pretrained model. - Parameters - ---------- - dataloader: paddle.io.DataLoader - - config: yacs.config.CfgNode - model configs - - checkpoint_path: Path or str - the path of pretrained model checkpoint, without extension name - - Returns - ------- - DeepSpeech2Model - The model built from pretrained result. - """ - model = cls(feat_size=dataloader.collate_fn.feature_size, - dict_size=len(dataloader.collate_fn.vocab_list), - num_conv_layers=config.num_conv_layers, - num_rnn_layers=config.num_rnn_layers, - rnn_size=config.rnn_layer_size, - use_gru=config.use_gru, - share_rnn_weights=config.share_rnn_weights) - infos = Checkpoint().load_parameters( - model, checkpoint_path=checkpoint_path) - logger.info(f"checkpoint info: {infos}") - layer_tools.summary(model) - return model - - @classmethod - def from_config(cls, config): - """Build a DeepSpeec2Model from config - Parameters - - config: yacs.config.CfgNode - config - Returns - ------- - DeepSpeech2Model - The model built from config. - """ - model = cls(feat_size=config.feat_size, - dict_size=config.dict_size, - num_conv_layers=config.num_conv_layers, - num_rnn_layers=config.num_rnn_layers, - rnn_size=config.rnn_layer_size, - use_gru=config.use_gru, - share_rnn_weights=config.share_rnn_weights, - blank_id=config.blank_id) - return model - - -class DeepSpeech2InferModel(DeepSpeech2Model): - def __init__(self, - feat_size, - dict_size, - num_conv_layers=2, - num_rnn_layers=3, - rnn_size=1024, - use_gru=False, - share_rnn_weights=True, - blank_id=0): - super().__init__( - feat_size=feat_size, - dict_size=dict_size, - num_conv_layers=num_conv_layers, - num_rnn_layers=num_rnn_layers, - rnn_size=rnn_size, - use_gru=use_gru, - share_rnn_weights=share_rnn_weights, - blank_id=blank_id) - - def forward(self, audio, audio_len): - """export model function - - Args: - audio (Tensor): [B, T, D] - audio_len (Tensor): [B] - - Returns: - probs: probs after softmax - """ - eouts, eouts_len = self.encoder(audio, audio_len) - probs = self.decoder.softmax(eouts) - return probs, eouts_len - - def export(self): - static_model = paddle.jit.to_static( - self, - input_spec=[ - paddle.static.InputSpec( - shape=[None, None, self.encoder.feat_size], - dtype='float32'), # audio, [B,T,D] - paddle.static.InputSpec(shape=[None], - dtype='int64'), # audio_length, [B] - ]) - return static_model diff --git a/examples/other/1xt2x/src_deepspeech2x/models/ds2/rnn.py b/examples/other/1xt2x/src_deepspeech2x/models/ds2/rnn.py deleted file mode 100644 index 383a07467..000000000 --- a/examples/other/1xt2x/src_deepspeech2x/models/ds2/rnn.py +++ /dev/null @@ -1,334 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math - -import paddle -from paddle import nn -from paddle.nn import functional as F -from paddle.nn import initializer as I - -from paddlespeech.s2t.modules.activation import brelu -from paddlespeech.s2t.modules.mask import make_non_pad_mask -from paddlespeech.s2t.utils.log import Log -logger = Log(__name__).getlog() - -__all__ = ['RNNStack'] - - -class RNNCell(nn.RNNCellBase): - r""" - Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it - computes the outputs and updates states. - The formula used is as follows: - .. math:: - h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh}) - y_{t} & = h_{t} - - where :math:`act` is for :attr:`activation`. - """ - - def __init__(self, - hidden_size: int, - activation="tanh", - weight_ih_attr=None, - weight_hh_attr=None, - bias_ih_attr=None, - bias_hh_attr=None, - name=None): - super().__init__() - std = 1.0 / math.sqrt(hidden_size) - self.weight_hh = self.create_parameter( - (hidden_size, hidden_size), - weight_hh_attr, - default_initializer=I.Uniform(-std, std)) - self.bias_ih = None - self.bias_hh = self.create_parameter( - (hidden_size, ), - bias_hh_attr, - is_bias=True, - default_initializer=I.Uniform(-std, std)) - - self.hidden_size = hidden_size - if activation not in ["tanh", "relu", "brelu"]: - raise ValueError( - "activation for SimpleRNNCell should be tanh or relu, " - "but get {}".format(activation)) - self.activation = activation - self._activation_fn = paddle.tanh \ - if activation == "tanh" \ - else F.relu - if activation == 'brelu': - self._activation_fn = brelu - - def forward(self, inputs, states=None): - if states is None: - states = self.get_initial_states(inputs, self.state_shape) - pre_h = states - i2h = inputs - if self.bias_ih is not None: - i2h += self.bias_ih - h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True) - if self.bias_hh is not None: - h2h += self.bias_hh - h = self._activation_fn(i2h + h2h) - return h, h - - @property - def state_shape(self): - return (self.hidden_size, ) - - -class GRUCell(nn.RNNCellBase): - r""" - Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, - it computes the outputs and updates states. - The formula for GRU used is as follows: - .. math:: - r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr}) - z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz}) - \widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc})) - h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t} - y_{t} & = h_{t} - - where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise - multiplication operator. - """ - - def __init__(self, - input_size: int, - hidden_size: int, - weight_ih_attr=None, - weight_hh_attr=None, - bias_ih_attr=None, - bias_hh_attr=None, - name=None): - super().__init__() - std = 1.0 / math.sqrt(hidden_size) - self.weight_hh = self.create_parameter( - (3 * hidden_size, hidden_size), - weight_hh_attr, - default_initializer=I.Uniform(-std, std)) - self.bias_ih = None - self.bias_hh = self.create_parameter( - (3 * hidden_size, ), - bias_hh_attr, - is_bias=True, - default_initializer=I.Uniform(-std, std)) - - self.hidden_size = hidden_size - self.input_size = input_size - self._gate_activation = F.sigmoid - self._activation = paddle.relu - - def forward(self, inputs, states=None): - if states is None: - states = self.get_initial_states(inputs, self.state_shape) - - pre_hidden = states # shape [batch_size, hidden_size] - - x_gates = inputs - if self.bias_ih is not None: - x_gates = x_gates + self.bias_ih - bias_u, bias_r, bias_c = paddle.split( - self.bias_hh, num_or_sections=3, axis=0) - - weight_hh = paddle.transpose( - self.weight_hh, - perm=[1, 0]) #weight_hh:shape[hidden_size, 3 * hidden_size] - w_u_r_c = paddle.flatten(weight_hh) - size_u_r = self.hidden_size * 2 * self.hidden_size - w_u_r = paddle.reshape(w_u_r_c[:size_u_r], - (self.hidden_size, self.hidden_size * 2)) - w_u, w_r = paddle.split(w_u_r, num_or_sections=2, axis=1) - w_c = paddle.reshape(w_u_r_c[size_u_r:], - (self.hidden_size, self.hidden_size)) - - h_u = paddle.matmul( - pre_hidden, w_u, - transpose_y=False) + bias_u #shape [batch_size, hidden_size] - h_r = paddle.matmul( - pre_hidden, w_r, - transpose_y=False) + bias_r #shape [batch_size, hidden_size] - - x_u, x_r, x_c = paddle.split( - x_gates, num_or_sections=3, axis=1) #shape[batch_size, hidden_size] - - u = self._gate_activation(x_u + h_u) #shape [batch_size, hidden_size] - r = self._gate_activation(x_r + h_r) #shape [batch_size, hidden_size] - c = self._activation( - x_c + paddle.matmul(r * pre_hidden, w_c, transpose_y=False) + - bias_c) # [batch_size, hidden_size] - - h = (1 - u) * pre_hidden + u * c - # https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru - return h, h - - @property - def state_shape(self): - r""" - The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch - size would be automatically inserted into shape). The shape corresponds - to the shape of :math:`h_{t-1}`. - """ - return (self.hidden_size, ) - - -class BiRNNWithBN(nn.Layer): - """Bidirectonal simple rnn layer with sequence-wise batch normalization. - The batch normalization is only performed on input-state weights. - - :param size: Dimension of RNN cells. - :type size: int - :param share_weights: Whether to share input-hidden weights between - forward and backward directional RNNs. - :type share_weights: bool - :return: Bidirectional simple rnn layer. - :rtype: Variable - """ - - def __init__(self, i_size: int, h_size: int, share_weights: bool): - super().__init__() - self.share_weights = share_weights - if self.share_weights: - #input-hidden weights shared between bi-directional rnn. - self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False) - # batch norm is only performed on input-state projection - self.fw_bn = nn.BatchNorm1D( - h_size, bias_attr=None, data_format='NLC') - self.bw_fc = self.fw_fc - self.bw_bn = self.fw_bn - else: - self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False) - self.fw_bn = nn.BatchNorm1D( - h_size, bias_attr=None, data_format='NLC') - self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False) - self.bw_bn = nn.BatchNorm1D( - h_size, bias_attr=None, data_format='NLC') - - self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu') - self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu') - self.fw_rnn = nn.RNN( - self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] - self.bw_rnn = nn.RNN( - self.bw_cell, is_reverse=True, time_major=False) #[B, T, D] - - def forward(self, x: paddle.Tensor, x_len: paddle.Tensor): - # x, shape [B, T, D] - fw_x = self.fw_bn(self.fw_fc(x)) - bw_x = self.bw_bn(self.bw_fc(x)) - fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) - bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) - x = paddle.concat([fw_x, bw_x], axis=-1) - return x, x_len - - -class BiGRUWithBN(nn.Layer): - """Bidirectonal gru layer with sequence-wise batch normalization. - The batch normalization is only performed on input-state weights. - - :param name: Name of the layer. - :type name: string - :param input: Input layer. - :type input: Variable - :param size: Dimension of GRU cells. - :type size: int - :param act: Activation type. - :type act: string - :return: Bidirectional GRU layer. - :rtype: Variable - """ - - def __init__(self, i_size: int, h_size: int): - super().__init__() - hidden_size = h_size * 3 - - self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) - self.fw_bn = nn.BatchNorm1D( - hidden_size, bias_attr=None, data_format='NLC') - self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) - self.bw_bn = nn.BatchNorm1D( - hidden_size, bias_attr=None, data_format='NLC') - - self.fw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size) - self.bw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size) - self.fw_rnn = nn.RNN( - self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] - self.bw_rnn = nn.RNN( - self.bw_cell, is_reverse=True, time_major=False) #[B, T, D] - - def forward(self, x, x_len): - # x, shape [B, T, D] - fw_x = self.fw_bn(self.fw_fc(x)) - - bw_x = self.bw_bn(self.bw_fc(x)) - fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) - bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) - x = paddle.concat([fw_x, bw_x], axis=-1) - return x, x_len - - -class RNNStack(nn.Layer): - """RNN group with stacked bidirectional simple RNN or GRU layers. - - :param input: Input layer. - :type input: Variable - :param size: Dimension of RNN cells in each layer. - :type size: int - :param num_stacks: Number of stacked rnn layers. - :type num_stacks: int - :param use_gru: Use gru if set True. Use simple rnn if set False. - :type use_gru: bool - :param share_rnn_weights: Whether to share input-hidden weights between - forward and backward directional RNNs. - It is only available when use_gru=False. - :type share_weights: bool - :return: Output layer of the RNN group. - :rtype: Variable - """ - - def __init__(self, - i_size: int, - h_size: int, - num_stacks: int, - use_gru: bool, - share_rnn_weights: bool): - super().__init__() - rnn_stacks = [] - for i in range(num_stacks): - if use_gru: - #default:GRU using tanh - rnn_stacks.append(BiGRUWithBN(i_size=i_size, h_size=h_size)) - else: - rnn_stacks.append( - BiRNNWithBN( - i_size=i_size, - h_size=h_size, - share_weights=share_rnn_weights)) - i_size = h_size * 2 - - self.rnn_stacks = nn.LayerList(rnn_stacks) - - def forward(self, x: paddle.Tensor, x_len: paddle.Tensor): - """ - x: shape [B, T, D] - x_len: shpae [B] - """ - for i, rnn in enumerate(self.rnn_stacks): - x, x_len = rnn(x, x_len) - masks = make_non_pad_mask(x_len) #[B, T] - masks = masks.unsqueeze(-1) # [B, T, 1] - # TODO(Hui Zhang): not support bool multiply - masks = masks.astype(x.dtype) - x = x.multiply(masks) - return x, x_len diff --git a/examples/other/1xt2x/src_deepspeech2x/test_model.py b/examples/other/1xt2x/src_deepspeech2x/test_model.py deleted file mode 100644 index 11b85442d..000000000 --- a/examples/other/1xt2x/src_deepspeech2x/test_model.py +++ /dev/null @@ -1,357 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Contains DeepSpeech2 and DeepSpeech2Online model.""" -import time -from collections import defaultdict -from contextlib import nullcontext - -import numpy as np -import paddle -from paddle import distributed as dist -from paddle.io import DataLoader -from src_deepspeech2x.models.ds2 import DeepSpeech2InferModel -from src_deepspeech2x.models.ds2 import DeepSpeech2Model - -from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer -from paddlespeech.s2t.io.collator import SpeechCollator -from paddlespeech.s2t.io.dataset import ManifestDataset -from paddlespeech.s2t.io.sampler import SortagradBatchSampler -from paddlespeech.s2t.io.sampler import SortagradDistributedBatchSampler -from paddlespeech.s2t.models.ds2_online import DeepSpeech2InferModelOnline -from paddlespeech.s2t.models.ds2_online import DeepSpeech2ModelOnline -from paddlespeech.s2t.training.gradclip import ClipGradByGlobalNormWithLog -from paddlespeech.s2t.training.trainer import Trainer -from paddlespeech.s2t.utils import error_rate -from paddlespeech.s2t.utils import layer_tools -from paddlespeech.s2t.utils import mp_tools -from paddlespeech.s2t.utils.log import Log - -logger = Log(__name__).getlog() - - -class DeepSpeech2Trainer(Trainer): - def __init__(self, config, args): - super().__init__(config, args) - - def train_batch(self, batch_index, batch_data, msg): - train_conf = self.config - start = time.time() - - # forward - utt, audio, audio_len, text, text_len = batch_data - loss = self.model(audio, audio_len, text, text_len) - losses_np = { - 'train_loss': float(loss), - } - - # loss backward - if (batch_index + 1) % train_conf.accum_grad != 0: - # Disable gradient synchronizations across DDP processes. - # Within this context, gradients will be accumulated on module - # variables, which will later be synchronized. - context = self.model.no_sync - else: - # Used for single gpu training and DDP gradient synchronization - # processes. - context = nullcontext - - with context(): - loss.backward() - layer_tools.print_grads(self.model, print_func=None) - - # optimizer step - if (batch_index + 1) % train_conf.accum_grad == 0: - self.optimizer.step() - self.optimizer.clear_grad() - self.iteration += 1 - - iteration_time = time.time() - start - - msg += "train time: {:>.3f}s, ".format(iteration_time) - msg += "batch size: {}, ".format(self.config.batch_size) - msg += "accum: {}, ".format(train_conf.accum_grad) - msg += ', '.join('{}: {:>.6f}'.format(k, v) - for k, v in losses_np.items()) - logger.info(msg) - - if dist.get_rank() == 0 and self.visualizer: - for k, v in losses_np.items(): - # `step -1` since we update `step` after optimizer.step(). - self.visualizer.add_scalar("train/{}".format(k), v, - self.iteration - 1) - - @paddle.no_grad() - def valid(self): - logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") - self.model.eval() - valid_losses = defaultdict(list) - num_seen_utts = 1 - total_loss = 0.0 - for i, batch in enumerate(self.valid_loader): - utt, audio, audio_len, text, text_len = batch - loss = self.model(audio, audio_len, text, text_len) - if paddle.isfinite(loss): - num_utts = batch[1].shape[0] - num_seen_utts += num_utts - total_loss += float(loss) * num_utts - valid_losses['val_loss'].append(float(loss)) - - if (i + 1) % self.config.log_interval == 0: - valid_dump = {k: np.mean(v) for k, v in valid_losses.items()} - valid_dump['val_history_loss'] = total_loss / num_seen_utts - - # logging - msg = f"Valid: Rank: {dist.get_rank()}, " - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(i + 1, len(self.valid_loader)) - msg += ', '.join('{}: {:>.6f}'.format(k, v) - for k, v in valid_dump.items()) - logger.info(msg) - - logger.info('Rank {} Val info val_loss {}'.format( - dist.get_rank(), total_loss / num_seen_utts)) - return total_loss, num_seen_utts - - def setup_model(self): - config = self.config.clone() - config.defrost() - config.feat_size = self.train_loader.collate_fn.feature_size - #config.dict_size = self.train_loader.collate_fn.vocab_size - config.dict_size = len(self.train_loader.collate_fn.vocab_list) - config.freeze() - - if self.args.model_type == 'offline': - model = DeepSpeech2Model.from_config(config) - elif self.args.model_type == 'online': - model = DeepSpeech2ModelOnline.from_config(config) - else: - raise Exception("wrong model type") - if self.parallel: - model = paddle.DataParallel(model) - - logger.info(f"{model}") - layer_tools.print_params(model, logger.info) - - grad_clip = ClipGradByGlobalNormWithLog(config.global_grad_clip) - lr_scheduler = paddle.optimizer.lr.ExponentialDecay( - learning_rate=config.lr, gamma=config.lr_decay, verbose=True) - optimizer = paddle.optimizer.Adam( - learning_rate=lr_scheduler, - parameters=model.parameters(), - weight_decay=paddle.regularizer.L2Decay(config.weight_decay), - grad_clip=grad_clip) - - self.model = model - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - logger.info("Setup model/optimizer/lr_scheduler!") - - def setup_dataloader(self): - config = self.config.clone() - config.defrost() - config.keep_transcription_text = False - - config.manifest = config.train_manifest - train_dataset = ManifestDataset.from_config(config) - - config.manifest = config.dev_manifest - dev_dataset = ManifestDataset.from_config(config) - - config.manifest = config.test_manifest - test_dataset = ManifestDataset.from_config(config) - - if self.parallel: - batch_sampler = SortagradDistributedBatchSampler( - train_dataset, - batch_size=config.batch_size, - num_replicas=None, - rank=None, - shuffle=True, - drop_last=True, - sortagrad=config.sortagrad, - shuffle_method=config.shuffle_method) - else: - batch_sampler = SortagradBatchSampler( - train_dataset, - shuffle=True, - batch_size=config.batch_size, - drop_last=True, - sortagrad=config.sortagrad, - shuffle_method=config.shuffle_method) - - collate_fn_train = SpeechCollator.from_config(config) - - config.augmentation_config = "" - collate_fn_dev = SpeechCollator.from_config(config) - - config.keep_transcription_text = True - config.augmentation_config = "" - collate_fn_test = SpeechCollator.from_config(config) - - self.train_loader = DataLoader( - train_dataset, - batch_sampler=batch_sampler, - collate_fn=collate_fn_train, - num_workers=config.num_workers) - self.valid_loader = DataLoader( - dev_dataset, - batch_size=config.batch_size, - shuffle=False, - drop_last=False, - collate_fn=collate_fn_dev) - self.test_loader = DataLoader( - test_dataset, - batch_size=config.decode.decode_batch_size, - shuffle=False, - drop_last=False, - collate_fn=collate_fn_test) - if "" in self.test_loader.collate_fn.vocab_list: - self.test_loader.collate_fn.vocab_list.remove("") - if "" in self.valid_loader.collate_fn.vocab_list: - self.valid_loader.collate_fn.vocab_list.remove("") - if "" in self.train_loader.collate_fn.vocab_list: - self.train_loader.collate_fn.vocab_list.remove("") - logger.info("Setup train/valid/test Dataloader!") - - -class DeepSpeech2Tester(DeepSpeech2Trainer): - def __init__(self, config, args): - - self._text_featurizer = TextFeaturizer( - unit_type=config.unit_type, vocab=None) - super().__init__(config, args) - - def ordid2token(self, texts, texts_len): - """ ord() id to chr() chr """ - trans = [] - for text, n in zip(texts, texts_len): - n = n.numpy().item() - ids = text[:n] - trans.append(''.join([chr(i) for i in ids])) - return trans - - def compute_metrics(self, - utts, - audio, - audio_len, - texts, - texts_len, - fout=None): - cfg = self.config.decode - errors_sum, len_refs, num_ins = 0.0, 0, 0 - errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors - error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer - - target_transcripts = self.ordid2token(texts, texts_len) - - result_transcripts = self.compute_result_transcripts(audio, audio_len) - - for utt, target, result in zip(utts, target_transcripts, - result_transcripts): - errors, len_ref = errors_func(target, result) - errors_sum += errors - len_refs += len_ref - num_ins += 1 - if fout: - fout.write(utt + " " + result + "\n") - logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % - (target, result)) - logger.info("Current error rate [%s] = %f" % - (cfg.error_rate_type, error_rate_func(target, result))) - - return dict( - errors_sum=errors_sum, - len_refs=len_refs, - num_ins=num_ins, - error_rate=errors_sum / len_refs, - error_rate_type=cfg.error_rate_type) - - def compute_result_transcripts(self, audio, audio_len): - result_transcripts = self.model.decode(audio, audio_len) - - result_transcripts = [ - self._text_featurizer.detokenize(item) - for item in result_transcripts - ] - return result_transcripts - - @mp_tools.rank_zero_only - @paddle.no_grad() - def test(self): - logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") - self.model.eval() - cfg = self.config - error_rate_type = None - errors_sum, len_refs, num_ins = 0.0, 0, 0 - - # Initialized the decoder in model - decode_cfg = self.config.decode - vocab_list = self.test_loader.collate_fn.vocab_list - decode_batch_size = self.test_loader.batch_size - self.model.decoder.init_decoder( - decode_batch_size, vocab_list, decode_cfg.decoding_method, - decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta, - decode_cfg.beam_size, decode_cfg.cutoff_prob, - decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch) - - with open(self.args.result_file, 'w') as fout: - for i, batch in enumerate(self.test_loader): - utts, audio, audio_len, texts, texts_len = batch - metrics = self.compute_metrics(utts, audio, audio_len, texts, - texts_len, fout) - errors_sum += metrics['errors_sum'] - len_refs += metrics['len_refs'] - num_ins += metrics['num_ins'] - error_rate_type = metrics['error_rate_type'] - logger.info("Error rate [%s] (%d/?) = %f" % - (error_rate_type, num_ins, errors_sum / len_refs)) - - # logging - msg = "Test: " - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "Final error rate [%s] (%d/%d) = %f" % ( - error_rate_type, num_ins, num_ins, errors_sum / len_refs) - logger.info(msg) - self.model.decoder.del_decoder() - - def run_test(self): - self.resume_or_scratch() - try: - self.test() - except KeyboardInterrupt: - exit(-1) - - def export(self): - if self.args.model_type == 'offline': - infer_model = DeepSpeech2InferModel.from_pretrained( - self.test_loader, self.config, self.args.checkpoint_path) - elif self.args.model_type == 'online': - infer_model = DeepSpeech2InferModelOnline.from_pretrained( - self.test_loader, self.config, self.args.checkpoint_path) - else: - raise Exception("wrong model type") - - infer_model.eval() - feat_dim = self.test_loader.collate_fn.feature_size - static_model = infer_model.export() - logger.info(f"Export code: {static_model.forward.code}") - paddle.jit.save(static_model, self.args.export_path) - - def run_export(self): - try: - self.export() - except KeyboardInterrupt: - exit(-1) diff --git a/examples/other/mfa/local/reorganize_baker.py b/examples/other/mfa/local/reorganize_baker.py index 8adad834f..153e01d13 100644 --- a/examples/other/mfa/local/reorganize_baker.py +++ b/examples/other/mfa/local/reorganize_baker.py @@ -42,9 +42,6 @@ def get_transcripts(path: Union[str, Path]): for i in range(0, len(lines), 2): sentence_id = lines[i].split()[0] transcription = lines[i + 1].strip() - # tones are dropped here - # since the lexicon does not consider tones, too - transcription = " ".join([item[:-1] for item in transcription.split()]) transcripts[sentence_id] = transcription return transcripts diff --git a/examples/other/mfa/run.sh b/examples/other/mfa/run.sh old mode 100644 new mode 100755 index 1fef58b4e..29dacc9b1 --- a/examples/other/mfa/run.sh +++ b/examples/other/mfa/run.sh @@ -4,7 +4,7 @@ mkdir -p $EXP_DIR LEXICON_NAME='simple' if [ ! -f "$EXP_DIR/$LEXICON_NAME.lexicon" ]; then echo "generating lexicon..." - python local/generate_lexicon.py "$EXP_DIR/$LEXICON_NAME" --with-r + python local/generate_lexicon.py "$EXP_DIR/$LEXICON_NAME" --with-r --with-tone echo "lexicon done" fi @@ -16,6 +16,7 @@ if [ ! -d $EXP_DIR/baker_corpus ]; then echo "transcription for each audio file is saved with the same namd in $EXP_DIR/baker_corpus " fi + echo "detecting oov..." python local/detect_oov.py $EXP_DIR/baker_corpus $EXP_DIR/"$LEXICON_NAME.lexicon" echo "detecting oov done. you may consider regenerate lexicon if there is unexpected OOVs." @@ -44,6 +45,3 @@ if [ ! -d "$EXP_DIR/baker_alignment" ]; then echo "model: $EXP_DIR/baker_model" fi - - - diff --git a/examples/ted_en_zh/st0/local/train.sh b/examples/ted_en_zh/st0/local/train.sh index ad00653b7..71659e28d 100755 --- a/examples/ted_en_zh/st0/local/train.sh +++ b/examples/ted_en_zh/st0/local/train.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 2 ];then - echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" +if [ $# -lt 2 ] && [ $# -gt 3 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)" exit -1 fi @@ -10,6 +10,13 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_name=$2 +ips=$3 + +if [ ! $ips ];then + ips_config= +else + ips_config="--ips="${ips} +fi mkdir -p exp @@ -26,7 +33,7 @@ python3 -u ${BIN_DIR}/train.py \ --output exp/${ckpt_name} \ --seed ${seed} else -python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${BIN_DIR}/train.py \ +python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \ --ngpu ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ diff --git a/examples/ted_en_zh/st0/run.sh b/examples/ted_en_zh/st0/run.sh index 1746c0251..c5a59f657 100755 --- a/examples/ted_en_zh/st0/run.sh +++ b/examples/ted_en_zh/st0/run.sh @@ -6,6 +6,7 @@ gpus=0,1,2,3 stage=0 stop_stage=50 conf_path=conf/transformer_mtl_noam.yaml +ips= #xx.xx.xx.xx,xx.xx.xx.xx decode_conf_path=conf/tuning/decode.yaml avg_num=5 data_path=./TED_EnZh # path to unzipped data @@ -23,7 +24,7 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then diff --git a/examples/ted_en_zh/st1/local/train.sh b/examples/ted_en_zh/st1/local/train.sh index 5da64e99c..3e9295e53 100755 --- a/examples/ted_en_zh/st1/local/train.sh +++ b/examples/ted_en_zh/st1/local/train.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 3 ];then - echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ckpt_path" +if [ $# -lt 3 ] && [ $# -gt 4 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)" exit -1 fi @@ -11,6 +11,15 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_name=$2 ckpt_path=$3 +ips=$3 + +if [ ! $ips ];then + ips_config= +else + ips_config="--ips="${ips} +fi + +mkdir -p exp mkdir -p exp @@ -28,7 +37,7 @@ python3 -u ${BIN_DIR}/train.py \ --checkpoint_path "${ckpt_path}" \ --seed ${seed} else -python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${BIN_DIR}/train.py \ +python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \ --ngpu ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ diff --git a/examples/ted_en_zh/st1/run.sh b/examples/ted_en_zh/st1/run.sh index 1808e37b4..06a407d44 100755 --- a/examples/ted_en_zh/st1/run.sh +++ b/examples/ted_en_zh/st1/run.sh @@ -7,6 +7,7 @@ gpus=0,1,2,3 stage=1 stop_stage=4 conf_path=conf/transformer_mtl_noam.yaml +ips= #xx.xx.xx.xx,xx.xx.xx.xx decode_conf_path=conf/tuning/decode.yaml ckpt_path= # paddle.98 # (finetune from FAT-ST pretrained model) avg_num=5 @@ -29,7 +30,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then echo "Finetune from Pretrained Model" ${ckpt_path} ./local/download_pretrain.sh || exit -1 fi - CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} "${ckpt_path}" + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} "${ckpt_path}" ${ips} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then diff --git a/examples/tiny/asr0/conf/augmentation.json b/examples/tiny/asr0/conf/augmentation.json deleted file mode 100644 index 4480307b9..000000000 --- a/examples/tiny/asr0/conf/augmentation.json +++ /dev/null @@ -1,36 +0,0 @@ -[ - { - "type": "speed", - "params": { - "min_speed_rate": 0.9, - "max_speed_rate": 1.1, - "num_rates": 3 - }, - "prob": 0.0 - }, - { - "type": "shift", - "params": { - "min_shift_ms": -5, - "max_shift_ms": 5 - }, - "prob": 1.0 - }, - { - "type": "specaug", - "params": { - "W": 5, - "warp_mode": "PIL", - "F": 30, - "n_freq_masks": 2, - "T": 40, - "n_time_masks": 2, - "p": 1.0, - "adaptive_number_ratio": 0, - "adaptive_size_ratio": 0, - "max_n_time_masks": 20, - "replace_with_zero": true - }, - "prob": 1.0 - } -] diff --git a/examples/tiny/asr0/conf/deepspeech2.yaml b/examples/tiny/asr0/conf/deepspeech2.yaml index 64d432e26..a94143b95 100644 --- a/examples/tiny/asr0/conf/deepspeech2.yaml +++ b/examples/tiny/asr0/conf/deepspeech2.yaml @@ -16,28 +16,26 @@ max_output_input_ratio: 10.0 ########################################### # Dataloader # ########################################### -mean_std_filepath: data/mean_std.json -unit_type: char -vocab_filepath: data/lang_char/vocab.txt -augmentation_config: conf/augmentation.json -random_seed: 0 -spm_model_prefix: -spectrum_type: linear +vocab_filepath: data/lang_char/vocab.txt +spm_model_prefix: '' +unit_type: 'char' +preprocess_config: conf/preprocess.yaml feat_dim: 161 -delta_delta: False stride_ms: 10.0 -window_ms: 20.0 -n_fft: None -max_freq: None -target_sample_rate: 16000 -use_dB_normalization: True -target_dB: -20 -dither: 1.0 -keep_transcription_text: False -sortagrad: True -shuffle_method: batch_shuffle -num_workers: 2 +window_ms: 25.0 +sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs batch_size: 4 +maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced +minibatches: 0 # for debug +batch_count: auto +batch_bins: 0 +batch_frames_in: 0 +batch_frames_out: 0 +batch_frames_inout: 0 +num_workers: 8 +subsampling_factor: 1 +num_encs: 1 ############################################ # Network Architecture # @@ -45,8 +43,10 @@ batch_size: 4 num_conv_layers: 2 num_rnn_layers: 3 rnn_layer_size: 2048 +rnn_direction: bidirect # [forward, bidirect] +num_fc_layers: 0 +fc_layers_size_list: -1, use_gru: False -share_rnn_weights: True blank_id: 0 @@ -59,6 +59,7 @@ lr: 1.0e-5 lr_decay: 0.8 weight_decay: 1.0e-6 global_grad_clip: 5.0 +dist_sampler: False log_interval: 1 checkpoint: kbest_n: 3 diff --git a/examples/tiny/asr0/conf/deepspeech2_online.yaml b/examples/tiny/asr0/conf/deepspeech2_online.yaml index 74a4dc814..1bd8da19c 100644 --- a/examples/tiny/asr0/conf/deepspeech2_online.yaml +++ b/examples/tiny/asr0/conf/deepspeech2_online.yaml @@ -16,29 +16,27 @@ max_output_input_ratio: 10.0 ########################################### # Dataloader # ########################################### -mean_std_filepath: data/mean_std.json -unit_type: char -vocab_filepath: data/lang_char/vocab.txt -augmentation_config: conf/augmentation.json -random_seed: 0 -spm_model_prefix: -spectrum_type: linear +vocab_filepath: data/lang_char/vocab.txt +spm_model_prefix: '' +unit_type: 'char' +preprocess_config: conf/preprocess.yaml feat_dim: 161 -delta_delta: False stride_ms: 10.0 -window_ms: 20.0 -n_fft: None -max_freq: None -target_sample_rate: 16000 -use_dB_normalization: True -target_dB: -20 -dither: 1.0 -keep_transcription_text: False -sortagrad: True -shuffle_method: batch_shuffle -num_workers: 0 +window_ms: 25.0 +sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs batch_size: 4 - +maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced +minibatches: 0 # for debug +batch_count: auto +batch_bins: 0 +batch_frames_in: 0 +batch_frames_out: 0 +batch_frames_inout: 0 +num_workers: 8 +subsampling_factor: 1 +num_encs: 1 + ############################################ # Network Architecture # ############################################ @@ -61,6 +59,7 @@ lr: 1.0e-5 lr_decay: 1.0 weight_decay: 1.0e-6 global_grad_clip: 5.0 +dist_sampler: False log_interval: 1 checkpoint: kbest_n: 3 diff --git a/examples/tiny/asr0/conf/preprocess.yaml b/examples/tiny/asr0/conf/preprocess.yaml new file mode 100644 index 000000000..3f526e0ad --- /dev/null +++ b/examples/tiny/asr0/conf/preprocess.yaml @@ -0,0 +1,25 @@ +process: + # extract kaldi fbank from PCM + - type: fbank_kaldi + fs: 16000 + n_mels: 161 + n_shift: 160 + win_length: 400 + dither: 0.1 + - type: cmvn_json + cmvn_path: data/mean_std.json + # these three processes are a.k.a. SpecAugument + - type: time_warp + max_time_warp: 5 + inplace: true + mode: PIL + - type: freq_mask + F: 30 + n_mask: 2 + inplace: true + replace_with_zero: false + - type: time_mask + T: 40 + n_mask: 2 + inplace: true + replace_with_zero: false diff --git a/examples/tiny/asr0/local/export.sh b/examples/tiny/asr0/local/export.sh index 426a72fe5..ce7e6d642 100755 --- a/examples/tiny/asr0/local/export.sh +++ b/examples/tiny/asr0/local/export.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 4 ];then - echo "usage: $0 config_path ckpt_prefix jit_model_path model_type" +if [ $# != 3 ];then + echo "usage: $0 config_path ckpt_prefix jit_model_path" exit -1 fi @@ -11,14 +11,12 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_path_prefix=$2 jit_model_export_path=$3 -model_type=$4 python3 -u ${BIN_DIR}/export.py \ --ngpu ${ngpu} \ --config ${config_path} \ --checkpoint_path ${ckpt_path_prefix} \ ---export_path ${jit_model_export_path} \ ---model_type ${model_type} +--export_path ${jit_model_export_path} if [ $? -ne 0 ]; then echo "Failed in export!" diff --git a/examples/tiny/asr0/local/test.sh b/examples/tiny/asr0/local/test.sh index ea40046b1..55f97d2ec 100755 --- a/examples/tiny/asr0/local/test.sh +++ b/examples/tiny/asr0/local/test.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 4 ];then - echo "usage: ${0} config_path decode_config_path ckpt_path_prefix model_type" +if [ $# != 3 ];then + echo "usage: ${0} config_path decode_config_path ckpt_path_prefix" exit -1 fi @@ -11,7 +11,6 @@ echo "using $ngpu gpus..." config_path=$1 decode_config_path=$2 ckpt_prefix=$3 -model_type=$4 # download language model bash local/download_lm_en.sh @@ -24,8 +23,7 @@ python3 -u ${BIN_DIR}/test.py \ --config ${config_path} \ --decode_cfg ${decode_config_path} \ --result_file ${ckpt_prefix}.rsl \ ---checkpoint_path ${ckpt_prefix} \ ---model_type ${model_type} +--checkpoint_path ${ckpt_prefix} if [ $? -ne 0 ]; then echo "Failed in evaluation!" diff --git a/examples/tiny/asr0/local/train.sh b/examples/tiny/asr0/local/train.sh index 9060be674..8b67902fe 100755 --- a/examples/tiny/asr0/local/train.sh +++ b/examples/tiny/asr0/local/train.sh @@ -15,14 +15,20 @@ if [ ${seed} != 0 ]; then echo "using seed $seed & FLAGS_cudnn_deterministic=True ..." fi -if [ $# != 3 ];then - echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type" +if [ $# -lt 2 ] && [ $# -gt 3 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)" exit -1 fi config_path=$1 ckpt_name=$2 -model_type=$3 +ips=$3 + +if [ ! $ips ];then + ips_config= +else + ips_config="--ips="${ips} +fi mkdir -p exp @@ -31,15 +37,13 @@ python3 -u ${BIN_DIR}/train.py \ --ngpu ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ ---model_type ${model_type} \ --profiler-options "${profiler_options}" \ --seed ${seed} else -python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${BIN_DIR}/train.py \ +python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \ --ngpu ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ ---model_type ${model_type} \ --profiler-options "${profiler_options}" \ --seed ${seed} fi diff --git a/examples/tiny/asr0/run.sh b/examples/tiny/asr0/run.sh index 25f046245..3e84d4224 100755 --- a/examples/tiny/asr0/run.sh +++ b/examples/tiny/asr0/run.sh @@ -2,14 +2,13 @@ set -e source path.sh -gpus=0 +gpus=4 stage=0 stop_stage=100 conf_path=conf/deepspeech2.yaml +ips= #xx.xx.xx.xx,xx.xx.xx.xx decode_conf_path=conf/tuning/decode.yaml avg_num=1 -model_type=offline - source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; avg_ckpt=avg_${avg_num} @@ -23,7 +22,7 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${model_type} + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then @@ -33,10 +32,10 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type} || exit -1 + CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt}|| exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # export ckpt avg_n - CUDA_VISIBLE_DEVICES=${gpus} ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type} + CUDA_VISIBLE_DEVICES=${gpus} ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit fi diff --git a/examples/tiny/asr1/local/train.sh b/examples/tiny/asr1/local/train.sh index 5617f7efe..459f2e218 100755 --- a/examples/tiny/asr1/local/train.sh +++ b/examples/tiny/asr1/local/train.sh @@ -17,13 +17,20 @@ if [ ${seed} != 0 ]; then echo "using seed $seed & FLAGS_cudnn_deterministic=True ..." fi -if [ $# != 2 ];then - echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" +if [ $# -lt 2 ] && [ $# -gt 3 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)" exit -1 fi config_path=$1 ckpt_name=$2 +ips=$3 + +if [ ! $ips ];then + ips_config= +else + ips_config="--ips="${ips} +fi mkdir -p exp @@ -37,7 +44,7 @@ python3 -u ${BIN_DIR}/train.py \ --benchmark-batch-size ${benchmark_batch_size} \ --benchmark-max-step ${benchmark_max_step} else -python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${BIN_DIR}/train.py \ +python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \ --ngpu ${ngpu} \ --seed ${seed} \ --config ${config_path} \ diff --git a/examples/tiny/asr1/run.sh b/examples/tiny/asr1/run.sh index 1651c034c..ca0a7a013 100755 --- a/examples/tiny/asr1/run.sh +++ b/examples/tiny/asr1/run.sh @@ -2,10 +2,11 @@ set -e source path.sh -gpus=0 +gpus=4 stage=0 stop_stage=50 conf_path=conf/transformer.yaml +ips= #xx.xx.xx.xx,xx.xx.xx.xx decode_conf_path=conf/tuning/decode.yaml avg_num=1 @@ -22,7 +23,7 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then diff --git a/examples/vctk/tts3/README.md b/examples/vctk/tts3/README.md index f373ca6a3..0b0ce0934 100644 --- a/examples/vctk/tts3/README.md +++ b/examples/vctk/tts3/README.md @@ -3,7 +3,7 @@ This example contains code used to train a [Fastspeech2](https://arxiv.org/abs/2 ## Dataset ### Download and Extract the dataset -Download VCTK-0.92 from the [official website](https://datashare.ed.ac.uk/handle/10283/3443). +Download VCTK-0.92 from it's [Official Website](https://datashare.ed.ac.uk/handle/10283/3443) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/VCTK-Corpus-0.92`. ### Get MFA Result and Extract We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for fastspeech2. @@ -112,12 +112,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p ``` ```text usage: synthesize.py [-h] - [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] + [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--voice-cloning VOICE_CLONING] - [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] + [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_stat VOC_STAT] [--ngpu NGPU] [--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR] @@ -126,11 +126,10 @@ Synthesize with acoustic model & vocoder optional arguments: -h, --help show this help message and exit - --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} + --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3} Choose acoustic model type of tts task. --am_config AM_CONFIG - Config of acoustic model. Use deault config when it is - None. + Config of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_stat AM_STAT mean and standard deviation used to normalize spectrogram when training acoustic model. @@ -142,10 +141,10 @@ optional arguments: speaker id map file. --voice-cloning VOICE_CLONING whether training voice cloning model. - --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} + --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc} Choose vocoder type of tts task. --voc_config VOC_CONFIG - Config of voc. Use deault config when it is None. + Config of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_stat VOC_STAT mean and standard deviation used to normalize spectrogram when training voc. @@ -161,12 +160,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp ``` ```text usage: synthesize_e2e.py [-h] - [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] + [--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID] - [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] + [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_stat VOC_STAT] [--lang LANG] [--inference_dir INFERENCE_DIR] [--ngpu NGPU] @@ -176,11 +175,10 @@ Synthesize with acoustic model & vocoder optional arguments: -h, --help show this help message and exit - --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} + --am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech} Choose acoustic model type of tts task. --am_config AM_CONFIG - Config of acoustic model. Use deault config when it is - None. + Config of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_stat AM_STAT mean and standard deviation used to normalize spectrogram when training acoustic model. @@ -191,10 +189,10 @@ optional arguments: --speaker_dict SPEAKER_DICT speaker id map file. --spk_id SPK_ID spk id for multi speaker acoustic model - --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} + --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc} Choose vocoder type of tts task. --voc_config VOC_CONFIG - Config of voc. Use deault config when it is None. + Config of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_stat VOC_STAT mean and standard deviation used to normalize spectrogram when training voc. @@ -207,9 +205,9 @@ optional arguments: output dir. ``` 1. `--am` is acoustic model type with the format {model_name}_{dataset} -2. `--am_config`, `--am_checkpoint`, `--am_stat`, `--phones_dict` `--speaker_dict` are arguments for acoustic model, which correspond to the 5 files in the fastspeech2 pretrained model. +2. `--am_config`, `--am_ckpt`, `--am_stat`, `--phones_dict` `--speaker_dict` are arguments for acoustic model, which correspond to the 5 files in the fastspeech2 pretrained model. 3. `--voc` is vocoder type with the format {model_name}_{dataset} -4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model. +4. `--voc_config`, `--voc_ckpt`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model. 5. `--lang` is the model language, which can be `zh` or `en`. 6. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder. 7. `--text` is the text file, which contains sentences to synthesize. diff --git a/examples/vctk/voc1/README.md b/examples/vctk/voc1/README.md index 1c3016f88..a0e06a420 100644 --- a/examples/vctk/voc1/README.md +++ b/examples/vctk/voc1/README.md @@ -3,7 +3,7 @@ This example contains code used to train a [parallel wavegan](http://arxiv.org/a ## Dataset ### Download and Extract -Download VCTK-0.92 from the [official website](https://datashare.ed.ac.uk/handle/10283/3443) and extract it to `~/datasets`. Then the dataset is in directory `~/datasets/VCTK-Corpus-0.92`. +Download VCTK-0.92 from it's [Official Website](https://datashare.ed.ac.uk/handle/10283/3443) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/VCTK-Corpus-0.92`. ### Get MFA Result and Extract We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio. @@ -70,7 +70,7 @@ Train a ParallelWaveGAN model. optional arguments: -h, --help show this help message and exit - --config CONFIG config file to overwrite default config. + --config CONFIG ParallelWaveGAN config file. --train-metadata TRAIN_METADATA training data. --dev-metadata DEV_METADATA diff --git a/examples/vctk/voc5/README.md b/examples/vctk/voc5/README.md index 4eb25c02d..f2cbf27d2 100644 --- a/examples/vctk/voc5/README.md +++ b/examples/vctk/voc5/README.md @@ -3,7 +3,7 @@ This example contains code used to train a [HiFiGAN](https://arxiv.org/abs/2010. ## Dataset ### Download and Extract -Download VCTK-0.92 from the [official website](https://datashare.ed.ac.uk/handle/10283/3443) and extract it to `~/datasets`. Then the dataset is in directory `~/datasets/VCTK-Corpus-0.92`. +Download VCTK-0.92 from it's [Official Website](https://datashare.ed.ac.uk/handle/10283/3443) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/VCTK-Corpus-0.92`. ### Get MFA Result and Extract We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio. @@ -62,15 +62,13 @@ Here's the complete help message. ```text usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA] [--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR] - [--ngpu NGPU] [--batch-size BATCH_SIZE] [--max-iter MAX_ITER] - [--run-benchmark RUN_BENCHMARK] - [--profiler_options PROFILER_OPTIONS] + [--ngpu NGPU] -Train a ParallelWaveGAN model. +Train a HiFiGAN model. optional arguments: -h, --help show this help message and exit - --config CONFIG config file to overwrite default config. + --config CONFIG HiFiGAN config file. --train-metadata TRAIN_METADATA training data. --dev-metadata DEV_METADATA @@ -78,19 +76,6 @@ optional arguments: --output-dir OUTPUT_DIR output dir. --ngpu NGPU if ngpu == 0, use cpu. - -benchmark: - arguments related to benchmark. - - --batch-size BATCH_SIZE - batch size. - --max-iter MAX_ITER train max steps. - --run-benchmark RUN_BENCHMARK - runing benchmark or not, if True, use the --batch-size - and --max-iter. - --profiler_options PROFILER_OPTIONS - The option of profiler, which should be in format - "key1=value1;key2=value2;key3=value3". ``` 1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`. diff --git a/examples/voxceleb/sv0/README.md b/examples/voxceleb/sv0/README.md index 418102b4f..26c95aca9 100644 --- a/examples/voxceleb/sv0/README.md +++ b/examples/voxceleb/sv0/README.md @@ -141,11 +141,11 @@ using the `tar` scripts to unpack the model and then you can use the script to t For example: ``` -wget https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz -tar -xvf sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz +wget https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_1.tar.gz +tar -xvf sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_1.tar.gz source path.sh # If you have processed the data and get the manifest file, you can skip the following 2 steps -CUDA_VISIBLE_DEVICES= bash ./local/test.sh ./data sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_2/model/ conf/ecapa_tdnn.yaml +CUDA_VISIBLE_DEVICES= bash ./local/test.sh ./data sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_1/model/ conf/ecapa_tdnn.yaml ``` The performance of the released models are shown in [this](./RESULTS.md) diff --git a/examples/voxceleb/sv0/RESULT.md b/examples/voxceleb/sv0/RESULT.md index 3a3f67d09..56ee887c6 100644 --- a/examples/voxceleb/sv0/RESULT.md +++ b/examples/voxceleb/sv0/RESULT.md @@ -4,4 +4,4 @@ | Model | Number of Params | Release | Config | dim | Test set | Cosine | Cosine + S-Norm | | --- | --- | --- | --- | --- | --- | --- | ---- | -| ECAPA-TDNN | 85M | 0.2.0 | conf/ecapa_tdnn.yaml |192 | test | 1.02 | 0.95 | +| ECAPA-TDNN | 85M | 0.2.1 | conf/ecapa_tdnn.yaml | 192 | test | 0.8188 | 0.7815| diff --git a/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml b/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml index 3e3a13072..b7b71d77d 100644 --- a/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml +++ b/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml @@ -59,3 +59,11 @@ global_embedding_norm: True embedding_mean_norm: True embedding_std_norm: False +########################################### +# score-norm # +########################################### +score_norm: s-norm +cohort_size: 20000 # amount of imposter utterances in normalization cohort +n_train_snts: 400000 # used for normalization stats + + diff --git a/examples/voxceleb/sv0/conf/ecapa_tdnn_small.yaml b/examples/voxceleb/sv0/conf/ecapa_tdnn_small.yaml index 5925e5730..40498c874 100644 --- a/examples/voxceleb/sv0/conf/ecapa_tdnn_small.yaml +++ b/examples/voxceleb/sv0/conf/ecapa_tdnn_small.yaml @@ -58,3 +58,10 @@ global_embedding_norm: True embedding_mean_norm: True embedding_std_norm: False +########################################### +# score-norm # +########################################### +score_norm: s-norm +cohort_size: 20000 # amount of imposter utterances in normalization cohort +n_train_snts: 400000 # used for normalization stats + diff --git a/examples/voxceleb/sv0/local/data_prepare.py b/examples/voxceleb/sv0/local/data_prepare.py index b4486b6f0..e5a5dff7b 100644 --- a/examples/voxceleb/sv0/local/data_prepare.py +++ b/examples/voxceleb/sv0/local/data_prepare.py @@ -14,9 +14,9 @@ import argparse import paddle -from paddleaudio.datasets.voxceleb import VoxCeleb from yacs.config import CfgNode +from paddlespeech.audio.datasets.voxceleb import VoxCeleb from paddlespeech.s2t.utils.log import Log from paddlespeech.vector.io.augment import build_augment_pipeline from paddlespeech.vector.training.seeding import seed_everything diff --git a/examples/voxceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py b/examples/voxceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py index 0d0163f15..7ad9bd6ec 100644 --- a/examples/voxceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py +++ b/examples/voxceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py @@ -21,9 +21,9 @@ import os from typing import List import tqdm -from paddleaudio import load as load_audio from yacs.config import CfgNode +from paddlespeech.audio import load as load_audio from paddlespeech.s2t.utils.log import Log from paddlespeech.vector.utils.vector_utils import get_chunks diff --git a/examples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py b/examples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py index ffd0d212d..40adf53de 100644 --- a/examples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py +++ b/examples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py @@ -22,9 +22,9 @@ import os import random import tqdm -from paddleaudio import load as load_audio from yacs.config import CfgNode +from paddlespeech.audio import load as load_audio from paddlespeech.s2t.utils.log import Log from paddlespeech.vector.utils.vector_utils import get_chunks diff --git a/examples/wenetspeech/asr1/local/extract_meta.py b/examples/wenetspeech/asr1/local/extract_meta.py index 0e1b27278..954dbd780 100644 --- a/examples/wenetspeech/asr1/local/extract_meta.py +++ b/examples/wenetspeech/asr1/local/extract_meta.py @@ -1,18 +1,7 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. # Copyright 2021 Xiaomi Corporation (Author: Yongqing Wang) # Mobvoi Inc(Author: Di Wu, Binbin Zhang) +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/audio/paddleaudio/__init__.py b/paddlespeech/audio/__init__.py similarity index 100% rename from audio/paddleaudio/__init__.py rename to paddlespeech/audio/__init__.py diff --git a/audio/paddleaudio/backends/__init__.py b/paddlespeech/audio/backends/__init__.py similarity index 100% rename from audio/paddleaudio/backends/__init__.py rename to paddlespeech/audio/backends/__init__.py diff --git a/audio/paddleaudio/backends/soundfile_backend.py b/paddlespeech/audio/backends/soundfile_backend.py similarity index 100% rename from audio/paddleaudio/backends/soundfile_backend.py rename to paddlespeech/audio/backends/soundfile_backend.py diff --git a/audio/paddleaudio/backends/sox_backend.py b/paddlespeech/audio/backends/sox_backend.py similarity index 100% rename from audio/paddleaudio/backends/sox_backend.py rename to paddlespeech/audio/backends/sox_backend.py diff --git a/audio/paddleaudio/compliance/__init__.py b/paddlespeech/audio/compliance/__init__.py similarity index 100% rename from audio/paddleaudio/compliance/__init__.py rename to paddlespeech/audio/compliance/__init__.py diff --git a/audio/paddleaudio/compliance/kaldi.py b/paddlespeech/audio/compliance/kaldi.py similarity index 100% rename from audio/paddleaudio/compliance/kaldi.py rename to paddlespeech/audio/compliance/kaldi.py diff --git a/audio/paddleaudio/compliance/librosa.py b/paddlespeech/audio/compliance/librosa.py similarity index 100% rename from audio/paddleaudio/compliance/librosa.py rename to paddlespeech/audio/compliance/librosa.py diff --git a/audio/paddleaudio/datasets/__init__.py b/paddlespeech/audio/datasets/__init__.py similarity index 100% rename from audio/paddleaudio/datasets/__init__.py rename to paddlespeech/audio/datasets/__init__.py diff --git a/audio/paddleaudio/datasets/dataset.py b/paddlespeech/audio/datasets/dataset.py similarity index 100% rename from audio/paddleaudio/datasets/dataset.py rename to paddlespeech/audio/datasets/dataset.py diff --git a/audio/paddleaudio/datasets/esc50.py b/paddlespeech/audio/datasets/esc50.py similarity index 99% rename from audio/paddleaudio/datasets/esc50.py rename to paddlespeech/audio/datasets/esc50.py index e7477d40e..f5c7050f3 100644 --- a/audio/paddleaudio/datasets/esc50.py +++ b/paddlespeech/audio/datasets/esc50.py @@ -16,8 +16,8 @@ import os from typing import List from typing import Tuple +from ..utils import DATA_HOME from ..utils.download import download_and_decompress -from ..utils.env import DATA_HOME from .dataset import AudioClassificationDataset __all__ = ['ESC50'] diff --git a/audio/paddleaudio/datasets/gtzan.py b/paddlespeech/audio/datasets/gtzan.py similarity index 99% rename from audio/paddleaudio/datasets/gtzan.py rename to paddlespeech/audio/datasets/gtzan.py index cfea6f37e..1f6835a5a 100644 --- a/audio/paddleaudio/datasets/gtzan.py +++ b/paddlespeech/audio/datasets/gtzan.py @@ -17,8 +17,8 @@ import random from typing import List from typing import Tuple +from ..utils import DATA_HOME from ..utils.download import download_and_decompress -from ..utils.env import DATA_HOME from .dataset import AudioClassificationDataset __all__ = ['GTZAN'] diff --git a/audio/paddleaudio/datasets/hey_snips.py b/paddlespeech/audio/datasets/hey_snips.py similarity index 100% rename from audio/paddleaudio/datasets/hey_snips.py rename to paddlespeech/audio/datasets/hey_snips.py diff --git a/audio/paddleaudio/datasets/rirs_noises.py b/paddlespeech/audio/datasets/rirs_noises.py similarity index 100% rename from audio/paddleaudio/datasets/rirs_noises.py rename to paddlespeech/audio/datasets/rirs_noises.py diff --git a/audio/paddleaudio/datasets/tess.py b/paddlespeech/audio/datasets/tess.py similarity index 99% rename from audio/paddleaudio/datasets/tess.py rename to paddlespeech/audio/datasets/tess.py index 8faab9c39..1469fa5e2 100644 --- a/audio/paddleaudio/datasets/tess.py +++ b/paddlespeech/audio/datasets/tess.py @@ -17,8 +17,8 @@ import random from typing import List from typing import Tuple +from ..utils import DATA_HOME from ..utils.download import download_and_decompress -from ..utils.env import DATA_HOME from .dataset import AudioClassificationDataset __all__ = ['TESS'] diff --git a/audio/paddleaudio/datasets/urban_sound.py b/paddlespeech/audio/datasets/urban_sound.py similarity index 99% rename from audio/paddleaudio/datasets/urban_sound.py rename to paddlespeech/audio/datasets/urban_sound.py index d97c4d1dc..0389cd5f9 100644 --- a/audio/paddleaudio/datasets/urban_sound.py +++ b/paddlespeech/audio/datasets/urban_sound.py @@ -16,8 +16,8 @@ import os from typing import List from typing import Tuple +from ..utils import DATA_HOME from ..utils.download import download_and_decompress -from ..utils.env import DATA_HOME from .dataset import AudioClassificationDataset __all__ = ['UrbanSound8K'] diff --git a/audio/paddleaudio/datasets/voxceleb.py b/paddlespeech/audio/datasets/voxceleb.py similarity index 100% rename from audio/paddleaudio/datasets/voxceleb.py rename to paddlespeech/audio/datasets/voxceleb.py diff --git a/audio/paddleaudio/features/__init__.py b/paddlespeech/audio/features/__init__.py similarity index 100% rename from audio/paddleaudio/features/__init__.py rename to paddlespeech/audio/features/__init__.py diff --git a/audio/paddleaudio/features/layers.py b/paddlespeech/audio/features/layers.py similarity index 100% rename from audio/paddleaudio/features/layers.py rename to paddlespeech/audio/features/layers.py diff --git a/audio/paddleaudio/functional/__init__.py b/paddlespeech/audio/functional/__init__.py similarity index 100% rename from audio/paddleaudio/functional/__init__.py rename to paddlespeech/audio/functional/__init__.py diff --git a/audio/paddleaudio/functional/functional.py b/paddlespeech/audio/functional/functional.py similarity index 100% rename from audio/paddleaudio/functional/functional.py rename to paddlespeech/audio/functional/functional.py diff --git a/audio/paddleaudio/functional/window.py b/paddlespeech/audio/functional/window.py similarity index 100% rename from audio/paddleaudio/functional/window.py rename to paddlespeech/audio/functional/window.py diff --git a/audio/paddleaudio/io/__init__.py b/paddlespeech/audio/io/__init__.py similarity index 100% rename from audio/paddleaudio/io/__init__.py rename to paddlespeech/audio/io/__init__.py diff --git a/audio/paddleaudio/metric/__init__.py b/paddlespeech/audio/metric/__init__.py similarity index 95% rename from audio/paddleaudio/metric/__init__.py rename to paddlespeech/audio/metric/__init__.py index d2b3a1360..7ce6f5cff 100644 --- a/audio/paddleaudio/metric/__init__.py +++ b/paddlespeech/audio/metric/__init__.py @@ -11,6 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .dtw import dtw_distance from .eer import compute_eer from .eer import compute_minDCF diff --git a/audio/paddleaudio/metric/eer.py b/paddlespeech/audio/metric/eer.py similarity index 100% rename from audio/paddleaudio/metric/eer.py rename to paddlespeech/audio/metric/eer.py diff --git a/audio/paddleaudio/sox_effects/__init__.py b/paddlespeech/audio/sox_effects/__init__.py similarity index 100% rename from audio/paddleaudio/sox_effects/__init__.py rename to paddlespeech/audio/sox_effects/__init__.py diff --git a/audio/paddleaudio/utils/__init__.py b/paddlespeech/audio/utils/__init__.py similarity index 88% rename from audio/paddleaudio/utils/__init__.py rename to paddlespeech/audio/utils/__init__.py index afb9cedd8..742f9f8ef 100644 --- a/audio/paddleaudio/utils/__init__.py +++ b/paddlespeech/audio/utils/__init__.py @@ -11,13 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from ...cli.utils import DATA_HOME +from ...cli.utils import MODEL_HOME from .download import decompress from .download import download_and_decompress from .download import load_state_dict_from_url -from .env import DATA_HOME -from .env import MODEL_HOME -from .env import PPAUDIO_HOME -from .env import USER_HOME from .error import ParameterError from .log import Logger from .log import logger diff --git a/audio/paddleaudio/utils/download.py b/paddlespeech/audio/utils/download.py similarity index 100% rename from audio/paddleaudio/utils/download.py rename to paddlespeech/audio/utils/download.py diff --git a/audio/paddleaudio/utils/error.py b/paddlespeech/audio/utils/error.py similarity index 100% rename from audio/paddleaudio/utils/error.py rename to paddlespeech/audio/utils/error.py diff --git a/audio/paddleaudio/utils/log.py b/paddlespeech/audio/utils/log.py similarity index 100% rename from audio/paddleaudio/utils/log.py rename to paddlespeech/audio/utils/log.py diff --git a/audio/paddleaudio/utils/numeric.py b/paddlespeech/audio/utils/numeric.py similarity index 100% rename from audio/paddleaudio/utils/numeric.py rename to paddlespeech/audio/utils/numeric.py diff --git a/audio/paddleaudio/utils/time.py b/paddlespeech/audio/utils/time.py similarity index 100% rename from audio/paddleaudio/utils/time.py rename to paddlespeech/audio/utils/time.py diff --git a/paddlespeech/cli/__init__.py b/paddlespeech/cli/__init__.py index ddf0359bc..ca6993f2b 100644 --- a/paddlespeech/cli/__init__.py +++ b/paddlespeech/cli/__init__.py @@ -13,14 +13,7 @@ # limitations under the License. import _locale -from .asr import ASRExecutor from .base_commands import BaseCommand from .base_commands import HelpCommand -from .cls import CLSExecutor -from .st import STExecutor -from .stats import StatsExecutor -from .text import TextExecutor -from .tts import TTSExecutor -from .vector import VectorExecutor _locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8']) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 863a933f2..00cad150e 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -29,30 +29,21 @@ from yacs.config import CfgNode from ..download import get_path_from_url from ..executor import BaseExecutor from ..log import logger -from ..utils import cli_register from ..utils import CLI_TIMER from ..utils import MODEL_HOME from ..utils import stats_wrapper from ..utils import timer_register -from .pretrained_models import model_alias -from .pretrained_models import pretrained_models from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.transform.transformation import Transformation -from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.utility import UpdateConfig __all__ = ['ASRExecutor'] @timer_register -@cli_register( - name='paddlespeech.asr', description='Speech to text infer command.') class ASRExecutor(BaseExecutor): def __init__(self): - super().__init__() - self.model_alias = model_alias - self.pretrained_models = pretrained_models - + super().__init__(task='asr', inference_type='offline') self.parser = argparse.ArgumentParser( prog='paddlespeech.asr', add_help=True) self.parser.add_argument( @@ -62,7 +53,8 @@ class ASRExecutor(BaseExecutor): type=str, default='conformer_wenetspeech', choices=[ - tag[:tag.index('-')] for tag in self.pretrained_models.keys() + tag[:tag.index('-')] + for tag in self.task_resource.pretrained_models.keys() ], help='Choose model type of asr task.') self.parser.add_argument( @@ -91,6 +83,12 @@ class ASRExecutor(BaseExecutor): 'attention_rescoring' ], help='only support transformer and conformer model') + self.parser.add_argument( + '--num_decoding_left_chunks', + '-num_left', + type=str, + default=-1, + help='only support transformer and conformer online model') self.parser.add_argument( '--ckpt_path', type=str, @@ -130,11 +128,14 @@ class ASRExecutor(BaseExecutor): sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, decode_method: str='attention_rescoring', + num_decoding_left_chunks: int=-1, ckpt_path: Optional[os.PathLike]=None): """ Init model and other resources from a specific path. """ logger.info("start to init the model") + # default max_len: unit:second + self.max_len = 50 if hasattr(self, 'model'): logger.info('Model had been initialized.') return @@ -142,14 +143,15 @@ class ASRExecutor(BaseExecutor): if cfg_path is None or ckpt_path is None: sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str - res_path = self._get_pretrained_path(tag) # wenetspeech_zh - self.res_path = res_path + self.task_resource.set_task_model(tag, version=None) + self.res_path = self.task_resource.res_dir + self.cfg_path = os.path.join( - res_path, self.pretrained_models[tag]['cfg_path']) + self.res_path, self.task_resource.res_dict['cfg_path']) self.ckpt_path = os.path.join( - res_path, - self.pretrained_models[tag]['ckpt_path'] + ".pdparams") - logger.info(res_path) + self.res_path, + self.task_resource.res_dict['ckpt_path'] + ".pdparams") + logger.info(self.res_path) else: self.cfg_path = os.path.abspath(cfg_path) @@ -164,35 +166,35 @@ class ASRExecutor(BaseExecutor): self.config.merge_from_file(self.cfg_path) with UpdateConfig(self.config): - if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: - from paddlespeech.s2t.io.collator import SpeechCollator - self.vocab = self.config.vocab_filepath + if self.config.spm_model_prefix: + self.config.spm_model_prefix = os.path.join( + self.res_path, self.config.spm_model_prefix) + self.text_feature = TextFeaturizer( + unit_type=self.config.unit_type, + vocab=self.config.vocab_filepath, + spm_model_prefix=self.config.spm_model_prefix) + if "deepspeech2" in model_type: self.config.decode.lang_model_path = os.path.join( MODEL_HOME, 'language_model', self.config.decode.lang_model_path) - self.collate_fn_test = SpeechCollator.from_config(self.config) - self.text_feature = TextFeaturizer( - unit_type=self.config.unit_type, vocab=self.vocab) - lm_url = self.pretrained_models[tag]['lm_url'] - lm_md5 = self.pretrained_models[tag]['lm_md5'] + + lm_url = self.task_resource.res_dict['lm_url'] + lm_md5 = self.task_resource.res_dict['lm_md5'] self.download_lm( lm_url, os.path.dirname(self.config.decode.lang_model_path), lm_md5) - elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: - self.config.spm_model_prefix = os.path.join( - self.res_path, self.config.spm_model_prefix) - self.text_feature = TextFeaturizer( - unit_type=self.config.unit_type, - vocab=self.config.vocab_filepath, - spm_model_prefix=self.config.spm_model_prefix) + elif "conformer" in model_type or "transformer" in model_type: self.config.decode.decoding_method = decode_method + if num_decoding_left_chunks: + assert num_decoding_left_chunks == -1 or num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0" + self.config.num_decoding_left_chunks = num_decoding_left_chunks else: raise Exception("wrong type") model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} - model_class = dynamic_import(model_name, self.model_alias) + model_class = self.task_resource.get_model_class(model_name) model_conf = self.config model = model_class.from_config(model_conf) self.model = model @@ -203,7 +205,7 @@ class ASRExecutor(BaseExecutor): self.model.set_state_dict(model_dict) # compute the max len limit - if "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: + if "conformer" in model_type or "transformer" in model_type: # in transformer like model, we may use the subsample rate cnn network subsample_rate = self.model.subsampling_rate() frame_shift_ms = self.config.preprocess_config.process[0][ @@ -228,19 +230,7 @@ class ASRExecutor(BaseExecutor): logger.info("Preprocess audio_file:" + audio_file) # Get the object for feature extraction - if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: - audio, _ = self.collate_fn_test.process_utterance( - audio_file=audio_file, transcript=" ") - audio_len = audio.shape[0] - audio = paddle.to_tensor(audio, dtype='float32') - audio_len = paddle.to_tensor(audio_len) - audio = paddle.unsqueeze(audio, axis=0) - # vocab_list = collate_fn_test.vocab_list - self._inputs["audio"] = audio - self._inputs["audio_len"] = audio_len - logger.info(f"audio feat shape: {audio.shape}") - - elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: + if "deepspeech2" in model_type or "conformer" in model_type or "transformer" in model_type: logger.info("get the preprocess conf") preprocess_conf = self.config.preprocess_config preprocess_args = {"train": False} @@ -248,7 +238,6 @@ class ASRExecutor(BaseExecutor): logger.info("read the audio file") audio, audio_sample_rate = soundfile.read( audio_file, dtype="int16", always_2d=True) - if self.change_format: if audio.shape[1] >= 2: audio = audio.mean(axis=1, dtype=np.int16) @@ -291,7 +280,7 @@ class ASRExecutor(BaseExecutor): cfg = self.config.decode audio = self._inputs["audio"] audio_len = self._inputs["audio_len"] - if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: + if "deepspeech2" in model_type: decode_batch_size = audio.shape[0] self.model.decoder.init_decoder( decode_batch_size, self.text_feature.vocab_list, @@ -439,7 +428,7 @@ class ASRExecutor(BaseExecutor): if not parser_args.verbose: self.disable_task_loggers() - task_source = self.get_task_source(parser_args.input) + task_source = self.get_input_source(parser_args.input) task_results = OrderedDict() has_exceptions = False @@ -472,6 +461,7 @@ class ASRExecutor(BaseExecutor): config: os.PathLike=None, ckpt_path: os.PathLike=None, decode_method: str='attention_rescoring', + num_decoding_left_chunks: int=-1, force_yes: bool=False, rtf: bool=False, device=paddle.get_device()): @@ -479,11 +469,11 @@ class ASRExecutor(BaseExecutor): Python API to call an executor. """ audio_file = os.path.abspath(audio_file) - if not self._check(audio_file, sample_rate, force_yes): - sys.exit(-1) paddle.set_device(device) self._init_from_path(model, lang, sample_rate, config, decode_method, - ckpt_path) + num_decoding_left_chunks, ckpt_path) + if not self._check(audio_file, sample_rate, force_yes): + sys.exit(-1) if rtf: k = self.__class__.__name__ CLI_TIMER[k]['start'].append(time.time()) diff --git a/paddlespeech/cli/asr/pretrained_models.py b/paddlespeech/cli/asr/pretrained_models.py deleted file mode 100644 index 0f5218840..000000000 --- a/paddlespeech/cli/asr/pretrained_models.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -pretrained_models = { - # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". - # e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k". - # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: - # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" - "conformer_wenetspeech-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz', - 'md5': - '76cb19ed857e6623856b7cd7ebbfeda4', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/conformer/checkpoints/wenetspeech', - }, - "conformer_online_wenetspeech-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz', - 'md5': - 'b8c02632b04da34aca88459835be54a6', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/chunk_conformer/checkpoints/avg_10', - }, - "conformer_online_multicn-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.0.model.tar.gz', - 'md5': - '7989b3248c898070904cf042fd656003', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/chunk_conformer/checkpoints/multi_cn', - }, - "conformer_aishell-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_0.1.2.model.tar.gz', - 'md5': - '3f073eccfa7bb14e0c6867d65fc0dc3a', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/conformer/checkpoints/avg_30', - }, - "conformer_online_aishell-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz', - 'md5': - 'b374cfb93537761270b6224fb0bfc26a', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/chunk_conformer/checkpoints/avg_30', - }, - "transformer_librispeech-en-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz', - 'md5': - '2c667da24922aad391eacafe37bc1660', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/transformer/checkpoints/avg_10', - }, - "deepspeech2online_wenetspeech-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz', - 'md5': - 'e393d4d274af0f6967db24fc146e8074', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/deepspeech2_online/checkpoints/avg_10', - 'lm_url': - 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5': - '29e02312deb2e59b3c8686c7966d4fe3' - }, - "deepspeech2offline_aishell-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz', - 'md5': - '932c3593d62fe5c741b59b31318aa314', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/deepspeech2/checkpoints/avg_1', - 'lm_url': - 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5': - '29e02312deb2e59b3c8686c7966d4fe3' - }, - "deepspeech2online_aishell-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz', - 'md5': - '98b87b171b7240b7cae6e07d8d0bc9be', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/deepspeech2_online/checkpoints/avg_1', - 'lm_url': - 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5': - '29e02312deb2e59b3c8686c7966d4fe3' - }, - "deepspeech2offline_librispeech-en-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz', - 'md5': - 'f5666c81ad015c8de03aac2bc92e5762', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/deepspeech2/checkpoints/avg_1', - 'lm_url': - 'https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm', - 'lm_md5': - '099a601759d467cd0a8523ff939819c5' - }, -} - -model_alias = { - "deepspeech2offline": - "paddlespeech.s2t.models.ds2:DeepSpeech2Model", - "deepspeech2online": - "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", - "conformer": - "paddlespeech.s2t.models.u2:U2Model", - "conformer_online": - "paddlespeech.s2t.models.u2:U2Model", - "transformer": - "paddlespeech.s2t.models.u2:U2Model", - "wenetspeech": - "paddlespeech.s2t.models.u2:U2Model", -} diff --git a/paddlespeech/cli/base_commands.py b/paddlespeech/cli/base_commands.py index 0a26b1203..f5e2246d8 100644 --- a/paddlespeech/cli/base_commands.py +++ b/paddlespeech/cli/base_commands.py @@ -11,16 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import argparse from typing import List +from prettytable import PrettyTable + +from ..resource import CommonTaskResource from .entry import commands from .utils import cli_register +from .utils import explicit_command_register from .utils import get_command -__all__ = [ - 'BaseCommand', - 'HelpCommand', -] +__all__ = ['BaseCommand', 'HelpCommand', 'StatsCommand'] @cli_register(name='paddlespeech') @@ -73,3 +75,73 @@ class VersionCommand: print(msg) return True + + +model_name_format = { + 'asr': 'Model-Language-Sample Rate', + 'cls': 'Model-Sample Rate', + 'st': 'Model-Source language-Target language', + 'text': 'Model-Task-Language', + 'tts': 'Model-Language', + 'vector': 'Model-Sample Rate' +} + + +@cli_register( + name='paddlespeech.stats', + description='Get speech tasks support models list.') +class StatsCommand: + def __init__(self): + self.parser = argparse.ArgumentParser( + prog='paddlespeech.stats', add_help=True) + self.task_choices = ['asr', 'cls', 'st', 'text', 'tts', 'vector'] + self.parser.add_argument( + '--task', + type=str, + default='asr', + choices=self.task_choices, + help='Choose speech task.', + required=True) + + def show_support_models(self, pretrained_models: dict): + fields = model_name_format[self.task].split("-") + table = PrettyTable(fields) + for key in pretrained_models: + table.add_row(key.split("-")) + print(table) + + def execute(self, argv: List[str]) -> bool: + parser_args = self.parser.parse_args(argv) + self.task = parser_args.task + if self.task not in self.task_choices: + print("Please input correct speech task, choices = " + str( + self.task_choices)) + return + + pretrained_models = CommonTaskResource(task=self.task).pretrained_models + + try: + print( + "Here is the list of {} pretrained models released by PaddleSpeech that can be used by command line and python API" + .format(self.task.upper())) + self.show_support_models(pretrained_models) + except BaseException: + print("Failed to get the list of {} pretrained models.".format( + self.task.upper())) + + +# Dynamic import when running specific command +_commands = { + 'asr': ['Speech to text infer command.', 'ASRExecutor'], + 'cls': ['Audio classification infer command.', 'CLSExecutor'], + 'st': ['Speech translation infer command.', 'STExecutor'], + 'text': ['Text command.', 'TextExecutor'], + 'tts': ['Text to Speech infer command.', 'TTSExecutor'], + 'vector': ['Speech to vector embedding infer command.', 'VectorExecutor'], +} + +for com, info in _commands.items(): + explicit_command_register( + name='paddlespeech.{}'.format(com), + description=info[0], + cls='paddlespeech.cli.{}.{}'.format(com, info[1])) diff --git a/paddlespeech/cli/cls/infer.py b/paddlespeech/cli/cls/infer.py index 8b90f1244..942dc3b92 100644 --- a/paddlespeech/cli/cls/infer.py +++ b/paddlespeech/cli/cls/infer.py @@ -21,28 +21,19 @@ from typing import Union import numpy as np import paddle import yaml -from paddleaudio import load -from paddleaudio.features import LogMelSpectrogram from ..executor import BaseExecutor from ..log import logger -from ..utils import cli_register from ..utils import stats_wrapper -from .pretrained_models import model_alias -from .pretrained_models import pretrained_models -from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.audio import load +from paddlespeech.audio.features import LogMelSpectrogram __all__ = ['CLSExecutor'] -@cli_register( - name='paddlespeech.cls', description='Audio classification infer command.') class CLSExecutor(BaseExecutor): def __init__(self): - super().__init__() - self.model_alias = model_alias - self.pretrained_models = pretrained_models - + super().__init__(task='cls') self.parser = argparse.ArgumentParser( prog='paddlespeech.cls', add_help=True) self.parser.add_argument( @@ -52,7 +43,8 @@ class CLSExecutor(BaseExecutor): type=str, default='panns_cnn14', choices=[ - tag[:tag.index('-')] for tag in self.pretrained_models.keys() + tag[:tag.index('-')] + for tag in self.task_resource.pretrained_models.keys() ], help='Choose model type of cls task.') self.parser.add_argument( @@ -105,13 +97,16 @@ class CLSExecutor(BaseExecutor): if label_file is None or ckpt_path is None: tag = model_type + '-' + '32k' # panns_cnn14-32k - self.res_path = self._get_pretrained_path(tag) + self.task_resource.set_task_model(tag, version=None) self.cfg_path = os.path.join( - self.res_path, self.pretrained_models[tag]['cfg_path']) + self.task_resource.res_dir, + self.task_resource.res_dict['cfg_path']) self.label_file = os.path.join( - self.res_path, self.pretrained_models[tag]['label_file']) + self.task_resource.res_dir, + self.task_resource.res_dict['label_file']) self.ckpt_path = os.path.join( - self.res_path, self.pretrained_models[tag]['ckpt_path']) + self.task_resource.res_dir, + self.task_resource.res_dict['ckpt_path']) else: self.cfg_path = os.path.abspath(cfg_path) self.label_file = os.path.abspath(label_file) @@ -128,7 +123,7 @@ class CLSExecutor(BaseExecutor): self._label_list.append(line.strip()) # model - model_class = dynamic_import(model_type, self.model_alias) + model_class = self.task_resource.get_model_class(model_type) model_dict = paddle.load(self.ckpt_path) self.model = model_class(extract_embedding=False) self.model.set_state_dict(model_dict) @@ -205,7 +200,7 @@ class CLSExecutor(BaseExecutor): if not parser_args.verbose: self.disable_task_loggers() - task_source = self.get_task_source(parser_args.input) + task_source = self.get_input_source(parser_args.input) task_results = OrderedDict() has_exceptions = False diff --git a/paddlespeech/cli/cls/pretrained_models.py b/paddlespeech/cli/cls/pretrained_models.py deleted file mode 100644 index 1d66850aa..000000000 --- a/paddlespeech/cli/cls/pretrained_models.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -pretrained_models = { - # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". - # e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k". - # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: - # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" - "panns_cnn6-32k": { - 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz', - 'md5': '4cf09194a95df024fd12f84712cf0f9c', - 'cfg_path': 'panns.yaml', - 'ckpt_path': 'cnn6.pdparams', - 'label_file': 'audioset_labels.txt', - }, - "panns_cnn10-32k": { - 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz', - 'md5': 'cb8427b22176cc2116367d14847f5413', - 'cfg_path': 'panns.yaml', - 'ckpt_path': 'cnn10.pdparams', - 'label_file': 'audioset_labels.txt', - }, - "panns_cnn14-32k": { - 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz', - 'md5': 'e3b9b5614a1595001161d0ab95edee97', - 'cfg_path': 'panns.yaml', - 'ckpt_path': 'cnn14.pdparams', - 'label_file': 'audioset_labels.txt', - }, -} - -model_alias = { - "panns_cnn6": "paddlespeech.cls.models.panns:CNN6", - "panns_cnn10": "paddlespeech.cls.models.panns:CNN10", - "panns_cnn14": "paddlespeech.cls.models.panns:CNN14", -} diff --git a/paddlespeech/cli/download.py b/paddlespeech/cli/download.py index 0f09b6fad..ec7258747 100644 --- a/paddlespeech/cli/download.py +++ b/paddlespeech/cli/download.py @@ -86,7 +86,7 @@ def get_path_from_url(url, str: a local path to save downloaded models & weights & datasets. """ - from paddle.fluid.dygraph.parallel import ParallelEnv + from paddle.distributed import ParallelEnv assert _is_url(url), "downloading from {} not a url".format(url) # parse path after download to decompress under root_dir diff --git a/paddlespeech/cli/entry.py b/paddlespeech/cli/entry.py index 32123ece7..e0c306d62 100644 --- a/paddlespeech/cli/entry.py +++ b/paddlespeech/cli/entry.py @@ -34,6 +34,11 @@ def _execute(): # The method 'execute' of a command instance returns 'True' for a success # while 'False' for a failure. Here converts this result into a exit status # in bash: 0 for a success and 1 for a failure. + if not callable(com['_entry']): + i = com['_entry'].rindex('.') + module, cls = com['_entry'][:i], com['_entry'][i + 1:] + exec("from {} import {}".format(module, cls)) + com['_entry'] = locals()[cls] status = 0 if com['_entry']().execute(sys.argv[idx:]) else 1 return status diff --git a/paddlespeech/cli/executor.py b/paddlespeech/cli/executor.py index 4a631c7f5..d390f947d 100644 --- a/paddlespeech/cli/executor.py +++ b/paddlespeech/cli/executor.py @@ -24,9 +24,8 @@ from typing import Union import paddle +from ..resource import CommonTaskResource from .log import logger -from .utils import download_and_decompress -from .utils import MODEL_HOME class BaseExecutor(ABC): @@ -34,11 +33,10 @@ class BaseExecutor(ABC): An abstract executor of paddlespeech tasks. """ - def __init__(self): + def __init__(self, task: str, **kwargs): self._inputs = OrderedDict() self._outputs = OrderedDict() - self.pretrained_models = OrderedDict() - self.model_alias = OrderedDict() + self.task_resource = CommonTaskResource(task=task, **kwargs) @abstractmethod def _init_from_path(self, *args, **kwargs): @@ -98,8 +96,8 @@ class BaseExecutor(ABC): """ pass - def get_task_source(self, input_: Union[str, os.PathLike, None] - ) -> Dict[str, Union[str, os.PathLike]]: + def get_input_source(self, input_: Union[str, os.PathLike, None] + ) -> Dict[str, Union[str, os.PathLike]]: """ Get task input source from command line input. @@ -115,15 +113,17 @@ class BaseExecutor(ABC): ret = OrderedDict() if input_ is None: # Take input from stdin - for i, line in enumerate(sys.stdin): - line = line.strip() - if len(line.split(' ')) == 1: - ret[str(i + 1)] = line - elif len(line.split(' ')) == 2: - id_, info = line.split(' ') - ret[id_] = info - else: # No valid input info from one line. - continue + if not sys.stdin.isatty( + ): # Avoid getting stuck when stdin is empty. + for i, line in enumerate(sys.stdin): + line = line.strip() + if len(line.split(' ')) == 1: + ret[str(i + 1)] = line + elif len(line.split(' ')) == 2: + id_, info = line.split(' ') + ret[id_] = info + else: # No valid input info from one line. + continue else: ret[1] = input_ return ret @@ -219,23 +219,6 @@ class BaseExecutor(ABC): for l in loggers: l.disabled = True - def _get_pretrained_path(self, tag: str) -> os.PathLike: - """ - Download and returns pretrained resources path of current task. - """ - support_models = list(self.pretrained_models.keys()) - assert tag in self.pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( - tag, '\n\t\t'.join(support_models)) - - res_path = os.path.join(MODEL_HOME, tag) - decompressed_path = download_and_decompress(self.pretrained_models[tag], - res_path) - decompressed_path = os.path.abspath(decompressed_path) - logger.info( - 'Use pretrained model stored in: {}'.format(decompressed_path)) - - return decompressed_path - def show_rtf(self, info: Dict[str, List[float]]): """ Calculate rft of current task and show results. diff --git a/paddlespeech/cli/st/infer.py b/paddlespeech/cli/st/infer.py index 29d95f799..e1ce181af 100644 --- a/paddlespeech/cli/st/infer.py +++ b/paddlespeech/cli/st/infer.py @@ -28,27 +28,25 @@ from yacs.config import CfgNode from ..executor import BaseExecutor from ..log import logger -from ..utils import cli_register from ..utils import download_and_decompress from ..utils import MODEL_HOME from ..utils import stats_wrapper -from .pretrained_models import kaldi_bins -from .pretrained_models import model_alias -from .pretrained_models import pretrained_models from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer -from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.utility import UpdateConfig __all__ = ["STExecutor"] +kaldi_bins = { + "url": + "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/kaldi_bins.tar.gz", + "md5": + "c0682303b3f3393dbf6ed4c4e35a53eb", +} + -@cli_register( - name="paddlespeech.st", description="Speech translation infer command.") class STExecutor(BaseExecutor): def __init__(self): - super().__init__() - self.model_alias = model_alias - self.pretrained_models = pretrained_models + super().__init__(task='st') self.kaldi_bins = kaldi_bins self.parser = argparse.ArgumentParser( @@ -60,7 +58,8 @@ class STExecutor(BaseExecutor): type=str, default="fat_st_ted", choices=[ - tag[:tag.index('-')] for tag in self.pretrained_models.keys() + tag[:tag.index('-')] + for tag in self.task_resource.pretrained_models.keys() ], help="Choose model type of st task.") self.parser.add_argument( @@ -134,14 +133,16 @@ class STExecutor(BaseExecutor): if cfg_path is None or ckpt_path is None: tag = model_type + "-" + src_lang + "-" + tgt_lang - res_path = self._get_pretrained_path(tag) - self.cfg_path = os.path.join(res_path, - pretrained_models[tag]["cfg_path"]) - self.ckpt_path = os.path.join(res_path, - pretrained_models[tag]["ckpt_path"]) - logger.info(res_path) + self.task_resource.set_task_model(tag, version=None) + self.cfg_path = os.path.join( + self.task_resource.res_dir, + self.task_resource.res_dict['cfg_path']) + self.ckpt_path = os.path.join( + self.task_resource.res_dir, + self.task_resource.res_dict['ckpt_path']) logger.info(self.cfg_path) logger.info(self.ckpt_path) + res_path = self.task_resource.res_dir else: self.cfg_path = os.path.abspath(cfg_path) self.ckpt_path = os.path.abspath(ckpt_path) @@ -166,7 +167,7 @@ class STExecutor(BaseExecutor): model_conf = self.config model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} - model_class = dynamic_import(model_name, self.model_alias) + model_class = self.task_resource.get_model_class(model_name) self.model = model_class.from_config(model_conf) self.model.eval() @@ -304,7 +305,7 @@ class STExecutor(BaseExecutor): if not parser_args.verbose: self.disable_task_loggers() - task_source = self.get_task_source(parser_args.input) + task_source = self.get_input_source(parser_args.input) task_results = OrderedDict() has_exceptions = False diff --git a/paddlespeech/cli/st/pretrained_models.py b/paddlespeech/cli/st/pretrained_models.py deleted file mode 100644 index cc7410d25..000000000 --- a/paddlespeech/cli/st/pretrained_models.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -pretrained_models = { - "fat_st_ted-en-zh": { - "url": - "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/st1_transformer_mtl_noam_ted-en-zh_ckpt_0.1.1.model.tar.gz", - "md5": - "d62063f35a16d91210a71081bd2dd557", - "cfg_path": - "model.yaml", - "ckpt_path": - "exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams", - } -} - -model_alias = {"fat_st": "paddlespeech.s2t.models.u2_st:U2STModel"} - -kaldi_bins = { - "url": - "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/kaldi_bins.tar.gz", - "md5": - "c0682303b3f3393dbf6ed4c4e35a53eb", -} diff --git a/paddlespeech/cli/stats/infer.py b/paddlespeech/cli/stats/infer.py deleted file mode 100644 index 7cf4f2368..000000000 --- a/paddlespeech/cli/stats/infer.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -from typing import List - -from prettytable import PrettyTable - -from ..utils import cli_register -from ..utils import stats_wrapper - -__all__ = ['StatsExecutor'] - -model_name_format = { - 'asr': 'Model-Language-Sample Rate', - 'cls': 'Model-Sample Rate', - 'st': 'Model-Source language-Target language', - 'text': 'Model-Task-Language', - 'tts': 'Model-Language', - 'vector': 'Model-Sample Rate' -} - - -@cli_register( - name='paddlespeech.stats', - description='Get speech tasks support models list.') -class StatsExecutor(): - def __init__(self): - super().__init__() - - self.parser = argparse.ArgumentParser( - prog='paddlespeech.stats', add_help=True) - self.task_choices = ['asr', 'cls', 'st', 'text', 'tts', 'vector'] - self.parser.add_argument( - '--task', - type=str, - default='asr', - choices=self.task_choices, - help='Choose speech task.', - required=True) - - def show_support_models(self, pretrained_models: dict): - fields = model_name_format[self.task].split("-") - table = PrettyTable(fields) - for key in pretrained_models: - table.add_row(key.split("-")) - print(table) - - def execute(self, argv: List[str]) -> bool: - """ - Command line entry. - """ - parser_args = self.parser.parse_args(argv) - has_exceptions = False - try: - self(parser_args.task) - except Exception as e: - has_exceptions = True - if has_exceptions: - return False - else: - return True - - @stats_wrapper - def __call__( - self, - task: str=None, ): - """ - Python API to call an executor. - """ - self.task = task - if self.task not in self.task_choices: - print("Please input correct speech task, choices = " + str( - self.task_choices)) - - elif self.task == 'asr': - try: - from ..asr.pretrained_models import pretrained_models - print( - "Here is the list of ASR pretrained models released by PaddleSpeech that can be used by command line and python API" - ) - self.show_support_models(pretrained_models) - except BaseException: - print("Failed to get the list of ASR pretrained models.") - - elif self.task == 'cls': - try: - from ..cls.pretrained_models import pretrained_models - print( - "Here is the list of CLS pretrained models released by PaddleSpeech that can be used by command line and python API" - ) - self.show_support_models(pretrained_models) - except BaseException: - print("Failed to get the list of CLS pretrained models.") - - elif self.task == 'st': - try: - from ..st.pretrained_models import pretrained_models - print( - "Here is the list of ST pretrained models released by PaddleSpeech that can be used by command line and python API" - ) - self.show_support_models(pretrained_models) - except BaseException: - print("Failed to get the list of ST pretrained models.") - - elif self.task == 'text': - try: - from ..text.pretrained_models import pretrained_models - print( - "Here is the list of TEXT pretrained models released by PaddleSpeech that can be used by command line and python API" - ) - self.show_support_models(pretrained_models) - except BaseException: - print("Failed to get the list of TEXT pretrained models.") - - elif self.task == 'tts': - try: - from ..tts.pretrained_models import pretrained_models - print( - "Here is the list of TTS pretrained models released by PaddleSpeech that can be used by command line and python API" - ) - self.show_support_models(pretrained_models) - except BaseException: - print("Failed to get the list of TTS pretrained models.") - - elif self.task == 'vector': - try: - from ..vector.pretrained_models import pretrained_models - print( - "Here is the list of Speaker Recognition pretrained models released by PaddleSpeech that can be used by command line and python API" - ) - self.show_support_models(pretrained_models) - except BaseException: - print( - "Failed to get the list of Speaker Recognition pretrained models." - ) diff --git a/paddlespeech/cli/text/infer.py b/paddlespeech/cli/text/infer.py index 69e62e4b4..7b8faf99c 100644 --- a/paddlespeech/cli/text/infer.py +++ b/paddlespeech/cli/text/infer.py @@ -21,26 +21,16 @@ from typing import Union import paddle -from ...s2t.utils.dynamic_import import dynamic_import from ..executor import BaseExecutor from ..log import logger -from ..utils import cli_register from ..utils import stats_wrapper -from .pretrained_models import model_alias -from .pretrained_models import pretrained_models -from .pretrained_models import tokenizer_alias __all__ = ['TextExecutor'] -@cli_register(name='paddlespeech.text', description='Text infer command.') class TextExecutor(BaseExecutor): def __init__(self): - super().__init__() - self.model_alias = model_alias - self.pretrained_models = pretrained_models - self.tokenizer_alias = tokenizer_alias - + super().__init__(task='text') self.parser = argparse.ArgumentParser( prog='paddlespeech.text', add_help=True) self.parser.add_argument( @@ -56,7 +46,8 @@ class TextExecutor(BaseExecutor): type=str, default='ernie_linear_p7_wudao', choices=[ - tag[:tag.index('-')] for tag in self.pretrained_models.keys() + tag[:tag.index('-')] + for tag in self.task_resource.pretrained_models.keys() ], help='Choose model type of text task.') self.parser.add_argument( @@ -114,13 +105,16 @@ class TextExecutor(BaseExecutor): if cfg_path is None or ckpt_path is None or vocab_file is None: tag = '-'.join([model_type, task, lang]) - self.res_path = self._get_pretrained_path(tag) + self.task_resource.set_task_model(tag, version=None) self.cfg_path = os.path.join( - self.res_path, self.pretrained_models[tag]['cfg_path']) + self.task_resource.res_dir, + self.task_resource.res_dict['cfg_path']) self.ckpt_path = os.path.join( - self.res_path, self.pretrained_models[tag]['ckpt_path']) + self.task_resource.res_dir, + self.task_resource.res_dict['ckpt_path']) self.vocab_file = os.path.join( - self.res_path, self.pretrained_models[tag]['vocab_file']) + self.task_resource.res_dir, + self.task_resource.res_dict['vocab_file']) else: self.cfg_path = os.path.abspath(cfg_path) self.ckpt_path = os.path.abspath(ckpt_path) @@ -135,8 +129,8 @@ class TextExecutor(BaseExecutor): self._punc_list.append(line.strip()) # model - model_class = dynamic_import(model_name, self.model_alias) - tokenizer_class = dynamic_import(model_name, self.tokenizer_alias) + model_class, tokenizer_class = self.task_resource.get_model_class( + model_name) self.model = model_class( cfg_path=self.cfg_path, ckpt_path=self.ckpt_path) self.tokenizer = tokenizer_class.from_pretrained('ernie-1.0') @@ -226,7 +220,7 @@ class TextExecutor(BaseExecutor): if not parser_args.verbose: self.disable_task_loggers() - task_source = self.get_task_source(parser_args.input) + task_source = self.get_input_source(parser_args.input) task_results = OrderedDict() has_exceptions = False diff --git a/paddlespeech/cli/text/pretrained_models.py b/paddlespeech/cli/text/pretrained_models.py deleted file mode 100644 index 817d3caa3..000000000 --- a/paddlespeech/cli/text/pretrained_models.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -pretrained_models = { - # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". - # e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k". - # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: - # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" - "ernie_linear_p7_wudao-punc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p7_wudao-punc-zh.tar.gz', - 'md5': - '12283e2ddde1797c5d1e57036b512746', - 'cfg_path': - 'ckpt/model_config.json', - 'ckpt_path': - 'ckpt/model_state.pdparams', - 'vocab_file': - 'punc_vocab.txt', - }, - "ernie_linear_p3_wudao-punc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_wudao-punc-zh.tar.gz', - 'md5': - '448eb2fdf85b6a997e7e652e80c51dd2', - 'cfg_path': - 'ckpt/model_config.json', - 'ckpt_path': - 'ckpt/model_state.pdparams', - 'vocab_file': - 'punc_vocab.txt', - }, -} - -model_alias = { - "ernie_linear_p7": "paddlespeech.text.models:ErnieLinear", - "ernie_linear_p3": "paddlespeech.text.models:ErnieLinear", -} - -tokenizer_alias = { - "ernie_linear_p7": "paddlenlp.transformers:ErnieTokenizer", - "ernie_linear_p3": "paddlenlp.transformers:ErnieTokenizer", -} diff --git a/paddlespeech/cli/tts/infer.py b/paddlespeech/cli/tts/infer.py index 1c7199306..4e0337bcc 100644 --- a/paddlespeech/cli/tts/infer.py +++ b/paddlespeech/cli/tts/infer.py @@ -28,11 +28,7 @@ from yacs.config import CfgNode from ..executor import BaseExecutor from ..log import logger -from ..utils import cli_register from ..utils import stats_wrapper -from .pretrained_models import model_alias -from .pretrained_models import pretrained_models -from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.modules.normalizer import ZScore @@ -40,14 +36,9 @@ from paddlespeech.t2s.modules.normalizer import ZScore __all__ = ['TTSExecutor'] -@cli_register( - name='paddlespeech.tts', description='Text to Speech infer command.') class TTSExecutor(BaseExecutor): def __init__(self): - super().__init__() - self.model_alias = model_alias - self.pretrained_models = pretrained_models - + super().__init__('tts') self.parser = argparse.ArgumentParser( prog='paddlespeech.tts', add_help=True) self.parser.add_argument( @@ -186,19 +177,23 @@ class TTSExecutor(BaseExecutor): return # am am_tag = am + '-' + lang + self.task_resource.set_task_model( + model_tag=am_tag, + model_type=0, # am + version=None, # default version + ) if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None: - am_res_path = self._get_pretrained_path(am_tag) - self.am_res_path = am_res_path - self.am_config = os.path.join( - am_res_path, self.pretrained_models[am_tag]['config']) - self.am_ckpt = os.path.join(am_res_path, - self.pretrained_models[am_tag]['ckpt']) + self.am_res_path = self.task_resource.res_dir + self.am_config = os.path.join(self.am_res_path, + self.task_resource.res_dict['config']) + self.am_ckpt = os.path.join(self.am_res_path, + self.task_resource.res_dict['ckpt']) self.am_stat = os.path.join( - am_res_path, self.pretrained_models[am_tag]['speech_stats']) + self.am_res_path, self.task_resource.res_dict['speech_stats']) # must have phones_dict in acoustic self.phones_dict = os.path.join( - am_res_path, self.pretrained_models[am_tag]['phones_dict']) - logger.info(am_res_path) + self.am_res_path, self.task_resource.res_dict['phones_dict']) + logger.info(self.am_res_path) logger.info(self.am_config) logger.info(self.am_ckpt) else: @@ -210,32 +205,37 @@ class TTSExecutor(BaseExecutor): # for speedyspeech self.tones_dict = None - if 'tones_dict' in self.pretrained_models[am_tag]: + if 'tones_dict' in self.task_resource.res_dict: self.tones_dict = os.path.join( - am_res_path, self.pretrained_models[am_tag]['tones_dict']) + self.am_res_path, self.task_resource.res_dict['tones_dict']) if tones_dict: self.tones_dict = tones_dict # for multi speaker fastspeech2 self.speaker_dict = None - if 'speaker_dict' in self.pretrained_models[am_tag]: + if 'speaker_dict' in self.task_resource.res_dict: self.speaker_dict = os.path.join( - am_res_path, self.pretrained_models[am_tag]['speaker_dict']) + self.am_res_path, self.task_resource.res_dict['speaker_dict']) if speaker_dict: self.speaker_dict = speaker_dict # voc voc_tag = voc + '-' + lang + self.task_resource.set_task_model( + model_tag=voc_tag, + model_type=1, # vocoder + version=None, # default version + ) if voc_ckpt is None or voc_config is None or voc_stat is None: - voc_res_path = self._get_pretrained_path(voc_tag) - self.voc_res_path = voc_res_path + self.voc_res_path = self.task_resource.voc_res_dir self.voc_config = os.path.join( - voc_res_path, self.pretrained_models[voc_tag]['config']) + self.voc_res_path, self.task_resource.voc_res_dict['config']) self.voc_ckpt = os.path.join( - voc_res_path, self.pretrained_models[voc_tag]['ckpt']) + self.voc_res_path, self.task_resource.voc_res_dict['ckpt']) self.voc_stat = os.path.join( - voc_res_path, self.pretrained_models[voc_tag]['speech_stats']) - logger.info(voc_res_path) + self.voc_res_path, + self.task_resource.voc_res_dict['speech_stats']) + logger.info(self.voc_res_path) logger.info(self.voc_config) logger.info(self.voc_ckpt) else: @@ -285,9 +285,9 @@ class TTSExecutor(BaseExecutor): # model: {model_name}_{dataset} am_name = am[:am.rindex('_')] - am_class = dynamic_import(am_name, self.model_alias) - am_inference_class = dynamic_import(am_name + '_inference', - self.model_alias) + am_class = self.task_resource.get_model_class(am_name) + am_inference_class = self.task_resource.get_model_class(am_name + + '_inference') if am_name == 'fastspeech2': am = am_class( @@ -316,9 +316,9 @@ class TTSExecutor(BaseExecutor): # vocoder # model: {model_name}_{dataset} voc_name = voc[:voc.rindex('_')] - voc_class = dynamic_import(voc_name, self.model_alias) - voc_inference_class = dynamic_import(voc_name + '_inference', - self.model_alias) + voc_class = self.task_resource.get_model_class(voc_name) + voc_inference_class = self.task_resource.get_model_class(voc_name + + '_inference') if voc_name != 'wavernn': voc = voc_class(**self.voc_config["generator_params"]) voc.set_state_dict(paddle.load(self.voc_ckpt)["generator_params"]) @@ -446,7 +446,7 @@ class TTSExecutor(BaseExecutor): if not args.verbose: self.disable_task_loggers() - task_source = self.get_task_source(args.input) + task_source = self.get_input_source(args.input) task_results = OrderedDict() has_exceptions = False diff --git a/paddlespeech/cli/tts/pretrained_models.py b/paddlespeech/cli/tts/pretrained_models.py deleted file mode 100644 index 65254a935..000000000 --- a/paddlespeech/cli/tts/pretrained_models.py +++ /dev/null @@ -1,300 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -pretrained_models = { - # speedyspeech - "speedyspeech_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip', - 'md5': - '6f6fa967b408454b6662c8c00c0027cb', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_30600.pdz', - 'speech_stats': - 'feats_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - 'tones_dict': - 'tone_id_map.txt', - }, - - # fastspeech2 - "fastspeech2_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip', - 'md5': - '637d28a5e53aa60275612ba4393d5f22', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_76000.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - }, - "fastspeech2_ljspeech-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip', - 'md5': - 'ffed800c93deaf16ca9b3af89bfcd747', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_100000.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - }, - "fastspeech2_aishell3-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_ckpt_0.4.zip', - 'md5': - 'f4dd4a5f49a4552b77981f544ab3392e', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_96400.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - 'speaker_dict': - 'speaker_id_map.txt', - }, - "fastspeech2_vctk-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_vctk_ckpt_0.5.zip', - 'md5': - '743e5024ca1e17a88c5c271db9779ba4', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_66200.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - 'speaker_dict': - 'speaker_id_map.txt', - }, - # tacotron2 - "tacotron2_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip', - 'md5': - '0df4b6f0bcbe0d73c5ed6df8867ab91a', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_30600.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - }, - "tacotron2_ljspeech-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.2.0.zip', - 'md5': - '6a5eddd81ae0e81d16959b97481135f3', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_60300.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - }, - - # pwgan - "pwgan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip', - 'md5': - '2e481633325b5bdf0a3823c714d2c117', - 'config': - 'pwg_default.yaml', - 'ckpt': - 'pwg_snapshot_iter_400000.pdz', - 'speech_stats': - 'pwg_stats.npy', - }, - "pwgan_ljspeech-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip', - 'md5': - '53610ba9708fd3008ccaf8e99dacbaf0', - 'config': - 'pwg_default.yaml', - 'ckpt': - 'pwg_snapshot_iter_400000.pdz', - 'speech_stats': - 'pwg_stats.npy', - }, - "pwgan_aishell3-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip', - 'md5': - 'd7598fa41ad362d62f85ffc0f07e3d84', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_1000000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - "pwgan_vctk-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.1.1.zip', - 'md5': - 'b3da1defcde3e578be71eb284cb89f2c', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_1500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - # mb_melgan - "mb_melgan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip', - 'md5': - 'ee5f0604e20091f0d495b6ec4618b90d', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_1000000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - # style_melgan - "style_melgan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip', - 'md5': - '5de2d5348f396de0c966926b8c462755', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_1500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - # hifigan - "hifigan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip', - 'md5': - 'dd40a3d88dfcf64513fba2f0f961ada6', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_2500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - "hifigan_ljspeech-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_ckpt_0.2.0.zip', - 'md5': - '70e9131695decbca06a65fe51ed38a72', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_2500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - "hifigan_aishell3-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip', - 'md5': - '3bb49bc75032ed12f79c00c8cc79a09a', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_2500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - "hifigan_vctk-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip', - 'md5': - '7da8f88359bca2457e705d924cf27bd4', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_2500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - - # wavernn - "wavernn_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_ckpt_0.2.0.zip', - 'md5': - 'ee37b752f09bcba8f2af3b777ca38e13', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_400000.pdz', - 'speech_stats': - 'feats_stats.npy', - } -} - -model_alias = { - # acoustic model - "speedyspeech": - "paddlespeech.t2s.models.speedyspeech:SpeedySpeech", - "speedyspeech_inference": - "paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference", - "fastspeech2": - "paddlespeech.t2s.models.fastspeech2:FastSpeech2", - "fastspeech2_inference": - "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference", - "tacotron2": - "paddlespeech.t2s.models.tacotron2:Tacotron2", - "tacotron2_inference": - "paddlespeech.t2s.models.tacotron2:Tacotron2Inference", - # voc - "pwgan": - "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator", - "pwgan_inference": - "paddlespeech.t2s.models.parallel_wavegan:PWGInference", - "mb_melgan": - "paddlespeech.t2s.models.melgan:MelGANGenerator", - "mb_melgan_inference": - "paddlespeech.t2s.models.melgan:MelGANInference", - "style_melgan": - "paddlespeech.t2s.models.melgan:StyleMelGANGenerator", - "style_melgan_inference": - "paddlespeech.t2s.models.melgan:StyleMelGANInference", - "hifigan": - "paddlespeech.t2s.models.hifigan:HiFiGANGenerator", - "hifigan_inference": - "paddlespeech.t2s.models.hifigan:HiFiGANInference", - "wavernn": - "paddlespeech.t2s.models.wavernn:WaveRNN", - "wavernn_inference": - "paddlespeech.t2s.models.wavernn:WaveRNNInference", -} diff --git a/paddlespeech/cli/utils.py b/paddlespeech/cli/utils.py index 82d40c8bc..21c887e99 100644 --- a/paddlespeech/cli/utils.py +++ b/paddlespeech/cli/utils.py @@ -28,7 +28,7 @@ import requests import yaml from paddle.framework import load -import paddleaudio +import paddlespeech.audio from . import download from .entry import commands try: @@ -41,6 +41,7 @@ requests.adapters.DEFAULT_RETRIES = 3 __all__ = [ 'timer_register', 'cli_register', + 'explicit_command_register', 'get_command', 'download_and_decompress', 'load_state_dict_from_url', @@ -70,6 +71,16 @@ def cli_register(name: str, description: str='') -> Any: return _warpper +def explicit_command_register(name: str, description: str='', cls: str=''): + items = name.split('.') + com = commands + for item in items: + com = com[item] + com['_entry'] = cls + if description: + com['_description'] = description + + def get_command(name: str) -> Any: items = name.split('.') com = commands @@ -179,6 +190,7 @@ def _get_sub_home(directory): PPSPEECH_HOME = _get_paddlespcceh_home() MODEL_HOME = _get_sub_home('models') CONF_HOME = _get_sub_home('conf') +DATA_HOME = _get_sub_home('datasets') def _md5(text: str): @@ -270,7 +282,7 @@ def _note_one_stat(cls_name, params={}): if 'audio_file' in params: try: - _, sr = paddleaudio.load(params['audio_file']) + _, sr = paddlespeech.audio.load(params['audio_file']) except Exception: sr = -1 diff --git a/paddlespeech/cli/vector/infer.py b/paddlespeech/cli/vector/infer.py index 0a169f8bb..4bc8e135a 100644 --- a/paddlespeech/cli/vector/infer.py +++ b/paddlespeech/cli/vector/infer.py @@ -22,30 +22,20 @@ from typing import Union import paddle import soundfile -from paddleaudio.backends import load as load_audio -from paddleaudio.compliance.librosa import melspectrogram from yacs.config import CfgNode from ..executor import BaseExecutor from ..log import logger -from ..utils import cli_register from ..utils import stats_wrapper -from .pretrained_models import model_alias -from .pretrained_models import pretrained_models -from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.audio.backends import load as load_audio +from paddlespeech.audio.compliance.librosa import melspectrogram from paddlespeech.vector.io.batch import feature_normalize from paddlespeech.vector.modules.sid_model import SpeakerIdetification -@cli_register( - name="paddlespeech.vector", - description="Speech to vector embedding infer command.") class VectorExecutor(BaseExecutor): def __init__(self): - super().__init__() - self.model_alias = model_alias - self.pretrained_models = pretrained_models - + super().__init__('vector') self.parser = argparse.ArgumentParser( prog="paddlespeech.vector", add_help=True) @@ -53,7 +43,10 @@ class VectorExecutor(BaseExecutor): "--model", type=str, default="ecapatdnn_voxceleb12", - choices=["ecapatdnn_voxceleb12"], + choices=[ + tag[:tag.index('-')] + for tag in self.task_resource.pretrained_models.keys() + ], help="Choose model type of vector task.") self.parser.add_argument( "--task", @@ -123,7 +116,7 @@ class VectorExecutor(BaseExecutor): self.disable_task_loggers() # stage 2: read the input data and store them as a list - task_source = self.get_task_source(parser_args.input) + task_source = self.get_input_source(parser_args.input) logger.info(f"task source: {task_source}") # stage 3: process the audio one by one @@ -300,17 +293,18 @@ class VectorExecutor(BaseExecutor): # get the mode from pretrained list sample_rate_str = "16k" if sample_rate == 16000 else "8k" tag = model_type + "-" + sample_rate_str + self.task_resource.set_task_model(tag, version=None) logger.info(f"load the pretrained model: {tag}") # get the model from the pretrained list # we download the pretrained model and store it in the res_path - res_path = self._get_pretrained_path(tag) - self.res_path = res_path + self.res_path = self.task_resource.res_dir self.cfg_path = os.path.join( - res_path, self.pretrained_models[tag]['cfg_path']) + self.task_resource.res_dir, + self.task_resource.res_dict['cfg_path']) self.ckpt_path = os.path.join( - res_path, - self.pretrained_models[tag]['ckpt_path'] + '.pdparams') + self.task_resource.res_dir, + self.task_resource.res_dict['ckpt_path'] + '.pdparams') else: # get the model from disk self.cfg_path = os.path.abspath(cfg_path) @@ -329,8 +323,8 @@ class VectorExecutor(BaseExecutor): # stage 3: get the model name to instance the model network with dynamic_import logger.info("start to dynamic import the model class") model_name = model_type[:model_type.rindex('_')] + model_class = self.task_resource.get_model_class(model_name) logger.info(f"model name {model_name}") - model_class = dynamic_import(model_name, self.model_alias) model_conf = self.config.model backbone = model_class(**model_conf) model = SpeakerIdetification( diff --git a/paddlespeech/cli/vector/pretrained_models.py b/paddlespeech/cli/vector/pretrained_models.py deleted file mode 100644 index 686a22d8f..000000000 --- a/paddlespeech/cli/vector/pretrained_models.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -pretrained_models = { - # The tags for pretrained_models should be "{model_name}[-{dataset}][-{sr}][-...]". - # e.g. "ecapatdnn_voxceleb12-16k". - # Command line and python api use "{model_name}[-{dataset}]" as --model, usage: - # "paddlespeech vector --task spk --model ecapatdnn_voxceleb12-16k --sr 16000 --input ./input.wav" - "ecapatdnn_voxceleb12-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz', - 'md5': - 'cc33023c54ab346cd318408f43fcaf95', - 'cfg_path': - 'conf/model.yaml', # the yaml config path - 'ckpt_path': - 'model/model', # the format is ${dir}/{model_name}, - # so the first 'model' is dir, the second 'model' is the name - # this means we have a model stored as model/model.pdparams - }, -} - -model_alias = { - "ecapatdnn": "paddlespeech.vector.models.ecapa_tdnn:EcapaTdnn", -} diff --git a/paddlespeech/cls/exps/panns/deploy/predict.py b/paddlespeech/cls/exps/panns/deploy/predict.py index ee566ed4f..fe1c93fa8 100644 --- a/paddlespeech/cls/exps/panns/deploy/predict.py +++ b/paddlespeech/cls/exps/panns/deploy/predict.py @@ -16,11 +16,12 @@ import os import numpy as np from paddle import inference -from paddleaudio.backends import load as load_audio -from paddleaudio.datasets import ESC50 -from paddleaudio.features import melspectrogram from scipy.special import softmax +from paddlespeech.audio.backends import load as load_audio +from paddlespeech.audio.datasets import ESC50 +from paddlespeech.audio.features import melspectrogram + # yapf: disable parser = argparse.ArgumentParser() parser.add_argument("--model_dir", type=str, required=True, default="./export", help="The directory to static model.") diff --git a/paddlespeech/cls/exps/panns/export_model.py b/paddlespeech/cls/exps/panns/export_model.py index 63b22981a..e62d58f02 100644 --- a/paddlespeech/cls/exps/panns/export_model.py +++ b/paddlespeech/cls/exps/panns/export_model.py @@ -15,8 +15,8 @@ import argparse import os import paddle -from paddleaudio.datasets import ESC50 +from paddlespeech.audio.datasets import ESC50 from paddlespeech.cls.models import cnn14 from paddlespeech.cls.models import SoundClassifier diff --git a/paddlespeech/cls/exps/panns/predict.py b/paddlespeech/cls/exps/panns/predict.py index a3f9f9a9b..97759a89d 100644 --- a/paddlespeech/cls/exps/panns/predict.py +++ b/paddlespeech/cls/exps/panns/predict.py @@ -17,12 +17,12 @@ import os import paddle import paddle.nn.functional as F import yaml -from paddleaudio.backends import load as load_audio -from paddleaudio.features import LogMelSpectrogram -from paddleaudio.utils import logger +from paddlespeech.audio.backends import load as load_audio +from paddlespeech.audio.features import LogMelSpectrogram +from paddlespeech.audio.utils import logger from paddlespeech.cls.models import SoundClassifier -from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.utils.dynamic_import import dynamic_import # yapf: disable parser = argparse.ArgumentParser(__doc__) diff --git a/paddlespeech/cls/exps/panns/train.py b/paddlespeech/cls/exps/panns/train.py index 5a2f3042a..fba38a01c 100644 --- a/paddlespeech/cls/exps/panns/train.py +++ b/paddlespeech/cls/exps/panns/train.py @@ -16,12 +16,12 @@ import os import paddle import yaml -from paddleaudio.features import LogMelSpectrogram -from paddleaudio.utils import logger -from paddleaudio.utils import Timer +from paddlespeech.audio.features import LogMelSpectrogram +from paddlespeech.audio.utils import logger +from paddlespeech.audio.utils import Timer from paddlespeech.cls.models import SoundClassifier -from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.utils.dynamic_import import dynamic_import # yapf: disable parser = argparse.ArgumentParser(__doc__) diff --git a/paddlespeech/cls/models/panns/panns.py b/paddlespeech/cls/models/panns/panns.py index b442b2fd1..4befe7aa4 100644 --- a/paddlespeech/cls/models/panns/panns.py +++ b/paddlespeech/cls/models/panns/panns.py @@ -15,8 +15,9 @@ import os import paddle.nn as nn import paddle.nn.functional as F -from paddleaudio.utils.download import load_state_dict_from_url -from paddleaudio.utils.env import MODEL_HOME + +from paddlespeech.audio.utils import MODEL_HOME +from paddlespeech.audio.utils.download import load_state_dict_from_url __all__ = ['CNN14', 'CNN10', 'CNN6', 'cnn14', 'cnn10', 'cnn6'] diff --git a/paddlespeech/kws/exps/mdtc/compute_det.py b/paddlespeech/kws/exps/mdtc/compute_det.py index e43a953db..853056966 100644 --- a/paddlespeech/kws/exps/mdtc/compute_det.py +++ b/paddlespeech/kws/exps/mdtc/compute_det.py @@ -1,3 +1,5 @@ +# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) +# 2022 Shaoqing Yu(954793264@qq.com) # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/paddlespeech/kws/exps/mdtc/plot_det_curve.py b/paddlespeech/kws/exps/mdtc/plot_det_curve.py index a3ea21eff..4960281ee 100644 --- a/paddlespeech/kws/exps/mdtc/plot_det_curve.py +++ b/paddlespeech/kws/exps/mdtc/plot_det_curve.py @@ -1,3 +1,5 @@ +# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) +# Menglong Xu # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/paddlespeech/kws/exps/mdtc/score.py b/paddlespeech/kws/exps/mdtc/score.py index 1b5e1e296..556455ca1 100644 --- a/paddlespeech/kws/exps/mdtc/score.py +++ b/paddlespeech/kws/exps/mdtc/score.py @@ -1,4 +1,6 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) +# 2022 Shaoqing Yu(954793264@qq.com) +# 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/paddlespeech/kws/exps/mdtc/train.py b/paddlespeech/kws/exps/mdtc/train.py index 5a9ca92d1..94e45d590 100644 --- a/paddlespeech/kws/exps/mdtc/train.py +++ b/paddlespeech/kws/exps/mdtc/train.py @@ -14,10 +14,10 @@ import os import paddle -from paddleaudio.utils import logger -from paddleaudio.utils import Timer from yacs.config import CfgNode +from paddlespeech.audio.utils import logger +from paddlespeech.audio.utils import Timer from paddlespeech.kws.exps.mdtc.collate import collate_features from paddlespeech.kws.models.loss import max_pooling_loss from paddlespeech.kws.models.mdtc import KWSModel diff --git a/paddlespeech/kws/models/loss.py b/paddlespeech/kws/models/loss.py index 64c9a32c9..bda77f2ba 100644 --- a/paddlespeech/kws/models/loss.py +++ b/paddlespeech/kws/models/loss.py @@ -1,3 +1,4 @@ +# Copyright (c) 2021 Binbin Zhang # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/paddlespeech/kws/models/mdtc.py b/paddlespeech/kws/models/mdtc.py index 5d2e5de64..c605a02b6 100644 --- a/paddlespeech/kws/models/mdtc.py +++ b/paddlespeech/kws/models/mdtc.py @@ -1,3 +1,4 @@ +# Copyright (c) 2021 Jingyong Hou (houjingyong@gmail.com) # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/examples/other/1xt2x/src_deepspeech2x/models/ds2/__init__.py b/paddlespeech/resource/__init__.py similarity index 72% rename from examples/other/1xt2x/src_deepspeech2x/models/ds2/__init__.py rename to paddlespeech/resource/__init__.py index 39bea5bf9..e143413af 100644 --- a/examples/other/1xt2x/src_deepspeech2x/models/ds2/__init__.py +++ b/paddlespeech/resource/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,7 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .deepspeech2 import DeepSpeech2InferModel -from .deepspeech2 import DeepSpeech2Model - -__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel'] +from .resource import CommonTaskResource diff --git a/paddlespeech/resource/model_alias.py b/paddlespeech/resource/model_alias.py new file mode 100644 index 000000000..5309fd86f --- /dev/null +++ b/paddlespeech/resource/model_alias.py @@ -0,0 +1,86 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [ + 'model_alias', +] + +# Records of model name to import class +model_alias = { + # --------------------------------- + # -------------- ASR -------------- + # --------------------------------- + "deepspeech2offline": ["paddlespeech.s2t.models.ds2:DeepSpeech2Model"], + "deepspeech2online": ["paddlespeech.s2t.models.ds2:DeepSpeech2Model"], + "conformer": ["paddlespeech.s2t.models.u2:U2Model"], + "conformer_online": ["paddlespeech.s2t.models.u2:U2Model"], + "transformer": ["paddlespeech.s2t.models.u2:U2Model"], + "wenetspeech": ["paddlespeech.s2t.models.u2:U2Model"], + + # --------------------------------- + # -------------- CLS -------------- + # --------------------------------- + "panns_cnn6": ["paddlespeech.cls.models.panns:CNN6"], + "panns_cnn10": ["paddlespeech.cls.models.panns:CNN10"], + "panns_cnn14": ["paddlespeech.cls.models.panns:CNN14"], + + # --------------------------------- + # -------------- ST --------------- + # --------------------------------- + "fat_st": ["paddlespeech.s2t.models.u2_st:U2STModel"], + + # --------------------------------- + # -------------- TEXT ------------- + # --------------------------------- + "ernie_linear_p7": [ + "paddlespeech.text.models:ErnieLinear", + "paddlenlp.transformers:ErnieTokenizer" + ], + "ernie_linear_p3": [ + "paddlespeech.text.models:ErnieLinear", + "paddlenlp.transformers:ErnieTokenizer" + ], + + # --------------------------------- + # -------------- TTS -------------- + # --------------------------------- + # acoustic model + "speedyspeech": ["paddlespeech.t2s.models.speedyspeech:SpeedySpeech"], + "speedyspeech_inference": + ["paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference"], + "fastspeech2": ["paddlespeech.t2s.models.fastspeech2:FastSpeech2"], + "fastspeech2_inference": + ["paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference"], + "tacotron2": ["paddlespeech.t2s.models.tacotron2:Tacotron2"], + "tacotron2_inference": + ["paddlespeech.t2s.models.tacotron2:Tacotron2Inference"], + # voc + "pwgan": ["paddlespeech.t2s.models.parallel_wavegan:PWGGenerator"], + "pwgan_inference": + ["paddlespeech.t2s.models.parallel_wavegan:PWGInference"], + "mb_melgan": ["paddlespeech.t2s.models.melgan:MelGANGenerator"], + "mb_melgan_inference": ["paddlespeech.t2s.models.melgan:MelGANInference"], + "style_melgan": ["paddlespeech.t2s.models.melgan:StyleMelGANGenerator"], + "style_melgan_inference": + ["paddlespeech.t2s.models.melgan:StyleMelGANInference"], + "hifigan": ["paddlespeech.t2s.models.hifigan:HiFiGANGenerator"], + "hifigan_inference": ["paddlespeech.t2s.models.hifigan:HiFiGANInference"], + "wavernn": ["paddlespeech.t2s.models.wavernn:WaveRNN"], + "wavernn_inference": ["paddlespeech.t2s.models.wavernn:WaveRNNInference"], + + # --------------------------------- + # ------------ Vector ------------- + # --------------------------------- + "ecapatdnn": ["paddlespeech.vector.models.ecapa_tdnn:EcapaTdnn"], +} diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py new file mode 100644 index 000000000..f79961d64 --- /dev/null +++ b/paddlespeech/resource/pretrained_models.py @@ -0,0 +1,838 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [ + 'asr_dynamic_pretrained_models', + 'asr_static_pretrained_models', + 'cls_dynamic_pretrained_models', + 'cls_static_pretrained_models', + 'st_dynamic_pretrained_models', + 'st_kaldi_bins', + 'text_dynamic_pretrained_models', + 'tts_dynamic_pretrained_models', + 'tts_static_pretrained_models', + 'tts_onnx_pretrained_models', + 'vector_dynamic_pretrained_models', +] + +# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". +# e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k". +# Command line and python api use "{model_name}[_{dataset}]" as --model, usage: +# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" + +# --------------------------------- +# -------------- ASR -------------- +# --------------------------------- +asr_dynamic_pretrained_models = { + "conformer_wenetspeech-zh-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz', + 'md5': + '76cb19ed857e6623856b7cd7ebbfeda4', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/conformer/checkpoints/wenetspeech', + }, + }, + "conformer_online_wenetspeech-zh-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz', + 'md5': + 'b8c02632b04da34aca88459835be54a6', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/chunk_conformer/checkpoints/avg_10', + 'model': + 'exp/chunk_conformer/checkpoints/avg_10.pdparams', + 'params': + 'exp/chunk_conformer/checkpoints/avg_10.pdparams', + 'lm_url': + '', + 'lm_md5': + '', + }, + }, + "conformer_online_multicn-zh-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.0.model.tar.gz', + 'md5': + '7989b3248c898070904cf042fd656003', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/chunk_conformer/checkpoints/multi_cn', + }, + '2.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz', + 'md5': + '0ac93d390552336f2a906aec9e33c5fa', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/chunk_conformer/checkpoints/multi_cn', + 'model': + 'exp/chunk_conformer/checkpoints/multi_cn.pdparams', + 'params': + 'exp/chunk_conformer/checkpoints/multi_cn.pdparams', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3', + }, + }, + "conformer_aishell-zh-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_0.1.2.model.tar.gz', + 'md5': + '3f073eccfa7bb14e0c6867d65fc0dc3a', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/conformer/checkpoints/avg_30', + }, + }, + "conformer_online_aishell-zh-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz', + 'md5': + 'b374cfb93537761270b6224fb0bfc26a', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/chunk_conformer/checkpoints/avg_30', + }, + }, + "transformer_librispeech-en-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz', + 'md5': + '2c667da24922aad391eacafe37bc1660', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/transformer/checkpoints/avg_10', + }, + }, + "deepspeech2online_wenetspeech-zh-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.2.model.tar.gz', + 'md5': + 'b0c77e7f8881e0a27b82127d1abb8d5f', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2_online/checkpoints/avg_10', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, + }, + "deepspeech2offline_aishell-zh-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_offline_aishell_ckpt_1.0.1.model.tar.gz', + 'md5': + '4d26066c6f19f52087425dc722ae5b13', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2/checkpoints/avg_10', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, + }, + "deepspeech2online_aishell-zh-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz', + 'md5': + 'df5ddeac8b679a470176649ac4b78726', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2_online/checkpoints/avg_1', + 'model': + 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', + 'params': + 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, + }, + "deepspeech2offline_librispeech-en-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_offline_librispeech_ckpt_1.0.1.model.tar.gz', + 'md5': + 'ed9e2b008a65268b3484020281ab048c', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2/checkpoints/avg_5', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm', + 'lm_md5': + '099a601759d467cd0a8523ff939819c5' + }, + }, +} + +asr_static_pretrained_models = { + "deepspeech2offline_aishell-zh-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_offline_aishell_ckpt_1.0.1.model.tar.gz', + 'md5': + '4d26066c6f19f52087425dc722ae5b13', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2/checkpoints/avg_10', + 'model': + 'exp/deepspeech2/checkpoints/avg_10.jit.pdmodel', + 'params': + 'exp/deepspeech2/checkpoints/avg_10.jit.pdiparams', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + } + }, +} + +# --------------------------------- +# -------------- CLS -------------- +# --------------------------------- +cls_dynamic_pretrained_models = { + "panns_cnn6-32k": { + '1.0': { + 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz', + 'md5': '4cf09194a95df024fd12f84712cf0f9c', + 'cfg_path': 'panns.yaml', + 'ckpt_path': 'cnn6.pdparams', + 'label_file': 'audioset_labels.txt', + }, + }, + "panns_cnn10-32k": { + '1.0': { + 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz', + 'md5': 'cb8427b22176cc2116367d14847f5413', + 'cfg_path': 'panns.yaml', + 'ckpt_path': 'cnn10.pdparams', + 'label_file': 'audioset_labels.txt', + }, + }, + "panns_cnn14-32k": { + '1.0': { + 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz', + 'md5': 'e3b9b5614a1595001161d0ab95edee97', + 'cfg_path': 'panns.yaml', + 'ckpt_path': 'cnn14.pdparams', + 'label_file': 'audioset_labels.txt', + }, + }, +} + +cls_static_pretrained_models = { + "panns_cnn6-32k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn6_static.tar.gz', + 'md5': + 'da087c31046d23281d8ec5188c1967da', + 'cfg_path': + 'panns.yaml', + 'model_path': + 'inference.pdmodel', + 'params_path': + 'inference.pdiparams', + 'label_file': + 'audioset_labels.txt', + }, + }, + "panns_cnn10-32k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn10_static.tar.gz', + 'md5': + '5460cc6eafbfaf0f261cc75b90284ae1', + 'cfg_path': + 'panns.yaml', + 'model_path': + 'inference.pdmodel', + 'params_path': + 'inference.pdiparams', + 'label_file': + 'audioset_labels.txt', + }, + }, + "panns_cnn14-32k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn14_static.tar.gz', + 'md5': + 'ccc80b194821274da79466862b2ab00f', + 'cfg_path': + 'panns.yaml', + 'model_path': + 'inference.pdmodel', + 'params_path': + 'inference.pdiparams', + 'label_file': + 'audioset_labels.txt', + }, + }, +} + +# --------------------------------- +# -------------- ST --------------- +# --------------------------------- +st_dynamic_pretrained_models = { + "fat_st_ted-en-zh": { + '1.0': { + "url": + "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/st1_transformer_mtl_noam_ted-en-zh_ckpt_0.1.1.model.tar.gz", + "md5": + "d62063f35a16d91210a71081bd2dd557", + "cfg_path": + "model.yaml", + "ckpt_path": + "exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams", + }, + }, +} + +st_kaldi_bins = { + "url": + "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/kaldi_bins.tar.gz", + "md5": + "c0682303b3f3393dbf6ed4c4e35a53eb", +} + +# --------------------------------- +# -------------- TEXT ------------- +# --------------------------------- +text_dynamic_pretrained_models = { + "ernie_linear_p7_wudao-punc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p7_wudao-punc-zh.tar.gz', + 'md5': + '12283e2ddde1797c5d1e57036b512746', + 'cfg_path': + 'ckpt/model_config.json', + 'ckpt_path': + 'ckpt/model_state.pdparams', + 'vocab_file': + 'punc_vocab.txt', + }, + }, + "ernie_linear_p3_wudao-punc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_wudao-punc-zh.tar.gz', + 'md5': + '448eb2fdf85b6a997e7e652e80c51dd2', + 'cfg_path': + 'ckpt/model_config.json', + 'ckpt_path': + 'ckpt/model_state.pdparams', + 'vocab_file': + 'punc_vocab.txt', + }, + }, +} + +# --------------------------------- +# -------------- TTS -------------- +# --------------------------------- +tts_dynamic_pretrained_models = { + # speedyspeech + "speedyspeech_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip', + 'md5': + '6f6fa967b408454b6662c8c00c0027cb', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_30600.pdz', + 'speech_stats': + 'feats_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + 'tones_dict': + 'tone_id_map.txt', + }, + }, + # fastspeech2 + "fastspeech2_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip', + 'md5': + '637d28a5e53aa60275612ba4393d5f22', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_76000.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + }, + "fastspeech2_ljspeech-en": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip', + 'md5': + 'ffed800c93deaf16ca9b3af89bfcd747', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_100000.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + }, + "fastspeech2_aishell3-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_ckpt_0.4.zip', + 'md5': + 'f4dd4a5f49a4552b77981f544ab3392e', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_96400.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + 'speaker_dict': + 'speaker_id_map.txt', + }, + }, + "fastspeech2_vctk-en": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_vctk_ckpt_0.5.zip', + 'md5': + '743e5024ca1e17a88c5c271db9779ba4', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_66200.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + 'speaker_dict': + 'speaker_id_map.txt', + }, + }, + # tacotron2 + "tacotron2_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip', + 'md5': + '0df4b6f0bcbe0d73c5ed6df8867ab91a', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_30600.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + }, + "tacotron2_ljspeech-en": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.2.0.zip', + 'md5': + '6a5eddd81ae0e81d16959b97481135f3', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_60300.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + }, + # pwgan + "pwgan_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip', + 'md5': + '2e481633325b5bdf0a3823c714d2c117', + 'config': + 'pwg_default.yaml', + 'ckpt': + 'pwg_snapshot_iter_400000.pdz', + 'speech_stats': + 'pwg_stats.npy', + }, + }, + "pwgan_ljspeech-en": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip', + 'md5': + '53610ba9708fd3008ccaf8e99dacbaf0', + 'config': + 'pwg_default.yaml', + 'ckpt': + 'pwg_snapshot_iter_400000.pdz', + 'speech_stats': + 'pwg_stats.npy', + }, + }, + "pwgan_aishell3-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip', + 'md5': + 'd7598fa41ad362d62f85ffc0f07e3d84', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_1000000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + }, + "pwgan_vctk-en": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.1.1.zip', + 'md5': + 'b3da1defcde3e578be71eb284cb89f2c', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_1500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + }, + # mb_melgan + "mb_melgan_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip', + 'md5': + 'ee5f0604e20091f0d495b6ec4618b90d', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_1000000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + }, + # style_melgan + "style_melgan_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip', + 'md5': + '5de2d5348f396de0c966926b8c462755', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_1500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + }, + # hifigan + "hifigan_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip', + 'md5': + 'dd40a3d88dfcf64513fba2f0f961ada6', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_2500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + }, + "hifigan_ljspeech-en": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_ckpt_0.2.0.zip', + 'md5': + '70e9131695decbca06a65fe51ed38a72', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_2500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + }, + "hifigan_aishell3-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip', + 'md5': + '3bb49bc75032ed12f79c00c8cc79a09a', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_2500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + }, + "hifigan_vctk-en": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip', + 'md5': + '7da8f88359bca2457e705d924cf27bd4', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_2500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + }, + # wavernn + "wavernn_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_ckpt_0.2.0.zip', + 'md5': + 'ee37b752f09bcba8f2af3b777ca38e13', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_400000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + }, + "fastspeech2_cnndecoder_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0.zip', + 'md5': + '6eb28e22ace73e0ebe7845f86478f89f', + 'config': + 'cnndecoder.yaml', + 'ckpt': + 'snapshot_iter_153000.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + }, +} + +tts_static_pretrained_models = { + # speedyspeech + "speedyspeech_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_static_0.5.zip', + 'md5': + 'f10cbdedf47dc7a9668d2264494e1823', + 'model': + 'speedyspeech_csmsc.pdmodel', + 'params': + 'speedyspeech_csmsc.pdiparams', + 'phones_dict': + 'phone_id_map.txt', + 'tones_dict': + 'tone_id_map.txt', + 'sample_rate': + 24000, + }, + }, + # fastspeech2 + "fastspeech2_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_static_0.4.zip', + 'md5': + '9788cd9745e14c7a5d12d32670b2a5a7', + 'model': + 'fastspeech2_csmsc.pdmodel', + 'params': + 'fastspeech2_csmsc.pdiparams', + 'phones_dict': + 'phone_id_map.txt', + 'sample_rate': + 24000, + }, + }, + # pwgan + "pwgan_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip', + 'md5': + 'e3504aed9c5a290be12d1347836d2742', + 'model': + 'pwgan_csmsc.pdmodel', + 'params': + 'pwgan_csmsc.pdiparams', + 'sample_rate': + 24000, + }, + }, + # mb_melgan + "mb_melgan_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_static_0.1.1.zip', + 'md5': + 'ac6eee94ba483421d750433f4c3b8d36', + 'model': + 'mb_melgan_csmsc.pdmodel', + 'params': + 'mb_melgan_csmsc.pdiparams', + 'sample_rate': + 24000, + }, + }, + # hifigan + "hifigan_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_static_0.1.1.zip', + 'md5': + '7edd8c436b3a5546b3a7cb8cff9d5a0c', + 'model': + 'hifigan_csmsc.pdmodel', + 'params': + 'hifigan_csmsc.pdiparams', + 'sample_rate': + 24000, + }, + }, +} + +tts_onnx_pretrained_models = { + # fastspeech2 + "fastspeech2_csmsc_onnx-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_onnx_0.2.0.zip', + 'md5': + 'fd3ad38d83273ad51f0ea4f4abf3ab4e', + 'ckpt': ['fastspeech2_csmsc.onnx'], + 'phones_dict': + 'phone_id_map.txt', + 'sample_rate': + 24000, + }, + }, + "fastspeech2_cnndecoder_csmsc_onnx-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip', + 'md5': + '5f70e1a6bcd29d72d54e7931aa86f266', + 'ckpt': [ + 'fastspeech2_csmsc_am_encoder_infer.onnx', + 'fastspeech2_csmsc_am_decoder.onnx', + 'fastspeech2_csmsc_am_postnet.onnx', + ], + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + 'sample_rate': + 24000, + }, + }, + # mb_melgan + "mb_melgan_csmsc_onnx-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip', + 'md5': + '5b83ec746e8414bc29032d954ffd07ec', + 'ckpt': + 'mb_melgan_csmsc.onnx', + 'sample_rate': + 24000, + }, + }, + # hifigan + "hifigan_csmsc_onnx-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_onnx_0.2.0.zip', + 'md5': + '1a7dc0385875889e46952e50c0994a6b', + 'ckpt': + 'hifigan_csmsc.onnx', + 'sample_rate': + 24000, + }, + }, +} + +# --------------------------------- +# ------------ Vector ------------- +# --------------------------------- +vector_dynamic_pretrained_models = { + "ecapatdnn_voxceleb12-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz', + 'md5': + 'cc33023c54ab346cd318408f43fcaf95', + 'cfg_path': + 'conf/model.yaml', # the yaml config path + 'ckpt_path': + 'model/model', # the format is ${dir}/{model_name}, + # so the first 'model' is dir, the second 'model' is the name + # this means we have a model stored as model/model.pdparams + }, + }, +} diff --git a/paddlespeech/resource/resource.py b/paddlespeech/resource/resource.py new file mode 100644 index 000000000..369dba900 --- /dev/null +++ b/paddlespeech/resource/resource.py @@ -0,0 +1,233 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from collections import OrderedDict +from typing import Dict +from typing import List +from typing import Optional + +from ..cli.utils import download_and_decompress +from ..cli.utils import MODEL_HOME +from ..utils.dynamic_import import dynamic_import +from .model_alias import model_alias + +task_supported = ['asr', 'cls', 'st', 'text', 'tts', 'vector'] +model_format_supported = ['dynamic', 'static', 'onnx'] +inference_mode_supported = ['online', 'offline'] + + +class CommonTaskResource: + def __init__(self, task: str, model_format: str='dynamic', **kwargs): + assert task in task_supported, 'Arg "task" must be one of {}.'.format( + task_supported) + assert model_format in model_format_supported, 'Arg "model_format" must be one of {}.'.format( + model_format_supported) + + self.task = task + self.model_format = model_format + self.pretrained_models = self._get_pretrained_models() + + if 'inference_mode' in kwargs: + assert kwargs[ + 'inference_mode'] in inference_mode_supported, 'Arg "inference_mode" must be one of {}.'.format( + inference_mode_supported) + self._inference_mode_filter(kwargs['inference_mode']) + + # Initialize after model and version had been set. + self.model_tag = None + self.version = None + self.res_dict = None + self.res_dir = None + + if self.task == 'tts': + # For vocoder + self.voc_model_tag = None + self.voc_version = None + self.voc_res_dict = None + self.voc_res_dir = None + + def set_task_model(self, + model_tag: str, + model_type: int=0, + version: Optional[str]=None): + """Set model tag and version of current task. + + Args: + model_tag (str): Model tag. + model_type (int): 0 for acoustic model otherwise vocoder in tts task. + version (Optional[str], optional): Version of pretrained model. Defaults to None. + """ + assert model_tag in self.pretrained_models, \ + "Can't find \"{}\" in resource. Model name must be one of {}".format(model_tag, list(self.pretrained_models.keys())) + + if version is None: + version = self._get_default_version(model_tag) + + assert version in self.pretrained_models[model_tag], \ + "Can't find version \"{}\" in \"{}\". Model name must be one of {}".format( + version, model_tag, list(self.pretrained_models[model_tag].keys())) + + if model_type == 0: + self.model_tag = model_tag + self.version = version + self.res_dict = self.pretrained_models[model_tag][version] + self._format_path(self.res_dict) + self.res_dir = self._fetch(self.res_dict, + self._get_model_dir(model_type)) + else: + assert self.task == 'tts', 'Vocoder will only be used in tts task.' + self.voc_model_tag = model_tag + self.voc_version = version + self.voc_res_dict = self.pretrained_models[model_tag][version] + self._format_path(self.voc_res_dict) + self.voc_res_dir = self._fetch(self.voc_res_dict, + self._get_model_dir(model_type)) + + @staticmethod + def get_model_class(model_name) -> List[object]: + """Dynamic import model class. + Args: + model_name (str): Model name. + + Returns: + List[object]: Return a list of model class. + """ + assert model_name in model_alias, 'No model classes found for "{}"'.format( + model_name) + + ret = [] + for import_path in model_alias[model_name]: + ret.append(dynamic_import(import_path)) + + if len(ret) == 1: + return ret[0] + else: + return ret + + def get_versions(self, model_tag: str) -> List[str]: + """List all available versions. + + Args: + model_tag (str): Model tag. + + Returns: + List[str]: Version list of model. + """ + return list(self.pretrained_models[model_tag].keys()) + + def _get_default_version(self, model_tag: str) -> str: + """Get default version of model. + + Args: + model_tag (str): Model tag. + + Returns: + str: Default version. + """ + return self.get_versions(model_tag)[-1] # get latest version + + def _get_model_dir(self, model_type: int=0) -> os.PathLike: + """Get resource directory. + + Args: + model_type (int): 0 for acoustic model otherwise vocoder in tts task. + + Returns: + os.PathLike: Directory of model resource. + """ + if model_type == 0: + model_tag = self.model_tag + version = self.version + else: + model_tag = self.voc_model_tag + version = self.voc_version + + return os.path.join(MODEL_HOME, model_tag, version) + + def _get_pretrained_models(self) -> Dict[str, str]: + """Get all available models for current task. + + Returns: + Dict[str, str]: A dictionary with model tag and resources info. + """ + try: + import_models = '{}_{}_pretrained_models'.format(self.task, + self.model_format) + exec('from .pretrained_models import {}'.format(import_models)) + models = OrderedDict(locals()[import_models]) + except ImportError: + models = OrderedDict({}) # no models. + finally: + return models + + def _inference_mode_filter(self, inference_mode: Optional[str]): + """Filter models dict based on inference_mode. + + Args: + inference_mode (Optional[str]): 'online', 'offline' or None. + """ + if inference_mode is None: + return + + if self.task == 'asr': + online_flags = [ + 'online' in model_tag + for model_tag in self.pretrained_models.keys() + ] + for online_flag, model_tag in zip( + online_flags, list(self.pretrained_models.keys())): + if inference_mode == 'online' and online_flag: + continue + elif inference_mode == 'offline' and not online_flag: + continue + else: + del self.pretrained_models[model_tag] + elif self.task == 'tts': + # Hardcode for tts online models. + tts_online_models = [ + 'fastspeech2_csmsc-zh', 'fastspeech2_cnndecoder_csmsc-zh', + 'mb_melgan_csmsc-zh', 'hifigan_csmsc-zh' + ] + for model_tag in list(self.pretrained_models.keys()): + if inference_mode == 'online' and model_tag in tts_online_models: + continue + elif inference_mode == 'offline': + continue + else: + del self.pretrained_models[model_tag] + else: + raise NotImplementedError('Only supports asr and tts task.') + + @staticmethod + def _fetch(res_dict: Dict[str, str], + target_dir: os.PathLike) -> os.PathLike: + """Fetch archive from url. + + Args: + res_dict (Dict[str, str]): Info dict of a resource. + target_dir (os.PathLike): Directory to save archives. + + Returns: + os.PathLike: Directory of model resource. + """ + return download_and_decompress(res_dict, target_dir) + + @staticmethod + def _format_path(res_dict: Dict[str, str]): + for k, v in res_dict.items(): + if isinstance(v, str) and '/' in v: + if v.startswith('https://') or v.startswith('http://'): + continue + else: + res_dict[k] = os.path.join(*(v.split('/'))) diff --git a/paddlespeech/s2t/__init__.py b/paddlespeech/s2t/__init__.py index 2365071f3..2da68435c 100644 --- a/paddlespeech/s2t/__init__.py +++ b/paddlespeech/s2t/__init__.py @@ -189,25 +189,6 @@ if not hasattr(paddle.Tensor, 'contiguous'): paddle.static.Variable.contiguous = contiguous -def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor: - nargs = len(args) - assert (nargs <= 1) - s = paddle.shape(xs) - if nargs == 1: - return s[args[0]] - else: - return s - - -#`to_static` do not process `size` property, maybe some `paddle` api dependent on it. -logger.debug( - "override size of paddle.Tensor " - "(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!" -) -paddle.Tensor.size = size -paddle.static.Variable.size = size - - def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor: return xs.reshape(args) @@ -219,7 +200,7 @@ if not hasattr(paddle.Tensor, 'view'): def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor: - return xs.reshape(ys.size()) + return xs.reshape(paddle.shape(ys)) if not hasattr(paddle.Tensor, 'view_as'): diff --git a/paddlespeech/s2t/decoders/beam_search/beam_search.py b/paddlespeech/s2t/decoders/beam_search/beam_search.py index f331cb1c9..f6a2b4b0a 100644 --- a/paddlespeech/s2t/decoders/beam_search/beam_search.py +++ b/paddlespeech/s2t/decoders/beam_search/beam_search.py @@ -194,7 +194,7 @@ class BeamSearch(paddle.nn.Layer): Args: hyp (Hypothesis): Hypothesis with prefix tokens to score - ids (paddle.Tensor): 1D tensor of new partial tokens to score, + ids (paddle.Tensor): 1D tensor of new partial tokens to score, len(ids) < n_vocab x (paddle.Tensor): Corresponding input feature, (T, D) @@ -224,14 +224,14 @@ class BeamSearch(paddle.nn.Layer): ids (paddle.Tensor): The partial token ids(Global) to compute topk. Returns: - Tuple[paddle.Tensor, paddle.Tensor]: + Tuple[paddle.Tensor, paddle.Tensor]: The topk full token ids and partial token ids. Their shapes are `(self.beam_size,)`. i.e. (global ids, global relative local ids). """ # no pre beam performed, `ids` equal to `weighted_scores` - if weighted_scores.size(0) == ids.size(0): + if paddle.shape(weighted_scores)[0] == paddle.shape(ids)[0]: top_ids = weighted_scores.topk( self.beam_size)[1] # index in n_vocab return top_ids, top_ids @@ -370,13 +370,13 @@ class BeamSearch(paddle.nn.Layer): """ # set length bounds if maxlenratio == 0: - maxlen = x.shape[0] + maxlen = paddle.shape(x)[0] elif maxlenratio < 0: maxlen = -1 * int(maxlenratio) else: - maxlen = max(1, int(maxlenratio * x.size(0))) - minlen = int(minlenratio * x.size(0)) - logger.info("decoder input length: " + str(x.shape[0])) + maxlen = max(1, int(maxlenratio * paddle.shape(x)[0])) + minlen = int(minlenratio * paddle.shape(x)[0]) + logger.info("decoder input length: " + str(paddle.shape(x)[0])) logger.info("max output length: " + str(maxlen)) logger.info("min output length: " + str(minlen)) diff --git a/paddlespeech/s2t/decoders/scorers/ctc.py b/paddlespeech/s2t/decoders/scorers/ctc.py index 81d8b0783..3c1d4cf80 100644 --- a/paddlespeech/s2t/decoders/scorers/ctc.py +++ b/paddlespeech/s2t/decoders/scorers/ctc.py @@ -69,7 +69,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface): return sc[i], st[i] else: # for CTCPrefixScorePD (need new_id > 0) r, log_psi, f_min, f_max, scoring_idmap = state - s = log_psi[i, new_id].expand(log_psi.size(1)) + s = log_psi[i, new_id].expand(paddle.shape(log_psi)[1]) if scoring_idmap is not None: return r[:, :, i, scoring_idmap[i, new_id]], s, f_min, f_max else: @@ -107,7 +107,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface): """ logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1 - xlen = paddle.to_tensor([logp.size(1)]) + xlen = paddle.to_tensor([paddle.shape(logp)[1]]) self.impl = CTCPrefixScorePD(logp, xlen, 0, self.eos) return None diff --git a/paddlespeech/s2t/decoders/scorers/ctc_prefix_score.py b/paddlespeech/s2t/decoders/scorers/ctc_prefix_score.py index 78b8fe36c..a994412e0 100644 --- a/paddlespeech/s2t/decoders/scorers/ctc_prefix_score.py +++ b/paddlespeech/s2t/decoders/scorers/ctc_prefix_score.py @@ -33,9 +33,9 @@ class CTCPrefixScorePD(): self.logzero = -10000000000.0 self.blank = blank self.eos = eos - self.batch = x.size(0) - self.input_length = x.size(1) - self.odim = x.size(2) + self.batch = paddle.shape(x)[0] + self.input_length = paddle.shape(x)[1] + self.odim = paddle.shape(x)[2] self.dtype = x.dtype # Pad the rest of posteriors in the batch @@ -76,8 +76,8 @@ class CTCPrefixScorePD(): last_ids = [yi[-1] for yi in y] # last output label ids n_bh = len(last_ids) # batch * hyps n_hyps = n_bh // self.batch # assuming each utterance has the same # of hyps - self.scoring_num = scoring_ids.size( - -1) if scoring_ids is not None else 0 + self.scoring_num = paddle.shape(scoring_ids)[ + -1] if scoring_ids is not None else 0 # prepare state info if state is None: r_prev = paddle.full( @@ -153,7 +153,7 @@ class CTCPrefixScorePD(): # compute forward probabilities log(r_t^n(h)) and log(r_t^b(h)) for t in range(start, end): - rp = r[t - 1] # (2 x BW x O') + rp = r[t - 1] # (2 x BW x O') rr = paddle.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view( 2, 2, n_bh, snum) # (2,2,BW,O') r[t] = paddle.logsumexp(rr, 1) + x_[:, t] @@ -227,7 +227,7 @@ class CTCPrefixScorePD(): if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O) # Pad the rest of posteriors in the batch # TODO(takaaki-hori): need a better way without for-loops - xlens = [x.size(1)] + xlens = [paddle.shape(x)[1]] for i, l in enumerate(xlens): if l < self.input_length: x[i, l:, :] = self.logzero @@ -237,7 +237,7 @@ class CTCPrefixScorePD(): xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim) self.x = paddle.stack([xn, xb]) # (2, T, B, O) self.x[:, :tmp_x.shape[1], :, :] = tmp_x - self.input_length = x.size(1) + self.input_length = paddle.shape(x)[1] self.end_frames = paddle.to_tensor(xlens) - 1 def extend_state(self, state): @@ -318,16 +318,16 @@ class CTCPrefixScore(): r[0, 0] = xs[0] r[0, 1] = self.logzero else: - # Although the code does not exactly follow Algorithm 2, - # we don't have to change it because we can assume - # r_t(h)=0 for t < |h| in CTC forward computation + # Although the code does not exactly follow Algorithm 2, + # we don't have to change it because we can assume + # r_t(h)=0 for t < |h| in CTC forward computation # (Note: we assume here that index t starts with 0). # The purpose of this difference is to reduce the number of for-loops. # https://github.com/espnet/espnet/pull/3655 - # where we start to accumulate r_t(h) from t=|h| - # and iterate r_t(h) = (r_{t-1}(h) + ...) to T-1, + # where we start to accumulate r_t(h) from t=|h| + # and iterate r_t(h) = (r_{t-1}(h) + ...) to T-1, # avoiding accumulating zeros for t=1~|h|-1. - # Thus, we need to set r_{|h|-1}(h) = 0, + # Thus, we need to set r_{|h|-1}(h) = 0, # i.e., r[output_length-1] = logzero, for initialization. # This is just for reducing the computation. r[output_length - 1] = self.logzero diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/export.py b/paddlespeech/s2t/exps/deepspeech2/bin/export.py index ee013d79e..049e7b688 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/export.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/export.py @@ -32,13 +32,16 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() - # save jit model to + # save jit model to parser.add_argument( "--export_path", type=str, help="path of the jit model to save") parser.add_argument( - "--model_type", type=str, default='offline', help="offline/online") + '--nxpu', + type=int, + default=0, + choices=[0, 1], + help="if nxpu == 0 and ngpu == 0, use cpu.") args = parser.parse_args() - print("model_type:{}".format(args.model_type)) print_arguments(args) # https://yaml.org/type/float.html diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/test.py b/paddlespeech/s2t/exps/deepspeech2/bin/test.py index 388b380d1..a9828f6e7 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/test.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/test.py @@ -32,14 +32,17 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() - parser.add_argument( - "--model_type", type=str, default='offline', help='offline/online') - # save asr result to + # save asr result to parser.add_argument( "--result_file", type=str, help="path of save the asr result") + parser.add_argument( + '--nxpu', + type=int, + default=0, + choices=[0, 1], + help="if nxpu == 0 and ngpu == 0, use cpu.") args = parser.parse_args() print_arguments(args, globals()) - print("model_type:{}".format(args.model_type)) # https://yaml.org/type/float.html config = CfgNode(new_allowed=True) diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py b/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py index 707eb9e1b..8db081e7b 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py @@ -39,12 +39,15 @@ if __name__ == "__main__": parser.add_argument( "--export_path", type=str, help="path of the jit model to save") parser.add_argument( - "--model_type", type=str, default='offline', help='offline/online') + '--nxpu', + type=int, + default=0, + choices=[0, 1], + help="if nxpu == 0 and ngpu == 0, use cpu.") parser.add_argument( "--enable-auto-log", action="store_true", help="use auto log") args = parser.parse_args() print_arguments(args, globals()) - print("model_type:{}".format(args.model_type)) # https://yaml.org/type/float.html config = CfgNode(new_allowed=True) diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py b/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py index a909dd416..90b7d8a18 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py @@ -23,7 +23,6 @@ from yacs.config import CfgNode from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.models.ds2 import DeepSpeech2Model -from paddlespeech.s2t.models.ds2_online import DeepSpeech2ModelOnline from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils import mp_tools from paddlespeech.s2t.utils.checkpoint import Checkpoint @@ -113,12 +112,7 @@ class DeepSpeech2Tester_hub(): config.input_dim = self.collate_fn_test.feature_size config.output_dim = self.collate_fn_test.vocab_size - if self.args.model_type == 'offline': - model = DeepSpeech2Model.from_config(config) - elif self.args.model_type == 'online': - model = DeepSpeech2ModelOnline.from_config(config) - else: - raise Exception("wrong model type") + model = DeepSpeech2Model.from_config(config) self.model = model @@ -172,8 +166,6 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() - parser.add_argument( - "--model_type", type=str, default='offline', help='offline/online') parser.add_argument("--audio_file", type=str, help='audio file path') # save asr result to parser.add_argument( @@ -184,7 +176,6 @@ if __name__ == "__main__": print("Please input the audio file path") sys.exit(-1) check(args.audio_file) - print("model_type:{}".format(args.model_type)) # https://yaml.org/type/float.html config = CfgNode(new_allowed=True) diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/train.py b/paddlespeech/s2t/exps/deepspeech2/bin/train.py index e2c68d4be..fee7079d9 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/train.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/train.py @@ -32,9 +32,12 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() parser.add_argument( - "--model_type", type=str, default='offline', help='offline/online') + '--nxpu', + type=int, + default=0, + choices=[0, 1], + help="if nxpu == 0 and ngpu == 0, use cpu.") args = parser.parse_args() - print("model_type:{}".format(args.model_type)) print_arguments(args, globals()) # https://yaml.org/type/float.html diff --git a/paddlespeech/s2t/exps/deepspeech2/model.py b/paddlespeech/s2t/exps/deepspeech2/model.py index 3c2eaab72..511997a7c 100644 --- a/paddlespeech/s2t/exps/deepspeech2/model.py +++ b/paddlespeech/s2t/exps/deepspeech2/model.py @@ -22,17 +22,11 @@ import numpy as np import paddle from paddle import distributed as dist from paddle import inference -from paddle.io import DataLoader from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer -from paddlespeech.s2t.io.collator import SpeechCollator -from paddlespeech.s2t.io.dataset import ManifestDataset -from paddlespeech.s2t.io.sampler import SortagradBatchSampler -from paddlespeech.s2t.io.sampler import SortagradDistributedBatchSampler +from paddlespeech.s2t.io.dataloader import BatchDataLoader from paddlespeech.s2t.models.ds2 import DeepSpeech2InferModel from paddlespeech.s2t.models.ds2 import DeepSpeech2Model -from paddlespeech.s2t.models.ds2_online import DeepSpeech2InferModelOnline -from paddlespeech.s2t.models.ds2_online import DeepSpeech2ModelOnline from paddlespeech.s2t.training.gradclip import ClipGradByGlobalNormWithLog from paddlespeech.s2t.training.reporter import report from paddlespeech.s2t.training.timer import Timer @@ -136,18 +130,13 @@ class DeepSpeech2Trainer(Trainer): config = self.config.clone() with UpdateConfig(config): if self.train: - config.input_dim = self.train_loader.collate_fn.feature_size - config.output_dim = self.train_loader.collate_fn.vocab_size + config.input_dim = self.train_loader.feat_dim + config.output_dim = self.train_loader.vocab_size else: - config.input_dim = self.test_loader.collate_fn.feature_size - config.output_dim = self.test_loader.collate_fn.vocab_size + config.input_dim = self.test_loader.feat_dim + config.output_dim = self.test_loader.vocab_size - if self.args.model_type == 'offline': - model = DeepSpeech2Model.from_config(config) - elif self.args.model_type == 'online': - model = DeepSpeech2ModelOnline.from_config(config) - else: - raise Exception("wrong model type") + model = DeepSpeech2Model.from_config(config) if self.parallel: model = paddle.DataParallel(model) @@ -175,76 +164,80 @@ class DeepSpeech2Trainer(Trainer): config = self.config.clone() config.defrost() if self.train: - # train - config.manifest = config.train_manifest - train_dataset = ManifestDataset.from_config(config) - if self.parallel: - batch_sampler = SortagradDistributedBatchSampler( - train_dataset, - batch_size=config.batch_size, - num_replicas=None, - rank=None, - shuffle=True, - drop_last=True, - sortagrad=config.sortagrad, - shuffle_method=config.shuffle_method) - else: - batch_sampler = SortagradBatchSampler( - train_dataset, - shuffle=True, - batch_size=config.batch_size, - drop_last=True, - sortagrad=config.sortagrad, - shuffle_method=config.shuffle_method) - - config.keep_transcription_text = False - collate_fn_train = SpeechCollator.from_config(config) - self.train_loader = DataLoader( - train_dataset, - batch_sampler=batch_sampler, - collate_fn=collate_fn_train, - num_workers=config.num_workers) - - # dev - config.manifest = config.dev_manifest - dev_dataset = ManifestDataset.from_config(config) - - config.augmentation_config = "" - config.keep_transcription_text = False - collate_fn_dev = SpeechCollator.from_config(config) - self.valid_loader = DataLoader( - dev_dataset, - batch_size=int(config.batch_size), - shuffle=False, - drop_last=False, - collate_fn=collate_fn_dev, - num_workers=config.num_workers) - logger.info("Setup train/valid Dataloader!") + # train/valid dataset, return token ids + self.train_loader = BatchDataLoader( + json_file=config.train_manifest, + train_mode=True, + sortagrad=config.sortagrad, + batch_size=config.batch_size, + maxlen_in=config.maxlen_in, + maxlen_out=config.maxlen_out, + minibatches=config.minibatches, + mini_batch_size=self.args.ngpu, + batch_count=config.batch_count, + batch_bins=config.batch_bins, + batch_frames_in=config.batch_frames_in, + batch_frames_out=config.batch_frames_out, + batch_frames_inout=config.batch_frames_inout, + preprocess_conf=config.preprocess_config, + n_iter_processes=config.num_workers, + subsampling_factor=1, + num_encs=1, + dist_sampler=config.get('dist_sampler', False), + shortest_first=False) + + self.valid_loader = BatchDataLoader( + json_file=config.dev_manifest, + train_mode=False, + sortagrad=False, + batch_size=config.batch_size, + maxlen_in=float('inf'), + maxlen_out=float('inf'), + minibatches=0, + mini_batch_size=self.args.ngpu, + batch_count='auto', + batch_bins=0, + batch_frames_in=0, + batch_frames_out=0, + batch_frames_inout=0, + preprocess_conf=config.preprocess_config, + n_iter_processes=config.num_workers, + subsampling_factor=1, + num_encs=1, + dist_sampler=config.get('dist_sampler', False), + shortest_first=False) + logger.info("Setup train/valid Dataloader!") else: - # test - config.manifest = config.test_manifest - test_dataset = ManifestDataset.from_config(config) - - config.augmentation_config = "" - config.keep_transcription_text = True - collate_fn_test = SpeechCollator.from_config(config) decode_batch_size = config.get('decode', dict()).get( 'decode_batch_size', 1) - self.test_loader = DataLoader( - test_dataset, + # test dataset, return raw text + self.test_loader = BatchDataLoader( + json_file=config.test_manifest, + train_mode=False, + sortagrad=False, batch_size=decode_batch_size, - shuffle=False, - drop_last=False, - collate_fn=collate_fn_test, - num_workers=config.num_workers) - logger.info("Setup test Dataloader!") + maxlen_in=float('inf'), + maxlen_out=float('inf'), + minibatches=0, + mini_batch_size=1, + batch_count='auto', + batch_bins=0, + batch_frames_in=0, + batch_frames_out=0, + batch_frames_inout=0, + preprocess_conf=config.preprocess_config, + n_iter_processes=1, + subsampling_factor=1, + num_encs=1) + logger.info("Setup test/align Dataloader!") class DeepSpeech2Tester(DeepSpeech2Trainer): def __init__(self, config, args): super().__init__(config, args) self._text_featurizer = TextFeaturizer( - unit_type=config.unit_type, vocab=None) + unit_type=config.unit_type, vocab=config.vocab_filepath) + self.vocab_list = self._text_featurizer.vocab_list def ordid2token(self, texts, texts_len): """ ord() id to chr() chr """ @@ -252,7 +245,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): for text, n in zip(texts, texts_len): n = n.numpy().item() ids = text[:n] - trans.append(''.join([chr(i) for i in ids])) + trans.append( + self._text_featurizer.defeaturize(ids.numpy().tolist())) return trans def compute_metrics(self, @@ -307,8 +301,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): # Initialized the decoder in model decode_cfg = self.config.decode - vocab_list = self.test_loader.collate_fn.vocab_list - decode_batch_size = self.test_loader.batch_size + vocab_list = self.vocab_list + decode_batch_size = decode_cfg.decode_batch_size self.model.decoder.init_decoder( decode_batch_size, vocab_list, decode_cfg.decoding_method, decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta, @@ -338,17 +332,9 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): @paddle.no_grad() def export(self): - if self.args.model_type == 'offline': - infer_model = DeepSpeech2InferModel.from_pretrained( - self.test_loader, self.config, self.args.checkpoint_path) - elif self.args.model_type == 'online': - infer_model = DeepSpeech2InferModelOnline.from_pretrained( - self.test_loader, self.config, self.args.checkpoint_path) - else: - raise Exception("wrong model type") - + infer_model = DeepSpeech2InferModel.from_pretrained( + self.test_loader, self.config, self.args.checkpoint_path) infer_model.eval() - feat_dim = self.test_loader.collate_fn.feature_size static_model = infer_model.export() logger.info(f"Export code: {static_model.forward.code}") paddle.jit.save(static_model, self.args.export_path) @@ -376,10 +362,10 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): # Initialized the decoder in model decode_cfg = self.config.decode - vocab_list = self.test_loader.collate_fn.vocab_list - if self.args.model_type == "online": + vocab_list = self.vocab_list + if self.config.rnn_direction == "forward": decode_batch_size = 1 - elif self.args.model_type == "offline": + elif self.config.rnn_direction == "bidirect": decode_batch_size = self.test_loader.batch_size else: raise Exception("wrong model type") @@ -412,11 +398,11 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): self.model.decoder.del_decoder() def compute_result_transcripts(self, audio, audio_len): - if self.args.model_type == "online": + if self.config.rnn_direction == "forward": output_probs, output_lens, trans_batch = self.static_forward_online( audio, audio_len, decoder_chunk_size=1) result_transcripts = [trans[-1] for trans in trans_batch] - elif self.args.model_type == "offline": + elif self.config.rnn_direction == "bidirect": output_probs, output_lens = self.static_forward_offline(audio, audio_len) batch_size = output_probs.shape[0] diff --git a/paddlespeech/s2t/frontend/featurizer/audio_featurizer.py b/paddlespeech/s2t/frontend/featurizer/audio_featurizer.py index 22329d5e0..ac5720fd5 100644 --- a/paddlespeech/s2t/frontend/featurizer/audio_featurizer.py +++ b/paddlespeech/s2t/frontend/featurizer/audio_featurizer.py @@ -14,10 +14,11 @@ """Contains the audio featurizer class.""" import numpy as np import paddle -import paddleaudio.compliance.kaldi as kaldi from python_speech_features import delta from python_speech_features import mfcc +import paddlespeech.audio.compliance.kaldi as kaldi + class AudioFeaturizer(): """Audio featurizer, for extracting features from audio contents of diff --git a/paddlespeech/s2t/io/dataset.py b/paddlespeech/s2t/io/dataset.py index 0e94f047b..9987b5110 100644 --- a/paddlespeech/s2t/io/dataset.py +++ b/paddlespeech/s2t/io/dataset.py @@ -1,4 +1,5 @@ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2021 Mobvoi Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/paddlespeech/s2t/io/sampler.py b/paddlespeech/s2t/io/sampler.py index 89752bb9f..ac55af123 100644 --- a/paddlespeech/s2t/io/sampler.py +++ b/paddlespeech/s2t/io/sampler.py @@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False): """ rng = np.random.RandomState(epoch) shift_len = rng.randint(0, batch_size - 1) - batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size)) + batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) rng.shuffle(batch_indices) batch_indices = [item for batch in batch_indices for item in batch] assert clipped is False diff --git a/paddlespeech/s2t/models/ds2/__init__.py b/paddlespeech/s2t/models/ds2/__init__.py index b32220673..480f6d3af 100644 --- a/paddlespeech/s2t/models/ds2/__init__.py +++ b/paddlespeech/s2t/models/ds2/__init__.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys + from .deepspeech2 import DeepSpeech2InferModel from .deepspeech2 import DeepSpeech2Model from paddlespeech.s2t.utils import dynamic_pip_install @@ -20,7 +22,8 @@ try: except ImportError: try: package_name = 'paddlespeech_ctcdecoders' - dynamic_pip_install.install(package_name) + if sys.platform != "win32": + dynamic_pip_install.install(package_name) except Exception: raise RuntimeError( "Can not install package paddlespeech_ctcdecoders on your system. \ diff --git a/paddlespeech/s2t/models/ds2/conv.py b/paddlespeech/s2t/models/ds2/conv.py index 4e766e793..448d4d1bb 100644 --- a/paddlespeech/s2t/models/ds2/conv.py +++ b/paddlespeech/s2t/models/ds2/conv.py @@ -11,161 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from paddle import nn -from paddle.nn import functional as F +import paddle -from paddlespeech.s2t.modules.activation import brelu -from paddlespeech.s2t.modules.mask import make_non_pad_mask -from paddlespeech.s2t.utils.log import Log +from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling4 -logger = Log(__name__).getlog() -__all__ = ['ConvStack', "conv_output_size"] +class Conv2dSubsampling4Pure(Conv2dSubsampling4): + def __init__(self, idim: int, odim: int, dropout_rate: float): + super().__init__(idim, odim, dropout_rate, None) + self.output_dim = ((idim - 1) // 2 - 1) // 2 * odim + self.receptive_field_length = 2 * ( + 3 - 1) + 3 # stride_1 * (kernel_size_2 - 1) + kerel_size_1 - -def conv_output_size(I, F, P, S): - # https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters - # Output size after Conv: - # By noting I the length of the input volume size, - # F the length of the filter, - # P the amount of zero padding, - # S the stride, - # then the output size O of the feature map along that dimension is given by: - # O = (I - F + Pstart + Pend) // S + 1 - # When Pstart == Pend == P, we can replace Pstart + Pend by 2P. - # When Pstart == Pend == 0 - # O = (I - F - S) // S - # https://iq.opengenus.org/output-size-of-convolution/ - # Output height = (Input height + padding height top + padding height bottom - kernel height) / (stride height) + 1 - # Output width = (Output width + padding width right + padding width left - kernel width) / (stride width) + 1 - return (I - F + 2 * P - S) // S - - -# receptive field calculator -# https://fomoro.com/research/article/receptive-field-calculator -# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters -# https://distill.pub/2019/computing-receptive-fields/ -# Rl-1 = Sl * Rl + (Kl - Sl) - - -class ConvBn(nn.Layer): - """Convolution layer with batch normalization. - - :param kernel_size: The x dimension of a filter kernel. Or input a tuple for - two image dimension. - :type kernel_size: int|tuple|list - :param num_channels_in: Number of input channels. - :type num_channels_in: int - :param num_channels_out: Number of output channels. - :type num_channels_out: int - :param stride: The x dimension of the stride. Or input a tuple for two - image dimension. - :type stride: int|tuple|list - :param padding: The x dimension of the padding. Or input a tuple for two - image dimension. - :type padding: int|tuple|list - :param act: Activation type, relu|brelu - :type act: string - :return: Batch norm layer after convolution layer. - :rtype: Variable - - """ - - def __init__(self, num_channels_in, num_channels_out, kernel_size, stride, - padding, act): - - super().__init__() - assert len(kernel_size) == 2 - assert len(stride) == 2 - assert len(padding) == 2 - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - - self.conv = nn.Conv2D( - num_channels_in, - num_channels_out, - kernel_size=kernel_size, - stride=stride, - padding=padding, - weight_attr=None, - bias_attr=False, - data_format='NCHW') - - self.bn = nn.BatchNorm2D( - num_channels_out, - weight_attr=None, - bias_attr=None, - data_format='NCHW') - self.act = F.relu if act == 'relu' else brelu - - def forward(self, x, x_len): - """ - x(Tensor): audio, shape [B, C, D, T] - """ + def forward(self, x: paddle.Tensor, + x_len: paddle.Tensor) -> [paddle.Tensor, paddle.Tensor]: + x = x.unsqueeze(1) # (b, c=1, t, f) x = self.conv(x) - x = self.bn(x) - x = self.act(x) - - x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1] - ) // self.stride[1] + 1 - - # reset padding part to 0 - masks = make_non_pad_mask(x_len) #[B, T] - masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] - # TODO(Hui Zhang): not support bool multiply - # masks = masks.type_as(x) - masks = masks.astype(x.dtype) - x = x.multiply(masks) - return x, x_len - - -class ConvStack(nn.Layer): - """Convolution group with stacked convolution layers. - - :param feat_size: audio feature dim. - :type feat_size: int - :param num_stacks: Number of stacked convolution layers. - :type num_stacks: int - """ - - def __init__(self, feat_size, num_stacks): - super().__init__() - self.feat_size = feat_size # D - self.num_stacks = num_stacks - - self.conv_in = ConvBn( - num_channels_in=1, - num_channels_out=32, - kernel_size=(41, 11), #[D, T] - stride=(2, 3), - padding=(20, 5), - act='brelu') - - out_channel = 32 - convs = [ - ConvBn( - num_channels_in=32, - num_channels_out=out_channel, - kernel_size=(21, 11), - stride=(2, 1), - padding=(10, 5), - act='brelu') for i in range(num_stacks - 1) - ] - self.conv_stack = nn.LayerList(convs) - - # conv output feat_dim - output_height = (feat_size - 1) // 2 + 1 - for i in range(self.num_stacks - 1): - output_height = (output_height - 1) // 2 + 1 - self.output_height = out_channel * output_height - - def forward(self, x, x_len): - """ - x: shape [B, C, D, T] - x_len : shape [B] - """ - x, x_len = self.conv_in(x, x_len) - for i, conv in enumerate(self.conv_stack): - x, x_len = conv(x, x_len) + #b, c, t, f = paddle.shape(x) #not work under jit + x = x.transpose([0, 2, 1, 3]).reshape([0, 0, -1]) + x_len = ((x_len - 1) // 2 - 1) // 2 return x, x_len diff --git a/paddlespeech/s2t/models/ds2/deepspeech2.py b/paddlespeech/s2t/models/ds2/deepspeech2.py index 9c6b66c25..b7ee80a7d 100644 --- a/paddlespeech/s2t/models/ds2/deepspeech2.py +++ b/paddlespeech/s2t/models/ds2/deepspeech2.py @@ -13,15 +13,14 @@ # limitations under the License. """Deepspeech2 ASR Model""" import paddle +import paddle.nn.functional as F from paddle import nn -from paddlespeech.s2t.models.ds2.conv import ConvStack -from paddlespeech.s2t.models.ds2.rnn import RNNStack +from paddlespeech.s2t.models.ds2.conv import Conv2dSubsampling4Pure from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils.checkpoint import Checkpoint from paddlespeech.s2t.utils.log import Log - logger = Log(__name__).getlog() __all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel'] @@ -32,72 +31,197 @@ class CRNNEncoder(nn.Layer): feat_size, dict_size, num_conv_layers=2, - num_rnn_layers=3, + num_rnn_layers=4, rnn_size=1024, - use_gru=False, - share_rnn_weights=True): + rnn_direction='forward', + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=False): super().__init__() self.rnn_size = rnn_size self.feat_size = feat_size # 161 for linear self.dict_size = dict_size - - self.conv = ConvStack(feat_size, num_conv_layers) - - i_size = self.conv.output_height # H after conv stack - self.rnn = RNNStack( - i_size=i_size, - h_size=rnn_size, - num_stacks=num_rnn_layers, - use_gru=use_gru, - share_rnn_weights=share_rnn_weights) + self.num_rnn_layers = num_rnn_layers + self.num_fc_layers = num_fc_layers + self.rnn_direction = rnn_direction + self.fc_layers_size_list = fc_layers_size_list + self.use_gru = use_gru + self.conv = Conv2dSubsampling4Pure(feat_size, 32, dropout_rate=0.0) + + self.output_dim = self.conv.output_dim + + i_size = self.conv.output_dim + self.rnn = nn.LayerList() + self.layernorm_list = nn.LayerList() + self.fc_layers_list = nn.LayerList() + if rnn_direction == 'bidirect' or rnn_direction == 'bidirectional': + layernorm_size = 2 * rnn_size + elif rnn_direction == 'forward': + layernorm_size = rnn_size + else: + raise Exception("Wrong rnn direction") + for i in range(0, num_rnn_layers): + if i == 0: + rnn_input_size = i_size + else: + rnn_input_size = layernorm_size + if use_gru is True: + self.rnn.append( + nn.GRU( + input_size=rnn_input_size, + hidden_size=rnn_size, + num_layers=1, + direction=rnn_direction)) + else: + self.rnn.append( + nn.LSTM( + input_size=rnn_input_size, + hidden_size=rnn_size, + num_layers=1, + direction=rnn_direction)) + self.layernorm_list.append(nn.LayerNorm(layernorm_size)) + self.output_dim = layernorm_size + + fc_input_size = layernorm_size + for i in range(self.num_fc_layers): + self.fc_layers_list.append( + nn.Linear(fc_input_size, fc_layers_size_list[i])) + fc_input_size = fc_layers_size_list[i] + self.output_dim = fc_layers_size_list[i] @property def output_size(self): - return self.rnn_size * 2 + return self.output_dim - def forward(self, audio, audio_len): + def forward(self, x, x_lens, init_state_h_box=None, init_state_c_box=None): """Compute Encoder outputs Args: - audio (Tensor): [B, Tmax, D] - text (Tensor): [B, Umax] - audio_len (Tensor): [B] - text_len (Tensor): [B] - Returns: + x (Tensor): [B, T, D] + x_lens (Tensor): [B] + init_state_h_box(Tensor): init_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + init_state_c_box(Tensor): init_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + Return: x (Tensor): encoder outputs, [B, T, D] x_lens (Tensor): encoder length, [B] + final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] """ - # [B, T, D] -> [B, D, T] - audio = audio.transpose([0, 2, 1]) - # [B, D, T] -> [B, C=1, D, T] - x = audio.unsqueeze(1) - x_lens = audio_len + if init_state_h_box is not None: + init_state_list = None + + if self.use_gru is True: + init_state_h_list = paddle.split( + init_state_h_box, self.num_rnn_layers, axis=0) + init_state_list = init_state_h_list + else: + init_state_h_list = paddle.split( + init_state_h_box, self.num_rnn_layers, axis=0) + init_state_c_list = paddle.split( + init_state_c_box, self.num_rnn_layers, axis=0) + init_state_list = [(init_state_h_list[i], init_state_c_list[i]) + for i in range(self.num_rnn_layers)] + else: + init_state_list = [None] * self.num_rnn_layers - # convolution group x, x_lens = self.conv(x, x_lens) + final_chunk_state_list = [] + for i in range(0, self.num_rnn_layers): + x, final_state = self.rnn[i](x, init_state_list[i], + x_lens) #[B, T, D] + final_chunk_state_list.append(final_state) + x = self.layernorm_list[i](x) + + for i in range(self.num_fc_layers): + x = self.fc_layers_list[i](x) + x = F.relu(x) + + if self.use_gru is True: + final_chunk_state_h_box = paddle.concat( + final_chunk_state_list, axis=0) + final_chunk_state_c_box = init_state_c_box + else: + final_chunk_state_h_list = [ + final_chunk_state_list[i][0] for i in range(self.num_rnn_layers) + ] + final_chunk_state_c_list = [ + final_chunk_state_list[i][1] for i in range(self.num_rnn_layers) + ] + final_chunk_state_h_box = paddle.concat( + final_chunk_state_h_list, axis=0) + final_chunk_state_c_box = paddle.concat( + final_chunk_state_c_list, axis=0) + + return x, x_lens, final_chunk_state_h_box, final_chunk_state_c_box + + def forward_chunk_by_chunk(self, x, x_lens, decoder_chunk_size=8): + """Compute Encoder outputs - # convert data from convolution feature map to sequence of vectors - #B, C, D, T = paddle.shape(x) # not work under jit - x = x.transpose([0, 3, 1, 2]) #[B, T, C, D] - #x = x.reshape([B, T, C * D]) #[B, T, C*D] # not work under jit - x = x.reshape([0, 0, -1]) #[B, T, C*D] - - # remove padding part - x, x_lens = self.rnn(x, x_lens) #[B, T, D] - return x, x_lens + Args: + x (Tensor): [B, T, D] + x_lens (Tensor): [B] + decoder_chunk_size: The chunk size of decoder + Returns: + eouts_list (List of Tensor): The list of encoder outputs in chunk_size: [B, chunk_size, D] * num_chunks + eouts_lens_list (List of Tensor): The list of encoder length in chunk_size: [B] * num_chunks + final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + """ + subsampling_rate = self.conv.subsampling_rate + receptive_field_length = self.conv.receptive_field_length + chunk_size = (decoder_chunk_size - 1 + ) * subsampling_rate + receptive_field_length + chunk_stride = subsampling_rate * decoder_chunk_size + max_len = x.shape[1] + assert (chunk_size <= max_len) + + eouts_chunk_list = [] + eouts_chunk_lens_list = [] + if (max_len - chunk_size) % chunk_stride != 0: + padding_len = chunk_stride - (max_len - chunk_size) % chunk_stride + else: + padding_len = 0 + padding = paddle.zeros((x.shape[0], padding_len, x.shape[2])) + padded_x = paddle.concat([x, padding], axis=1) + num_chunk = (max_len + padding_len - chunk_size) / chunk_stride + 1 + num_chunk = int(num_chunk) + chunk_state_h_box = None + chunk_state_c_box = None + final_state_h_box = None + final_state_c_box = None + for i in range(0, num_chunk): + start = i * chunk_stride + end = start + chunk_size + x_chunk = padded_x[:, start:end, :] + + x_len_left = paddle.where(x_lens - i * chunk_stride < 0, + paddle.zeros_like(x_lens), + x_lens - i * chunk_stride) + x_chunk_len_tmp = paddle.ones_like(x_lens) * chunk_size + x_chunk_lens = paddle.where(x_len_left < x_chunk_len_tmp, + x_len_left, x_chunk_len_tmp) + + eouts_chunk, eouts_chunk_lens, chunk_state_h_box, chunk_state_c_box = self.forward( + x_chunk, x_chunk_lens, chunk_state_h_box, chunk_state_c_box) + + eouts_chunk_list.append(eouts_chunk) + eouts_chunk_lens_list.append(eouts_chunk_lens) + final_state_h_box = chunk_state_h_box + final_state_c_box = chunk_state_c_box + return eouts_chunk_list, eouts_chunk_lens_list, final_state_h_box, final_state_c_box class DeepSpeech2Model(nn.Layer): """The DeepSpeech2 network structure. - :param audio_data: Audio spectrogram data layer. - :type audio_data: Variable - :param text_data: Transcription text data layer. - :type text_data: Variable + :param audio: Audio spectrogram data layer. + :type audio: Variable + :param text: Transcription text data layer. + :type text: Variable :param audio_len: Valid sequence length data layer. :type audio_len: Variable - :param masks: Masks data layer to reset padding. - :type masks: Variable + :param feat_size: feature size for audio. + :type feat_size: int :param dict_size: Dictionary size for tokenized transcription. :type dict_size: int :param num_conv_layers: Number of stacking convolution layers. @@ -106,37 +230,41 @@ class DeepSpeech2Model(nn.Layer): :type num_rnn_layers: int :param rnn_size: RNN layer size (dimension of RNN cells). :type rnn_size: int + :param num_fc_layers: Number of stacking FC layers. + :type num_fc_layers: int + :param fc_layers_size_list: The list of FC layer sizes. + :type fc_layers_size_list: [int,] :param use_gru: Use gru if set True. Use simple rnn if set False. :type use_gru: bool - :param share_rnn_weights: Whether to share input-hidden weights between - forward and backward direction RNNs. - It is only available when use_gru=False. - :type share_weights: bool :return: A tuple of an output unnormalized log probability layer ( before softmax) and a ctc cost layer. :rtype: tuple of LayerOutput """ - def __init__(self, - feat_size, - dict_size, - num_conv_layers=2, - num_rnn_layers=3, - rnn_size=1024, - use_gru=False, - share_rnn_weights=True, - blank_id=0, - ctc_grad_norm_type=None): + def __init__( + self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=4, + rnn_size=1024, + rnn_direction='forward', + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=False, + blank_id=0, + ctc_grad_norm_type=None, ): super().__init__() self.encoder = CRNNEncoder( feat_size=feat_size, dict_size=dict_size, num_conv_layers=num_conv_layers, num_rnn_layers=num_rnn_layers, + rnn_direction=rnn_direction, + num_fc_layers=num_fc_layers, + fc_layers_size_list=fc_layers_size_list, rnn_size=rnn_size, - use_gru=use_gru, - share_rnn_weights=share_rnn_weights) - assert (self.encoder.output_size == rnn_size * 2) + use_gru=use_gru) self.decoder = CTCDecoder( odim=dict_size, # is in vocab @@ -151,7 +279,7 @@ class DeepSpeech2Model(nn.Layer): """Compute Model loss Args: - audio (Tensors): [B, T, D] + audio (Tensor): [B, T, D] audio_len (Tensor): [B] text (Tensor): [B, U] text_len (Tensor): [B] @@ -159,22 +287,22 @@ class DeepSpeech2Model(nn.Layer): Returns: loss (Tensor): [1] """ - eouts, eouts_len = self.encoder(audio, audio_len) + eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder( + audio, audio_len, None, None) loss = self.decoder(eouts, eouts_len, text, text_len) return loss @paddle.no_grad() def decode(self, audio, audio_len): # decoders only accept string encoded in utf-8 - # Make sure the decoder has been initialized - eouts, eouts_len = self.encoder(audio, audio_len) + eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder( + audio, audio_len, None, None) probs = self.decoder.softmax(eouts) batch_size = probs.shape[0] self.decoder.reset_decoder(batch_size=batch_size) self.decoder.next(probs, eouts_len) trans_best, trans_beam = self.decoder.decode() - return trans_best @classmethod @@ -196,13 +324,15 @@ class DeepSpeech2Model(nn.Layer): The model built from pretrained result. """ model = cls( - feat_size=dataloader.collate_fn.feature_size, - dict_size=dataloader.collate_fn.vocab_size, + feat_size=dataloader.feat_dim, + dict_size=dataloader.vocab_size, num_conv_layers=config.num_conv_layers, num_rnn_layers=config.num_rnn_layers, rnn_size=config.rnn_layer_size, + rnn_direction=config.rnn_direction, + num_fc_layers=config.num_fc_layers, + fc_layers_size_list=config.fc_layers_size_list, use_gru=config.use_gru, - share_rnn_weights=config.share_rnn_weights, blank_id=config.blank_id, ctc_grad_norm_type=config.get('ctc_grad_norm_type', None), ) infos = Checkpoint().load_parameters( @@ -229,8 +359,10 @@ class DeepSpeech2Model(nn.Layer): num_conv_layers=config.num_conv_layers, num_rnn_layers=config.num_rnn_layers, rnn_size=config.rnn_layer_size, + rnn_direction=config.rnn_direction, + num_fc_layers=config.num_fc_layers, + fc_layers_size_list=config.fc_layers_size_list, use_gru=config.use_gru, - share_rnn_weights=config.share_rnn_weights, blank_id=config.blank_id, ctc_grad_norm_type=config.get('ctc_grad_norm_type', None), ) return model @@ -240,28 +372,50 @@ class DeepSpeech2InferModel(DeepSpeech2Model): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def forward(self, audio, audio_len): - """export model function - - Args: - audio (Tensor): [B, T, D] - audio_len (Tensor): [B] - - Returns: - probs: probs after softmax - """ - eouts, eouts_len = self.encoder(audio, audio_len) - probs = self.decoder.softmax(eouts) - return probs, eouts_len + def forward(self, + audio_chunk, + audio_chunk_lens, + chunk_state_h_box=None, + chunk_state_c_box=None): + if self.encoder.rnn_direction == "forward": + eouts_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box = self.encoder( + audio_chunk, audio_chunk_lens, chunk_state_h_box, + chunk_state_c_box) + probs_chunk = self.decoder.softmax(eouts_chunk) + return probs_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box + elif self.encoder.rnn_direction == "bidirect": + eouts, eouts_len, _, _ = self.encoder(audio_chunk, audio_chunk_lens) + probs = self.decoder.softmax(eouts) + return probs, eouts_len + else: + raise Exception("wrong model type") def export(self): - static_model = paddle.jit.to_static( - self, - input_spec=[ - paddle.static.InputSpec( - shape=[None, None, self.encoder.feat_size], - dtype='float32'), # audio, [B,T,D] - paddle.static.InputSpec(shape=[None], - dtype='int64'), # audio_length, [B] - ]) + if self.encoder.rnn_direction == "forward": + static_model = paddle.jit.to_static( + self, + input_spec=[ + paddle.static.InputSpec( + shape=[None, None, self.encoder.feat_size + ], #[B, chunk_size, feat_dim] + dtype='float32'), + paddle.static.InputSpec(shape=[None], + dtype='int64'), # audio_length, [B] + paddle.static.InputSpec( + shape=[None, None, None], dtype='float32'), + paddle.static.InputSpec( + shape=[None, None, None], dtype='float32') + ]) + elif self.encoder.rnn_direction == "bidirect": + static_model = paddle.jit.to_static( + self, + input_spec=[ + paddle.static.InputSpec( + shape=[None, None, self.encoder.feat_size], + dtype='float32'), # audio, [B,T,D] + paddle.static.InputSpec(shape=[None], + dtype='int64'), # audio_length, [B] + ]) + else: + raise Exception("wrong model type") return static_model diff --git a/paddlespeech/s2t/models/ds2/rnn.py b/paddlespeech/s2t/models/ds2/rnn.py deleted file mode 100644 index f655b2d82..000000000 --- a/paddlespeech/s2t/models/ds2/rnn.py +++ /dev/null @@ -1,315 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math - -import paddle -from paddle import nn -from paddle.nn import functional as F -from paddle.nn import initializer as I - -from paddlespeech.s2t.modules.activation import brelu -from paddlespeech.s2t.modules.mask import make_non_pad_mask -from paddlespeech.s2t.utils.log import Log - -logger = Log(__name__).getlog() - -__all__ = ['RNNStack'] - - -class RNNCell(nn.RNNCellBase): - r""" - Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it - computes the outputs and updates states. - The formula used is as follows: - .. math:: - h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh}) - y_{t} & = h_{t} - - where :math:`act` is for :attr:`activation`. - """ - - def __init__(self, - hidden_size: int, - activation="tanh", - weight_ih_attr=None, - weight_hh_attr=None, - bias_ih_attr=None, - bias_hh_attr=None, - name=None): - super().__init__() - std = 1.0 / math.sqrt(hidden_size) - self.weight_hh = self.create_parameter( - (hidden_size, hidden_size), - weight_hh_attr, - default_initializer=I.Uniform(-std, std)) - self.bias_ih = None - self.bias_hh = self.create_parameter( - (hidden_size, ), - bias_hh_attr, - is_bias=True, - default_initializer=I.Uniform(-std, std)) - - self.hidden_size = hidden_size - if activation not in ["tanh", "relu", "brelu"]: - raise ValueError( - "activation for SimpleRNNCell should be tanh or relu, " - "but get {}".format(activation)) - self.activation = activation - self._activation_fn = paddle.tanh \ - if activation == "tanh" \ - else F.relu - if activation == 'brelu': - self._activation_fn = brelu - - def forward(self, inputs, states=None): - if states is None: - states = self.get_initial_states(inputs, self.state_shape) - pre_h = states - i2h = inputs - if self.bias_ih is not None: - i2h += self.bias_ih - h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True) - if self.bias_hh is not None: - h2h += self.bias_hh - h = self._activation_fn(i2h + h2h) - return h, h - - @property - def state_shape(self): - return (self.hidden_size, ) - - -class GRUCell(nn.RNNCellBase): - r""" - Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, - it computes the outputs and updates states. - The formula for GRU used is as follows: - .. math:: - r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr}) - z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz}) - \widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc})) - h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t} - y_{t} & = h_{t} - - where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise - multiplication operator. - """ - - def __init__(self, - input_size: int, - hidden_size: int, - weight_ih_attr=None, - weight_hh_attr=None, - bias_ih_attr=None, - bias_hh_attr=None, - name=None): - super().__init__() - std = 1.0 / math.sqrt(hidden_size) - self.weight_hh = self.create_parameter( - (3 * hidden_size, hidden_size), - weight_hh_attr, - default_initializer=I.Uniform(-std, std)) - self.bias_ih = None - self.bias_hh = self.create_parameter( - (3 * hidden_size, ), - bias_hh_attr, - is_bias=True, - default_initializer=I.Uniform(-std, std)) - - self.hidden_size = hidden_size - self.input_size = input_size - self._gate_activation = F.sigmoid - self._activation = paddle.tanh - - def forward(self, inputs, states=None): - if states is None: - states = self.get_initial_states(inputs, self.state_shape) - - pre_hidden = states - x_gates = inputs - if self.bias_ih is not None: - x_gates = x_gates + self.bias_ih - h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True) - if self.bias_hh is not None: - h_gates = h_gates + self.bias_hh - - x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1) - h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1) - - r = self._gate_activation(x_r + h_r) - z = self._gate_activation(x_z + h_z) - c = self._activation(x_c + r * h_c) # apply reset gate after mm - h = (pre_hidden - c) * z + c - # https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru - - return h, h - - @property - def state_shape(self): - r""" - The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch - size would be automatically inserted into shape). The shape corresponds - to the shape of :math:`h_{t-1}`. - """ - return (self.hidden_size, ) - - -class BiRNNWithBN(nn.Layer): - """Bidirectonal simple rnn layer with sequence-wise batch normalization. - The batch normalization is only performed on input-state weights. - - :param size: Dimension of RNN cells. - :type size: int - :param share_weights: Whether to share input-hidden weights between - forward and backward directional RNNs. - :type share_weights: bool - :return: Bidirectional simple rnn layer. - :rtype: Variable - """ - - def __init__(self, i_size: int, h_size: int, share_weights: bool): - super().__init__() - self.share_weights = share_weights - if self.share_weights: - #input-hidden weights shared between bi-directional rnn. - self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False) - # batch norm is only performed on input-state projection - self.fw_bn = nn.BatchNorm1D( - h_size, bias_attr=None, data_format='NLC') - self.bw_fc = self.fw_fc - self.bw_bn = self.fw_bn - else: - self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False) - self.fw_bn = nn.BatchNorm1D( - h_size, bias_attr=None, data_format='NLC') - self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False) - self.bw_bn = nn.BatchNorm1D( - h_size, bias_attr=None, data_format='NLC') - - self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu') - self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu') - self.fw_rnn = nn.RNN( - self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] - self.bw_rnn = nn.RNN( - self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] - - def forward(self, x: paddle.Tensor, x_len: paddle.Tensor): - # x, shape [B, T, D] - fw_x = self.fw_bn(self.fw_fc(x)) - bw_x = self.bw_bn(self.bw_fc(x)) - fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) - bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) - x = paddle.concat([fw_x, bw_x], axis=-1) - return x, x_len - - -class BiGRUWithBN(nn.Layer): - """Bidirectonal gru layer with sequence-wise batch normalization. - The batch normalization is only performed on input-state weights. - - :param name: Name of the layer. - :type name: string - :param input: Input layer. - :type input: Variable - :param size: Dimension of GRU cells. - :type size: int - :param act: Activation type. - :type act: string - :return: Bidirectional GRU layer. - :rtype: Variable - """ - - def __init__(self, i_size: int, h_size: int): - super().__init__() - hidden_size = h_size * 3 - - self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) - self.fw_bn = nn.BatchNorm1D( - hidden_size, bias_attr=None, data_format='NLC') - self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) - self.bw_bn = nn.BatchNorm1D( - hidden_size, bias_attr=None, data_format='NLC') - - self.fw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size) - self.bw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size) - self.fw_rnn = nn.RNN( - self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] - self.bw_rnn = nn.RNN( - self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] - - def forward(self, x, x_len): - # x, shape [B, T, D] - fw_x = self.fw_bn(self.fw_fc(x)) - bw_x = self.bw_bn(self.bw_fc(x)) - fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) - bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) - x = paddle.concat([fw_x, bw_x], axis=-1) - return x, x_len - - -class RNNStack(nn.Layer): - """RNN group with stacked bidirectional simple RNN or GRU layers. - - :param input: Input layer. - :type input: Variable - :param size: Dimension of RNN cells in each layer. - :type size: int - :param num_stacks: Number of stacked rnn layers. - :type num_stacks: int - :param use_gru: Use gru if set True. Use simple rnn if set False. - :type use_gru: bool - :param share_rnn_weights: Whether to share input-hidden weights between - forward and backward directional RNNs. - It is only available when use_gru=False. - :type share_weights: bool - :return: Output layer of the RNN group. - :rtype: Variable - """ - - def __init__(self, - i_size: int, - h_size: int, - num_stacks: int, - use_gru: bool, - share_rnn_weights: bool): - super().__init__() - rnn_stacks = [] - for i in range(num_stacks): - if use_gru: - #default:GRU using tanh - rnn_stacks.append(BiGRUWithBN(i_size=i_size, h_size=h_size)) - else: - rnn_stacks.append( - BiRNNWithBN( - i_size=i_size, - h_size=h_size, - share_weights=share_rnn_weights)) - i_size = h_size * 2 - - self.rnn_stacks = nn.LayerList(rnn_stacks) - - def forward(self, x: paddle.Tensor, x_len: paddle.Tensor): - """ - x: shape [B, T, D] - x_len: shpae [B] - """ - for i, rnn in enumerate(self.rnn_stacks): - x, x_len = rnn(x, x_len) - masks = make_non_pad_mask(x_len) #[B, T] - masks = masks.unsqueeze(-1) # [B, T, 1] - # TODO(Hui Zhang): not support bool multiply - masks = masks.astype(x.dtype) - x = x.multiply(masks) - - return x, x_len diff --git a/paddlespeech/s2t/models/ds2_online/__init__.py b/paddlespeech/s2t/models/ds2_online/__init__.py deleted file mode 100644 index c5fdab1bc..000000000 --- a/paddlespeech/s2t/models/ds2_online/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from .deepspeech2 import DeepSpeech2InferModelOnline -from .deepspeech2 import DeepSpeech2ModelOnline -from paddlespeech.s2t.utils import dynamic_pip_install - -try: - import paddlespeech_ctcdecoders -except ImportError: - try: - package_name = 'paddlespeech_ctcdecoders' - dynamic_pip_install.install(package_name) - except Exception: - raise RuntimeError( - "Can not install package paddlespeech_ctcdecoders on your system. \ - The DeepSpeech2 model is not supported for your system") - -__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline'] diff --git a/paddlespeech/s2t/models/ds2_online/conv.py b/paddlespeech/s2t/models/ds2_online/conv.py deleted file mode 100644 index 25a9715a3..000000000 --- a/paddlespeech/s2t/models/ds2_online/conv.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import paddle - -from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling4 - - -class Conv2dSubsampling4Online(Conv2dSubsampling4): - def __init__(self, idim: int, odim: int, dropout_rate: float): - super().__init__(idim, odim, dropout_rate, None) - self.output_dim = ((idim - 1) // 2 - 1) // 2 * odim - self.receptive_field_length = 2 * ( - 3 - 1) + 3 # stride_1 * (kernel_size_2 - 1) + kerel_size_1 - - def forward(self, x: paddle.Tensor, - x_len: paddle.Tensor) -> [paddle.Tensor, paddle.Tensor]: - x = x.unsqueeze(1) # (b, c=1, t, f) - x = self.conv(x) - #b, c, t, f = paddle.shape(x) #not work under jit - x = x.transpose([0, 2, 1, 3]).reshape([0, 0, -1]) - x_len = ((x_len - 1) // 2 - 1) // 2 - return x, x_len diff --git a/paddlespeech/s2t/models/ds2_online/deepspeech2.py b/paddlespeech/s2t/models/ds2_online/deepspeech2.py deleted file mode 100644 index 9574a62bd..000000000 --- a/paddlespeech/s2t/models/ds2_online/deepspeech2.py +++ /dev/null @@ -1,397 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Deepspeech2 ASR Online Model""" -import paddle -import paddle.nn.functional as F -from paddle import nn - -from paddlespeech.s2t.models.ds2_online.conv import Conv2dSubsampling4Online -from paddlespeech.s2t.modules.ctc import CTCDecoder -from paddlespeech.s2t.utils import layer_tools -from paddlespeech.s2t.utils.checkpoint import Checkpoint -from paddlespeech.s2t.utils.log import Log -logger = Log(__name__).getlog() - -__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline'] - - -class CRNNEncoder(nn.Layer): - def __init__(self, - feat_size, - dict_size, - num_conv_layers=2, - num_rnn_layers=4, - rnn_size=1024, - rnn_direction='forward', - num_fc_layers=2, - fc_layers_size_list=[512, 256], - use_gru=False): - super().__init__() - self.rnn_size = rnn_size - self.feat_size = feat_size # 161 for linear - self.dict_size = dict_size - self.num_rnn_layers = num_rnn_layers - self.num_fc_layers = num_fc_layers - self.rnn_direction = rnn_direction - self.fc_layers_size_list = fc_layers_size_list - self.use_gru = use_gru - self.conv = Conv2dSubsampling4Online(feat_size, 32, dropout_rate=0.0) - - self.output_dim = self.conv.output_dim - - i_size = self.conv.output_dim - self.rnn = nn.LayerList() - self.layernorm_list = nn.LayerList() - self.fc_layers_list = nn.LayerList() - if rnn_direction == 'bidirect' or rnn_direction == 'bidirectional': - layernorm_size = 2 * rnn_size - elif rnn_direction == 'forward': - layernorm_size = rnn_size - else: - raise Exception("Wrong rnn direction") - for i in range(0, num_rnn_layers): - if i == 0: - rnn_input_size = i_size - else: - rnn_input_size = layernorm_size - if use_gru is True: - self.rnn.append( - nn.GRU( - input_size=rnn_input_size, - hidden_size=rnn_size, - num_layers=1, - direction=rnn_direction)) - else: - self.rnn.append( - nn.LSTM( - input_size=rnn_input_size, - hidden_size=rnn_size, - num_layers=1, - direction=rnn_direction)) - self.layernorm_list.append(nn.LayerNorm(layernorm_size)) - self.output_dim = layernorm_size - - fc_input_size = layernorm_size - for i in range(self.num_fc_layers): - self.fc_layers_list.append( - nn.Linear(fc_input_size, fc_layers_size_list[i])) - fc_input_size = fc_layers_size_list[i] - self.output_dim = fc_layers_size_list[i] - - @property - def output_size(self): - return self.output_dim - - def forward(self, x, x_lens, init_state_h_box=None, init_state_c_box=None): - """Compute Encoder outputs - - Args: - x (Tensor): [B, T, D] - x_lens (Tensor): [B] - init_state_h_box(Tensor): init_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] - init_state_c_box(Tensor): init_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] - Return: - x (Tensor): encoder outputs, [B, T, D] - x_lens (Tensor): encoder length, [B] - final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] - final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] - """ - if init_state_h_box is not None: - init_state_list = None - - if self.use_gru is True: - init_state_h_list = paddle.split( - init_state_h_box, self.num_rnn_layers, axis=0) - init_state_list = init_state_h_list - else: - init_state_h_list = paddle.split( - init_state_h_box, self.num_rnn_layers, axis=0) - init_state_c_list = paddle.split( - init_state_c_box, self.num_rnn_layers, axis=0) - init_state_list = [(init_state_h_list[i], init_state_c_list[i]) - for i in range(self.num_rnn_layers)] - else: - init_state_list = [None] * self.num_rnn_layers - - x, x_lens = self.conv(x, x_lens) - final_chunk_state_list = [] - for i in range(0, self.num_rnn_layers): - x, final_state = self.rnn[i](x, init_state_list[i], - x_lens) #[B, T, D] - final_chunk_state_list.append(final_state) - x = self.layernorm_list[i](x) - - for i in range(self.num_fc_layers): - x = self.fc_layers_list[i](x) - x = F.relu(x) - - if self.use_gru is True: - final_chunk_state_h_box = paddle.concat( - final_chunk_state_list, axis=0) - final_chunk_state_c_box = init_state_c_box - else: - final_chunk_state_h_list = [ - final_chunk_state_list[i][0] for i in range(self.num_rnn_layers) - ] - final_chunk_state_c_list = [ - final_chunk_state_list[i][1] for i in range(self.num_rnn_layers) - ] - final_chunk_state_h_box = paddle.concat( - final_chunk_state_h_list, axis=0) - final_chunk_state_c_box = paddle.concat( - final_chunk_state_c_list, axis=0) - - return x, x_lens, final_chunk_state_h_box, final_chunk_state_c_box - - def forward_chunk_by_chunk(self, x, x_lens, decoder_chunk_size=8): - """Compute Encoder outputs - - Args: - x (Tensor): [B, T, D] - x_lens (Tensor): [B] - decoder_chunk_size: The chunk size of decoder - Returns: - eouts_list (List of Tensor): The list of encoder outputs in chunk_size: [B, chunk_size, D] * num_chunks - eouts_lens_list (List of Tensor): The list of encoder length in chunk_size: [B] * num_chunks - final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] - final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] - """ - subsampling_rate = self.conv.subsampling_rate - receptive_field_length = self.conv.receptive_field_length - chunk_size = (decoder_chunk_size - 1 - ) * subsampling_rate + receptive_field_length - chunk_stride = subsampling_rate * decoder_chunk_size - max_len = x.shape[1] - assert (chunk_size <= max_len) - - eouts_chunk_list = [] - eouts_chunk_lens_list = [] - if (max_len - chunk_size) % chunk_stride != 0: - padding_len = chunk_stride - (max_len - chunk_size) % chunk_stride - else: - padding_len = 0 - padding = paddle.zeros((x.shape[0], padding_len, x.shape[2])) - padded_x = paddle.concat([x, padding], axis=1) - num_chunk = (max_len + padding_len - chunk_size) / chunk_stride + 1 - num_chunk = int(num_chunk) - chunk_state_h_box = None - chunk_state_c_box = None - final_state_h_box = None - final_state_c_box = None - for i in range(0, num_chunk): - start = i * chunk_stride - end = start + chunk_size - x_chunk = padded_x[:, start:end, :] - - x_len_left = paddle.where(x_lens - i * chunk_stride < 0, - paddle.zeros_like(x_lens), - x_lens - i * chunk_stride) - x_chunk_len_tmp = paddle.ones_like(x_lens) * chunk_size - x_chunk_lens = paddle.where(x_len_left < x_chunk_len_tmp, - x_len_left, x_chunk_len_tmp) - - eouts_chunk, eouts_chunk_lens, chunk_state_h_box, chunk_state_c_box = self.forward( - x_chunk, x_chunk_lens, chunk_state_h_box, chunk_state_c_box) - - eouts_chunk_list.append(eouts_chunk) - eouts_chunk_lens_list.append(eouts_chunk_lens) - final_state_h_box = chunk_state_h_box - final_state_c_box = chunk_state_c_box - return eouts_chunk_list, eouts_chunk_lens_list, final_state_h_box, final_state_c_box - - -class DeepSpeech2ModelOnline(nn.Layer): - """The DeepSpeech2 network structure for online. - - :param audio: Audio spectrogram data layer. - :type audio: Variable - :param text: Transcription text data layer. - :type text: Variable - :param audio_len: Valid sequence length data layer. - :type audio_len: Variable - :param feat_size: feature size for audio. - :type feat_size: int - :param dict_size: Dictionary size for tokenized transcription. - :type dict_size: int - :param num_conv_layers: Number of stacking convolution layers. - :type num_conv_layers: int - :param num_rnn_layers: Number of stacking RNN layers. - :type num_rnn_layers: int - :param rnn_size: RNN layer size (dimension of RNN cells). - :type rnn_size: int - :param num_fc_layers: Number of stacking FC layers. - :type num_fc_layers: int - :param fc_layers_size_list: The list of FC layer sizes. - :type fc_layers_size_list: [int,] - :param use_gru: Use gru if set True. Use simple rnn if set False. - :type use_gru: bool - :return: A tuple of an output unnormalized log probability layer ( - before softmax) and a ctc cost layer. - :rtype: tuple of LayerOutput - """ - - def __init__( - self, - feat_size, - dict_size, - num_conv_layers=2, - num_rnn_layers=4, - rnn_size=1024, - rnn_direction='forward', - num_fc_layers=2, - fc_layers_size_list=[512, 256], - use_gru=False, - blank_id=0, - ctc_grad_norm_type=None, ): - super().__init__() - self.encoder = CRNNEncoder( - feat_size=feat_size, - dict_size=dict_size, - num_conv_layers=num_conv_layers, - num_rnn_layers=num_rnn_layers, - rnn_direction=rnn_direction, - num_fc_layers=num_fc_layers, - fc_layers_size_list=fc_layers_size_list, - rnn_size=rnn_size, - use_gru=use_gru) - - self.decoder = CTCDecoder( - odim=dict_size, # is in vocab - enc_n_units=self.encoder.output_size, - blank_id=blank_id, - dropout_rate=0.0, - reduction=True, # sum - batch_average=True, # sum / batch_size - grad_norm_type=ctc_grad_norm_type) - - def forward(self, audio, audio_len, text, text_len): - """Compute Model loss - - Args: - audio (Tensor): [B, T, D] - audio_len (Tensor): [B] - text (Tensor): [B, U] - text_len (Tensor): [B] - - Returns: - loss (Tensor): [1] - """ - eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder( - audio, audio_len, None, None) - loss = self.decoder(eouts, eouts_len, text, text_len) - return loss - - @paddle.no_grad() - def decode(self, audio, audio_len): - # decoders only accept string encoded in utf-8 - # Make sure the decoder has been initialized - eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder( - audio, audio_len, None, None) - probs = self.decoder.softmax(eouts) - batch_size = probs.shape[0] - self.decoder.reset_decoder(batch_size=batch_size) - self.decoder.next(probs, eouts_len) - trans_best, trans_beam = self.decoder.decode() - return trans_best - - @classmethod - def from_pretrained(cls, dataloader, config, checkpoint_path): - """Build a DeepSpeech2Model model from a pretrained model. - Parameters - ---------- - dataloader: paddle.io.DataLoader - - config: yacs.config.CfgNode - model configs - - checkpoint_path: Path or str - the path of pretrained model checkpoint, without extension name - - Returns - ------- - DeepSpeech2ModelOnline - The model built from pretrained result. - """ - model = cls( - feat_size=dataloader.collate_fn.feature_size, - dict_size=dataloader.collate_fn.vocab_size, - num_conv_layers=config.num_conv_layers, - num_rnn_layers=config.num_rnn_layers, - rnn_size=config.rnn_layer_size, - rnn_direction=config.rnn_direction, - num_fc_layers=config.num_fc_layers, - fc_layers_size_list=config.fc_layers_size_list, - use_gru=config.use_gru, - blank_id=config.blank_id, - ctc_grad_norm_type=config.get('ctc_grad_norm_type', None), ) - infos = Checkpoint().load_parameters( - model, checkpoint_path=checkpoint_path) - logger.info(f"checkpoint info: {infos}") - layer_tools.summary(model) - return model - - @classmethod - def from_config(cls, config): - """Build a DeepSpeec2ModelOnline from config - Parameters - - config: yacs.config.CfgNode - config - Returns - ------- - DeepSpeech2ModelOnline - The model built from config. - """ - model = cls( - feat_size=config.input_dim, - dict_size=config.output_dim, - num_conv_layers=config.num_conv_layers, - num_rnn_layers=config.num_rnn_layers, - rnn_size=config.rnn_layer_size, - rnn_direction=config.rnn_direction, - num_fc_layers=config.num_fc_layers, - fc_layers_size_list=config.fc_layers_size_list, - use_gru=config.use_gru, - blank_id=config.blank_id, - ctc_grad_norm_type=config.get('ctc_grad_norm_type', None), ) - return model - - -class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box, - chunk_state_c_box): - eouts_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box = self.encoder( - audio_chunk, audio_chunk_lens, chunk_state_h_box, chunk_state_c_box) - probs_chunk = self.decoder.softmax(eouts_chunk) - return probs_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box - - def export(self): - static_model = paddle.jit.to_static( - self, - input_spec=[ - paddle.static.InputSpec( - shape=[None, None, - self.encoder.feat_size], #[B, chunk_size, feat_dim] - dtype='float32'), - paddle.static.InputSpec(shape=[None], - dtype='int64'), # audio_length, [B] - paddle.static.InputSpec( - shape=[None, None, None], dtype='float32'), - paddle.static.InputSpec( - shape=[None, None, None], dtype='float32') - ]) - return static_model diff --git a/paddlespeech/s2t/models/lm/transformer.py b/paddlespeech/s2t/models/lm/transformer.py index 85bd7c232..04ddddf86 100644 --- a/paddlespeech/s2t/models/lm/transformer.py +++ b/paddlespeech/s2t/models/lm/transformer.py @@ -90,7 +90,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): def _target_mask(self, ys_in_pad): ys_mask = ys_in_pad != 0 - m = subsequent_mask(ys_mask.size(-1)).unsqueeze(0) + m = subsequent_mask(paddle.shape(ys_mask)[-1]).unsqueeze(0) return ys_mask.unsqueeze(-2) & m def forward(self, x: paddle.Tensor, t: paddle.Tensor @@ -112,7 +112,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): in perplexity: p(t)^{-n} = exp(-log p(t) / n) """ - batch_size = x.size(0) + batch_size = paddle.shape(x)[0] xm = x != 0 xlen = xm.sum(axis=1) if self.embed_drop is not None: @@ -122,7 +122,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): h, _ = self.encoder(emb, xlen) y = self.decoder(h) loss = F.cross_entropy( - y.view(-1, y.shape[-1]), t.view(-1), reduction="none") + y.view(-1, paddle.shape(y)[-1]), t.view(-1), reduction="none") mask = xm.to(loss.dtype) logp = loss * mask.view(-1) nll = logp.view(batch_size, -1).sum(-1) diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 530840d0f..b4b61666f 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -1,3 +1,4 @@ +# Copyright 2021 Mobvoi Inc. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -775,7 +776,7 @@ class U2DecodeModel(U2BaseModel): """ self.eval() x = paddle.to_tensor(x).unsqueeze(0) - ilen = x.size(1) + ilen = paddle.shape(x)[1] enc_output, _ = self._forward_encoder(x, ilen) return enc_output.squeeze(0) diff --git a/paddlespeech/s2t/models/u2/updater.py b/paddlespeech/s2t/models/u2/updater.py index c59090a84..bb18fe416 100644 --- a/paddlespeech/s2t/models/u2/updater.py +++ b/paddlespeech/s2t/models/u2/updater.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# Modified from wenet(https://github.com/wenet-e2e/wenet) from contextlib import nullcontext import paddle diff --git a/paddlespeech/s2t/modules/ctc.py b/paddlespeech/s2t/modules/ctc.py index 33ad472de..0f50db21d 100644 --- a/paddlespeech/s2t/modules/ctc.py +++ b/paddlespeech/s2t/modules/ctc.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys from typing import Union import paddle @@ -34,7 +35,8 @@ except ImportError: try: from paddlespeech.s2t.utils import dynamic_pip_install package_name = 'paddlespeech_ctcdecoders' - dynamic_pip_install.install(package_name) + if sys.platform != "win32": + dynamic_pip_install.install(package_name) from paddlespeech.s2t.decoders.ctcdecoder import ctc_beam_search_decoding_batch # noqa: F401 from paddlespeech.s2t.decoders.ctcdecoder import ctc_greedy_decoding # noqa: F401 from paddlespeech.s2t.decoders.ctcdecoder import Scorer # noqa: F401 diff --git a/paddlespeech/s2t/modules/decoder.py b/paddlespeech/s2t/modules/decoder.py index 42ac119b4..ccc8482d5 100644 --- a/paddlespeech/s2t/modules/decoder.py +++ b/paddlespeech/s2t/modules/decoder.py @@ -242,7 +242,7 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): ] # batch decoding - ys_mask = subsequent_mask(ys.size(-1)).unsqueeze(0) # (B,L,L) + ys_mask = subsequent_mask(paddle.shape(ys)[-1]).unsqueeze(0) # (B,L,L) xs_mask = make_xs_mask(xs).unsqueeze(1) # (B,1,T) logp, states = self.forward_one_step( xs, xs_mask, ys, ys_mask, cache=batch_state) diff --git a/paddlespeech/s2t/modules/embedding.py b/paddlespeech/s2t/modules/embedding.py index 596f61b78..51e558eb8 100644 --- a/paddlespeech/s2t/modules/embedding.py +++ b/paddlespeech/s2t/modules/embedding.py @@ -115,7 +115,7 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface): assert offset + x.shape[ 1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format( offset, x.shape[1], self.max_len) - #TODO(Hui Zhang): using T = x.size(1), __getitem__ not support Tensor + #TODO(Hui Zhang): using T = paddle.shape(x)[1], __getitem__ not support Tensor pos_emb = self.pe[:, offset:offset + T] x = x * self.xscale + pos_emb return self.dropout(x), self.dropout(pos_emb) @@ -165,6 +165,6 @@ class RelPositionalEncoding(PositionalEncoding): 1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format( offset, x.shape[1], self.max_len) x = x * self.xscale - #TODO(Hui Zhang): using x.size(1), __getitem__ not support Tensor + #TODO(Hui Zhang): using paddle.shape(x)[1], __getitem__ not support Tensor pos_emb = self.pe[:, offset:offset + x.shape[1]] return self.dropout(x), self.dropout(pos_emb) diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index 669a12d65..4d31acf1a 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -218,7 +218,7 @@ class BaseEncoder(nn.Layer): assert xs.shape[0] == 1 # batch size must be one # tmp_masks is just for interface compatibility # TODO(Hui Zhang): stride_slice not support bool tensor - # tmp_masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool) + # tmp_masks = paddle.ones([1, paddle.shape(xs)[1]], dtype=paddle.bool) tmp_masks = paddle.ones([1, xs.shape[1]], dtype=paddle.int32) tmp_masks = tmp_masks.unsqueeze(1) #[B=1, C=1, T] diff --git a/paddlespeech/s2t/training/trainer.py b/paddlespeech/s2t/training/trainer.py index 84da251aa..a7eb9892d 100644 --- a/paddlespeech/s2t/training/trainer.py +++ b/paddlespeech/s2t/training/trainer.py @@ -112,7 +112,16 @@ class Trainer(): logger.info(f"Rank: {self.rank}/{self.world_size}") # set device - paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu') + if self.args.ngpu == 0: + if self.args.nxpu == 0: + paddle.set_device('cpu') + else: + paddle.set_device('xpu') + elif self.args.ngpu > 0: + paddle.set_device("gpu") + else: + raise Exception("invalid device") + if self.parallel: self.init_parallel() diff --git a/paddlespeech/s2t/transform/perturb.py b/paddlespeech/s2t/transform/perturb.py index 9e41b824b..b18caefb8 100644 --- a/paddlespeech/s2t/transform/perturb.py +++ b/paddlespeech/s2t/transform/perturb.py @@ -154,7 +154,8 @@ class SpeedPerturbationSox(): package = "sox" dynamic_pip_install.install(package) package = "soxbindings" - dynamic_pip_install.install(package) + if sys.platform != "win32": + dynamic_pip_install.install(package) import soxbindings as sox except Exception: raise RuntimeError( diff --git a/paddlespeech/s2t/transform/spectrogram.py b/paddlespeech/s2t/transform/spectrogram.py index 2a93bedc8..19f0237bf 100644 --- a/paddlespeech/s2t/transform/spectrogram.py +++ b/paddlespeech/s2t/transform/spectrogram.py @@ -15,9 +15,10 @@ import librosa import numpy as np import paddle -import paddleaudio.compliance.kaldi as kaldi from python_speech_features import logfbank +import paddlespeech.audio.compliance.kaldi as kaldi + def stft(x, n_fft, diff --git a/paddlespeech/s2t/utils/ctc_utils.py b/paddlespeech/s2t/utils/ctc_utils.py index 886b72033..42564d8e1 100644 --- a/paddlespeech/s2t/utils/ctc_utils.py +++ b/paddlespeech/s2t/utils/ctc_utils.py @@ -1,3 +1,4 @@ +# Copyright 2021 Mobvoi Inc. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/paddlespeech/s2t/utils/tensor_utils.py b/paddlespeech/s2t/utils/tensor_utils.py index 0dbaa0b6b..f9a843ea1 100644 --- a/paddlespeech/s2t/utils/tensor_utils.py +++ b/paddlespeech/s2t/utils/tensor_utils.py @@ -58,7 +58,7 @@ def pad_sequence(sequences: List[paddle.Tensor], >>> a = paddle.ones(25, 300) >>> b = paddle.ones(22, 300) >>> c = paddle.ones(15, 300) - >>> pad_sequence([a, b, c]).size() + >>> pad_sequence([a, b, c]).shape paddle.Tensor([25, 3, 300]) Note: @@ -79,10 +79,11 @@ def pad_sequence(sequences: List[paddle.Tensor], # assuming trailing dimensions and type of all the Tensors # in sequences are same and fetching those from sequences[0] - max_size = sequences[0].size() + max_size = paddle.shape(sequences[0]) # (TODO Hui Zhang): slice not supprot `end==start` # trailing_dims = max_size[1:] - trailing_dims = max_size[1:] if max_size.ndim >= 2 else () + trailing_dims = tuple( + max_size[1:].numpy().tolist()) if sequences[0].ndim >= 2 else () max_len = max([s.shape[0] for s in sequences]) if batch_first: out_dims = (len(sequences), max_len) + trailing_dims @@ -99,7 +100,7 @@ def pad_sequence(sequences: List[paddle.Tensor], if batch_first: # TODO (Hui Zhang): set_value op not supprot `end==start` # TODO (Hui Zhang): set_value op not support int16 - # TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...] + # TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...] # out_tensor[i, :length, ...] = tensor if length != 0: out_tensor[i, :length] = tensor @@ -145,7 +146,7 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int, [ 4, 5, 6, 11, -1, -1], [ 7, 8, 9, 11, -1, -1]]) """ - # TODO(Hui Zhang): using comment code, + # TODO(Hui Zhang): using comment code, #_sos = paddle.to_tensor( # [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place) #_eos = paddle.to_tensor( diff --git a/paddlespeech/s2t/utils/text_grid.py b/paddlespeech/s2t/utils/text_grid.py index cbd9856e4..e696f43d5 100644 --- a/paddlespeech/s2t/utils/text_grid.py +++ b/paddlespeech/s2t/utils/text_grid.py @@ -1,3 +1,4 @@ +# Copyright 2021 Mobvoi Inc. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/paddlespeech/server/bin/paddlespeech_client.py b/paddlespeech/server/bin/paddlespeech_client.py index 74e7ce3fe..fb521b309 100644 --- a/paddlespeech/server/bin/paddlespeech_client.py +++ b/paddlespeech/server/bin/paddlespeech_client.py @@ -752,6 +752,7 @@ class VectorClientExecutor(BaseExecutor): res = handler.run(enroll_audio, test_audio, audio_format, sample_rate) logger.info(f"The vector score is: {res}") + return res else: logger.error(f"Sorry, we have not support such task {task}") diff --git a/paddlespeech/server/bin/paddlespeech_server.py b/paddlespeech/server/bin/paddlespeech_server.py index 578a0a8a8..57d728872 100644 --- a/paddlespeech/server/bin/paddlespeech_server.py +++ b/paddlespeech/server/bin/paddlespeech_server.py @@ -26,7 +26,9 @@ from ..executor import BaseExecutor from ..util import cli_server_register from ..util import stats_wrapper from paddlespeech.cli.log import logger +from paddlespeech.resource import CommonTaskResource from paddlespeech.server.engine.engine_pool import init_engine_pool +from paddlespeech.server.engine.engine_warmup import warm_up from paddlespeech.server.restful.api import setup_router as setup_http_router from paddlespeech.server.utils.config import get_config from paddlespeech.server.ws.api import setup_router as setup_ws_router @@ -90,6 +92,11 @@ class ServerExecutor(BaseExecutor): if not init_engine_pool(config): return False + # warm up + for engine_and_type in config.engine_list: + if not warm_up(engine_and_type): + return False + return True def execute(self, argv: List[str]) -> bool: @@ -156,101 +163,30 @@ class ServerStatsExecutor(): "Please input correct speech task, choices = ['asr', 'tts']") return False - elif self.task.lower() == 'asr': - try: - from paddlespeech.cli.asr.infer import pretrained_models - logger.info( - "Here is the table of ASR pretrained models supported in the service." - ) - self.show_support_models(pretrained_models) - - # show ASR static pretrained model - from paddlespeech.server.engine.asr.paddleinference.asr_engine import pretrained_models - logger.info( - "Here is the table of ASR static pretrained models supported in the service." - ) - self.show_support_models(pretrained_models) - - return True - except BaseException: - logger.error( - "Failed to get the table of ASR pretrained models supported in the service." - ) - return False - - elif self.task.lower() == 'tts': - try: - from paddlespeech.cli.tts.infer import pretrained_models - logger.info( - "Here is the table of TTS pretrained models supported in the service." - ) - self.show_support_models(pretrained_models) - - # show TTS static pretrained model - from paddlespeech.server.engine.tts.paddleinference.tts_engine import pretrained_models - logger.info( - "Here is the table of TTS static pretrained models supported in the service." - ) - self.show_support_models(pretrained_models) - - return True - except BaseException: - logger.error( - "Failed to get the table of TTS pretrained models supported in the service." - ) - return False - - elif self.task.lower() == 'cls': - try: - from paddlespeech.cli.cls.infer import pretrained_models - logger.info( - "Here is the table of CLS pretrained models supported in the service." - ) - self.show_support_models(pretrained_models) + try: + # Dynamic models + dynamic_pretrained_models = CommonTaskResource( + task=self.task, model_format='dynamic').pretrained_models - # show CLS static pretrained model - from paddlespeech.server.engine.cls.paddleinference.cls_engine import pretrained_models + if len(dynamic_pretrained_models) > 0: logger.info( - "Here is the table of CLS static pretrained models supported in the service." - ) - self.show_support_models(pretrained_models) - - return True - except BaseException: - logger.error( - "Failed to get the table of CLS pretrained models supported in the service." - ) - return False - elif self.task.lower() == 'text': - try: - from paddlespeech.cli.text.infer import pretrained_models + "Here is the table of {} pretrained models supported in the service.". + format(self.task.upper())) + self.show_support_models(dynamic_pretrained_models) + + # Static models + static_pretrained_models = CommonTaskResource( + task=self.task, model_format='static').pretrained_models + if len(static_pretrained_models) > 0: logger.info( - "Here is the table of Text pretrained models supported in the service." - ) + "Here is the table of {} static pretrained models supported in the service.". + format(self.task.upper())) self.show_support_models(pretrained_models) - return True - except BaseException: - logger.error( - "Failed to get the table of Text pretrained models supported in the service." - ) - return False - elif self.task.lower() == 'vector': - try: - from paddlespeech.cli.vector.infer import pretrained_models - logger.info( - "Here is the table of Vector pretrained models supported in the service." - ) - self.show_support_models(pretrained_models) + return True - return True - except BaseException: - logger.error( - "Failed to get the table of Vector pretrained models supported in the service." - ) - return False - else: + except BaseException: logger.error( - f"Failed to get the table of {self.task} pretrained models supported in the service." - ) + "Failed to get the table of {} pretrained models supported in the service.". + format(self.task.upper())) return False diff --git a/paddlespeech/server/conf/tts_online_application.yaml b/paddlespeech/server/conf/tts_online_application.yaml index 964e85ef9..0460a5e16 100644 --- a/paddlespeech/server/conf/tts_online_application.yaml +++ b/paddlespeech/server/conf/tts_online_application.yaml @@ -47,7 +47,7 @@ tts_online: am_pad: 12 # voc_pad and voc_block voc model to streaming voc infer, # when voc model is mb_melgan_csmsc, voc_pad set 14, streaming synthetic audio is the same as non-streaming synthetic audio; The minimum value of pad can be set to 7, streaming synthetic audio sounds normal - # when voc model is hifigan_csmsc, voc_pad set 20, streaming synthetic audio is the same as non-streaming synthetic audio; voc_pad set 14, streaming synthetic audio sounds normal + # when voc model is hifigan_csmsc, voc_pad set 19, streaming synthetic audio is the same as non-streaming synthetic audio; voc_pad set 14, streaming synthetic audio sounds normal voc_block: 36 voc_pad: 14 @@ -95,7 +95,7 @@ tts_online-onnx: am_pad: 12 # voc_pad and voc_block voc model to streaming voc infer, # when voc model is mb_melgan_csmsc_onnx, voc_pad set 14, streaming synthetic audio is the same as non-streaming synthetic audio; The minimum value of pad can be set to 7, streaming synthetic audio sounds normal - # when voc model is hifigan_csmsc_onnx, voc_pad set 20, streaming synthetic audio is the same as non-streaming synthetic audio; voc_pad set 14, streaming synthetic audio sounds normal + # when voc model is hifigan_csmsc_onnx, voc_pad set 19, streaming synthetic audio is the same as non-streaming synthetic audio; voc_pad set 14, streaming synthetic audio sounds normal voc_block: 36 voc_pad: 14 # voc_upsample should be same as n_shift on voc config. diff --git a/paddlespeech/server/conf/ws_application.yaml b/paddlespeech/server/conf/ws_application.yaml index dee8d78ba..43d83f2d4 100644 --- a/paddlespeech/server/conf/ws_application.yaml +++ b/paddlespeech/server/conf/ws_application.yaml @@ -28,7 +28,9 @@ asr_online: sample_rate: 16000 cfg_path: decode_method: + num_decoding_left_chunks: force_yes: True + device: # cpu or gpu:id am_predictor_conf: device: # set 'gpu:id' or 'cpu' diff --git a/paddlespeech/server/conf/ws_conformer_application.yaml b/paddlespeech/server/conf/ws_conformer_application.yaml index 9c0425345..d72eb2379 100644 --- a/paddlespeech/server/conf/ws_conformer_application.yaml +++ b/paddlespeech/server/conf/ws_conformer_application.yaml @@ -28,8 +28,11 @@ asr_online: sample_rate: 16000 cfg_path: decode_method: + num_decoding_left_chunks: -1 force_yes: True device: # cpu or gpu:id + continuous_decoding: True # enable continue decoding when endpoint detected + am_predictor_conf: device: # set 'gpu:id' or 'cpu' switch_ir_optim: True @@ -42,4 +45,4 @@ asr_online: window_ms: 25 # ms shift_ms: 10 # ms sample_rate: 16000 - sample_width: 2 \ No newline at end of file + sample_width: 2 diff --git a/paddlespeech/server/conf/ws_conformer_wenetspeech_application_faster.yaml b/paddlespeech/server/conf/ws_conformer_wenetspeech_application_faster.yaml new file mode 100644 index 000000000..ba413c802 --- /dev/null +++ b/paddlespeech/server/conf/ws_conformer_wenetspeech_application_faster.yaml @@ -0,0 +1,48 @@ +# This is the parameter configuration file for PaddleSpeech Serving. + +################################################################################# +# SERVER SETTING # +################################################################################# +host: 0.0.0.0 +port: 8090 + +# The task format in the engin_list is: _ +# task choices = ['asr_online'] +# protocol = ['websocket'] (only one can be selected). +# websocket only support online engine type. +protocol: 'websocket' +engine_list: ['asr_online'] + + +################################################################################# +# ENGINE CONFIG # +################################################################################# + +################################### ASR ######################################### +################### speech task: asr; engine_type: online ####################### +asr_online: + model_type: 'conformer_online_wenetspeech' + am_model: # the pdmodel file of am static model [optional] + am_params: # the pdiparams file of am static model [optional] + lang: 'zh' + sample_rate: 16000 + cfg_path: + decode_method: + force_yes: True + device: 'cpu' # cpu or gpu:id + decode_method: "attention_rescoring" + continuous_decoding: True # enable continue decoding when endpoint detected + num_decoding_left_chunks: 16 + am_predictor_conf: + device: # set 'gpu:id' or 'cpu' + switch_ir_optim: True + glog_info: False # True -> print glog + summary: True # False -> do not show predictor config + + chunk_buffer_conf: + window_n: 7 # frame + shift_n: 4 # frame + window_ms: 25 # ms + shift_ms: 10 # ms + sample_rate: 16000 + sample_width: 2 diff --git a/paddlespeech/server/download.py b/paddlespeech/server/download.py deleted file mode 100644 index ea943dd87..000000000 --- a/paddlespeech/server/download.py +++ /dev/null @@ -1,329 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import hashlib -import os -import os.path as osp -import shutil -import subprocess -import tarfile -import time -import zipfile - -import requests -from tqdm import tqdm - -from paddlespeech.cli.log import logger - -__all__ = ['get_path_from_url'] - -DOWNLOAD_RETRY_LIMIT = 3 - - -def _is_url(path): - """ - Whether path is URL. - Args: - path (string): URL string or not. - """ - return path.startswith('http://') or path.startswith('https://') - - -def _map_path(url, root_dir): - # parse path after download under root_dir - fname = osp.split(url)[-1] - fpath = fname - return osp.join(root_dir, fpath) - - -def _get_unique_endpoints(trainer_endpoints): - # Sorting is to avoid different environmental variables for each card - trainer_endpoints.sort() - ips = set() - unique_endpoints = set() - for endpoint in trainer_endpoints: - ip = endpoint.split(":")[0] - if ip in ips: - continue - ips.add(ip) - unique_endpoints.add(endpoint) - logger.info("unique_endpoints {}".format(unique_endpoints)) - return unique_endpoints - - -def get_path_from_url(url, - root_dir, - md5sum=None, - check_exist=True, - decompress=True, - method='get'): - """ Download from given url to root_dir. - if file or directory specified by url is exists under - root_dir, return the path directly, otherwise download - from url and decompress it, return the path. - Args: - url (str): download url - root_dir (str): root dir for downloading, it should be - WEIGHTS_HOME or DATASET_HOME - md5sum (str): md5 sum of download package - decompress (bool): decompress zip or tar file. Default is `True` - method (str): which download method to use. Support `wget` and `get`. Default is `get`. - Returns: - str: a local path to save downloaded models & weights & datasets. - """ - - from paddle.fluid.dygraph.parallel import ParallelEnv - - assert _is_url(url), "downloading from {} not a url".format(url) - # parse path after download to decompress under root_dir - fullpath = _map_path(url, root_dir) - # Mainly used to solve the problem of downloading data from different - # machines in the case of multiple machines. Different ips will download - # data, and the same ip will only download data once. - unique_endpoints = _get_unique_endpoints(ParallelEnv().trainer_endpoints[:]) - if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum): - logger.info("Found {}".format(fullpath)) - else: - if ParallelEnv().current_endpoint in unique_endpoints: - fullpath = _download(url, root_dir, md5sum, method=method) - else: - while not os.path.exists(fullpath): - time.sleep(1) - - if ParallelEnv().current_endpoint in unique_endpoints: - if decompress and (tarfile.is_tarfile(fullpath) or - zipfile.is_zipfile(fullpath)): - fullpath = _decompress(fullpath) - - return fullpath - - -def _get_download(url, fullname): - # using requests.get method - fname = osp.basename(fullname) - try: - req = requests.get(url, stream=True) - except Exception as e: # requests.exceptions.ConnectionError - logger.info("Downloading {} from {} failed with exception {}".format( - fname, url, str(e))) - return False - - if req.status_code != 200: - raise RuntimeError("Downloading from {} failed with code " - "{}!".format(url, req.status_code)) - - # For protecting download interupted, download to - # tmp_fullname firstly, move tmp_fullname to fullname - # after download finished - tmp_fullname = fullname + "_tmp" - total_size = req.headers.get('content-length') - with open(tmp_fullname, 'wb') as f: - if total_size: - with tqdm(total=(int(total_size) + 1023) // 1024) as pbar: - for chunk in req.iter_content(chunk_size=1024): - f.write(chunk) - pbar.update(1) - else: - for chunk in req.iter_content(chunk_size=1024): - if chunk: - f.write(chunk) - shutil.move(tmp_fullname, fullname) - - return fullname - - -def _wget_download(url, fullname): - # using wget to download url - tmp_fullname = fullname + "_tmp" - # –user-agent - command = 'wget -O {} -t {} {}'.format(tmp_fullname, DOWNLOAD_RETRY_LIMIT, - url) - subprc = subprocess.Popen( - command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - _ = subprc.communicate() - - if subprc.returncode != 0: - raise RuntimeError( - '{} failed. Please make sure `wget` is installed or {} exists'. - format(command, url)) - - shutil.move(tmp_fullname, fullname) - - return fullname - - -_download_methods = { - 'get': _get_download, - 'wget': _wget_download, -} - - -def _download(url, path, md5sum=None, method='get'): - """ - Download from url, save to path. - url (str): download url - path (str): download to given path - md5sum (str): md5 sum of download package - method (str): which download method to use. Support `wget` and `get`. Default is `get`. - """ - assert method in _download_methods, 'make sure `{}` implemented'.format( - method) - - if not osp.exists(path): - os.makedirs(path) - - fname = osp.split(url)[-1] - fullname = osp.join(path, fname) - retry_cnt = 0 - - logger.info("Downloading {} from {}".format(fname, url)) - while not (osp.exists(fullname) and _md5check(fullname, md5sum)): - if retry_cnt < DOWNLOAD_RETRY_LIMIT: - retry_cnt += 1 - else: - raise RuntimeError("Download from {} failed. " - "Retry limit reached".format(url)) - - if not _download_methods[method](url, fullname): - time.sleep(1) - continue - - return fullname - - -def _md5check(fullname, md5sum=None): - if md5sum is None: - return True - - logger.info("File {} md5 checking...".format(fullname)) - md5 = hashlib.md5() - with open(fullname, 'rb') as f: - for chunk in iter(lambda: f.read(4096), b""): - md5.update(chunk) - calc_md5sum = md5.hexdigest() - - if calc_md5sum != md5sum: - logger.info("File {} md5 check failed, {}(calc) != " - "{}(base)".format(fullname, calc_md5sum, md5sum)) - return False - return True - - -def _decompress(fname): - """ - Decompress for zip and tar file - """ - logger.info("Decompressing {}...".format(fname)) - - # For protecting decompressing interupted, - # decompress to fpath_tmp directory firstly, if decompress - # successed, move decompress files to fpath and delete - # fpath_tmp and remove download compress file. - - if tarfile.is_tarfile(fname): - uncompressed_path = _uncompress_file_tar(fname) - elif zipfile.is_zipfile(fname): - uncompressed_path = _uncompress_file_zip(fname) - else: - raise TypeError("Unsupport compress file type {}".format(fname)) - - return uncompressed_path - - -def _uncompress_file_zip(filepath): - files = zipfile.ZipFile(filepath, 'r') - file_list = files.namelist() - - file_dir = os.path.dirname(filepath) - - if _is_a_single_file(file_list): - rootpath = file_list[0] - uncompressed_path = os.path.join(file_dir, rootpath) - - for item in file_list: - files.extract(item, file_dir) - - elif _is_a_single_dir(file_list): - rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[0] - uncompressed_path = os.path.join(file_dir, rootpath) - - for item in file_list: - files.extract(item, file_dir) - - else: - rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] - uncompressed_path = os.path.join(file_dir, rootpath) - if not os.path.exists(uncompressed_path): - os.makedirs(uncompressed_path) - for item in file_list: - files.extract(item, os.path.join(file_dir, rootpath)) - - files.close() - - return uncompressed_path - - -def _uncompress_file_tar(filepath, mode="r:*"): - files = tarfile.open(filepath, mode) - file_list = files.getnames() - - file_dir = os.path.dirname(filepath) - - if _is_a_single_file(file_list): - rootpath = file_list[0] - uncompressed_path = os.path.join(file_dir, rootpath) - for item in file_list: - files.extract(item, file_dir) - elif _is_a_single_dir(file_list): - rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] - uncompressed_path = os.path.join(file_dir, rootpath) - for item in file_list: - files.extract(item, file_dir) - else: - rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] - uncompressed_path = os.path.join(file_dir, rootpath) - if not os.path.exists(uncompressed_path): - os.makedirs(uncompressed_path) - - for item in file_list: - files.extract(item, os.path.join(file_dir, rootpath)) - - files.close() - - return uncompressed_path - - -def _is_a_single_file(file_list): - if len(file_list) == 1 and file_list[0].find(os.sep) < -1: - return True - return False - - -def _is_a_single_dir(file_list): - new_file_list = [] - for file_path in file_list: - if '/' in file_path: - file_path = file_path.replace('/', os.sep) - elif '\\' in file_path: - file_path = file_path.replace('\\', os.sep) - new_file_list.append(file_path) - - file_name = new_file_list[0].split(os.sep)[0] - for i in range(1, len(new_file_list)): - if file_name != new_file_list[i].split(os.sep)[0]: - return False - return True diff --git a/paddlespeech/server/engine/acs/__init__.py b/paddlespeech/server/engine/acs/__init__.py index e69de29bb..97043fd7b 100644 --- a/paddlespeech/server/engine/acs/__init__.py +++ b/paddlespeech/server/engine/acs/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/acs/python/__init__.py b/paddlespeech/server/engine/acs/python/__init__.py index e69de29bb..97043fd7b 100644 --- a/paddlespeech/server/engine/acs/python/__init__.py +++ b/paddlespeech/server/engine/acs/python/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/acs/python/acs_engine.py b/paddlespeech/server/engine/acs/python/acs_engine.py index 30deeeb50..930101ac9 100644 --- a/paddlespeech/server/engine/acs/python/acs_engine.py +++ b/paddlespeech/server/engine/acs/python/acs_engine.py @@ -16,6 +16,7 @@ import json import os import re +import numpy as np import paddle import soundfile import websocket @@ -44,11 +45,10 @@ class ACSEngine(BaseEngine): logger.info("Init the acs engine") try: self.config = config - if self.config.device: - self.device = self.config.device - else: - self.device = paddle.get_device() + self.device = self.config.get("device", paddle.get_device()) + # websocket default ping timeout is 20 seconds + self.ping_timeout = self.config.get("ping_timeout", 20) paddle.set_device(self.device) logger.info(f"ACS Engine set the device: {self.device}") @@ -100,8 +100,8 @@ class ACSEngine(BaseEngine): logger.error("No asr server, please input valid ip and port") return "" ws = websocket.WebSocket() - ws.connect(self.url) - # with websocket.WebSocket.connect(self.url) as ws: + logger.info(f"set the ping timeout: {self.ping_timeout} seconds") + ws.connect(self.url, ping_timeout=self.ping_timeout) audio_info = json.dumps( { "name": "test.wav", @@ -116,11 +116,11 @@ class ACSEngine(BaseEngine): logger.info("client receive msg={}".format(msg)) # send the total audio data - samples, sample_rate = soundfile.read(audio_data, dtype='int16') - ws.send_binary(samples.tobytes()) - msg = ws.recv() - msg = json.loads(msg) - logger.info(f"audio result: {msg}") + for chunk_data in self.read_wave(audio_data): + ws.send_binary(chunk_data.tobytes()) + msg = ws.recv() + msg = json.loads(msg) + logger.info(f"audio result: {msg}") # 3. send chunk audio data to engine logger.info("send the end signal") @@ -142,6 +142,39 @@ class ACSEngine(BaseEngine): return msg + def read_wave(self, audio_data: str): + """read the audio file from specific wavfile path + + Args: + audio_data (str): the audio data, + we assume that audio sample rate matches the model + + Yields: + numpy.array: the samall package audio pcm data + """ + samples, sample_rate = soundfile.read(audio_data, dtype='int16') + x_len = len(samples) + assert sample_rate == 16000 + + chunk_size = int(85 * sample_rate / 1000) # 85ms, sample_rate = 16kHz + + if x_len % chunk_size != 0: + padding_len_x = chunk_size - x_len % chunk_size + else: + padding_len_x = 0 + + padding = np.zeros((padding_len_x), dtype=samples.dtype) + padded_x = np.concatenate([samples, padding], axis=0) + + assert (x_len + padding_len_x) % chunk_size == 0 + num_chunk = (x_len + padding_len_x) / chunk_size + num_chunk = int(num_chunk) + for i in range(0, num_chunk): + start = i * chunk_size + end = start + chunk_size + x_chunk = padded_x[start:end] + yield x_chunk + def get_macthed_word(self, msg): """Get the matched info in msg diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index fd57a3d52..cb6a42b74 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy import os import sys +from typing import ByteString from typing import Optional import numpy as np @@ -21,24 +21,23 @@ import paddle from numpy import float32 from yacs.config import CfgNode -from .pretrained_models import pretrained_models from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.cli.utils import MODEL_HOME +from paddlespeech.resource import CommonTaskResource from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer -from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.transform.transformation import Transformation -from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.tensor_utils import add_sos_eos from paddlespeech.s2t.utils.tensor_utils import pad_sequence from paddlespeech.s2t.utils.utility import UpdateConfig +from paddlespeech.server.engine.asr.online.ctc_endpoint import OnlineCTCEndpoingOpt +from paddlespeech.server.engine.asr.online.ctc_endpoint import OnlineCTCEndpoint from paddlespeech.server.engine.asr.online.ctc_search import CTCPrefixBeamSearch from paddlespeech.server.engine.base_engine import BaseEngine -from paddlespeech.server.utils.audio_process import pcm2float from paddlespeech.server.utils.paddle_predictor import init_predictor -__all__ = ['ASREngine'] +__all__ = ['PaddleASRConnectionHanddler', 'ASRServerExecutor', 'ASREngine'] # ASR server connection process class @@ -53,25 +52,39 @@ class PaddleASRConnectionHanddler: logger.info( "create an paddle asr connection handler to process the websocket connection" ) - self.config = asr_engine.config + self.config = asr_engine.config # server config self.model_config = asr_engine.executor.config self.asr_engine = asr_engine - self.init() - self.reset() - - def init(self): # model_type, sample_rate and text_feature is shared for deepspeech2 and conformer self.model_type = self.asr_engine.executor.model_type self.sample_rate = self.asr_engine.executor.sample_rate # tokens to text self.text_feature = self.asr_engine.executor.text_feature - if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: - from paddlespeech.s2t.io.collator import SpeechCollator + # extract feat, new only fbank in conformer model + self.preprocess_conf = self.model_config.preprocess_config + self.preprocess_args = {"train": False} + self.preprocessing = Transformation(self.preprocess_conf) + + # frame window and frame shift, in samples unit + self.win_length = self.preprocess_conf.process[0]['win_length'] + self.n_shift = self.preprocess_conf.process[0]['n_shift'] + + assert self.preprocess_conf.process[0]['fs'] == self.sample_rate, ( + self.sample_rate, self.preprocess_conf.process[0]['fs']) + self.frame_shift_in_ms = int( + self.n_shift / self.preprocess_conf.process[0]['fs'] * 1000) + + self.continuous_decoding = self.config.get("continuous_decoding", False) + self.init_decoder() + self.reset() + + def init_decoder(self): + if "deepspeech2" in self.model_type: + assert self.continuous_decoding is False, "ds2 model not support endpoint" self.am_predictor = self.asr_engine.executor.am_predictor - self.collate_fn_test = SpeechCollator.from_config(self.model_config) self.decoder = CTCDecoder( odim=self.model_config.output_dim, # is in vocab enc_n_units=self.model_config.rnn_layer_size * 2, @@ -89,188 +102,185 @@ class PaddleASRConnectionHanddler: cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch) - # frame window samples length and frame shift samples length - - self.win_length = int(self.model_config.window_ms / 1000 * - self.sample_rate) - self.n_shift = int(self.model_config.stride_ms / 1000 * - self.sample_rate) elif "conformer" in self.model_type or "transformer" in self.model_type: # acoustic model self.model = self.asr_engine.executor.model + self.continuous_decoding = self.config.continuous_decoding + logger.info(f"continue decoding: {self.continuous_decoding}") # ctc decoding config self.ctc_decode_config = self.asr_engine.executor.config.decode self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config) - # extract feat, new only fbank in conformer model - self.preprocess_conf = self.model_config.preprocess_config - self.preprocess_args = {"train": False} - self.preprocessing = Transformation(self.preprocess_conf) + # ctc endpoint + self.endpoint_opt = OnlineCTCEndpoingOpt( + frame_shift_in_ms=self.frame_shift_in_ms, blank=0) + self.endpointer = OnlineCTCEndpoint(self.endpoint_opt) + else: + raise ValueError(f"Not supported: {self.model_type}") - # frame window samples length and frame shift samples length - self.win_length = self.preprocess_conf.process[0]['win_length'] - self.n_shift = self.preprocess_conf.process[0]['n_shift'] + def model_reset(self): + if "deepspeech2" in self.model_type: + return - def extract_feat(self, samples): + # cache for audio and feat + self.remained_wav = None + self.cached_feat = None - # we compute the elapsed time of first char occuring - # and we record the start time at the first pcm sample arraving - # if self.first_char_occur_elapsed is not None: - # self.first_char_occur_elapsed = time.time() + ## conformer + # cache for conformer online + self.subsampling_cache = None + self.elayers_output_cache = None + self.conformer_cnn_cache = None + self.encoder_out = None + # conformer decoding state + self.offset = 0 # global offset in decoding frame unit - if "deepspeech2online" in self.model_type: - # self.reamined_wav stores all the samples, - # include the original remained_wav and this package samples - samples = np.frombuffer(samples, dtype=np.int16) - assert samples.ndim == 1 + ## just for record info + self.chunk_num = 0 # global decoding chunk num, not used - # pcm16 -> pcm 32 - # pcm2float will change the orignal samples, - # so we shoule do pcm2float before concatenate - samples = pcm2float(samples) + def output_reset(self): + ## outputs + # partial/ending decoding results + self.result_transcripts = [''] + # token timestamp result + self.word_time_stamp = [] - if self.remained_wav is None: - self.remained_wav = samples - else: - assert self.remained_wav.ndim == 1 - self.remained_wav = np.concatenate([self.remained_wav, samples]) - logger.info( - f"The connection remain the audio samples: {self.remained_wav.shape}" - ) + ## just for record + self.hyps = [] - # read audio - speech_segment = SpeechSegment.from_pcm( - self.remained_wav, self.sample_rate, transcript=" ") - # audio augment - self.collate_fn_test.augmentation.transform_audio(speech_segment) + # one best timestamp viterbi prob is large. + self.time_stamp = [] - # extract speech feature - spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( - speech_segment, self.collate_fn_test.keep_transcription_text) - # CMVN spectrum - if self.collate_fn_test._normalizer: - spectrum = self.collate_fn_test._normalizer.apply(spectrum) + def reset_continuous_decoding(self): + """ + when in continous decoding, reset for next utterance. + """ + self.global_frame_offset = self.num_frames + self.model_reset() + self.searcher.reset() + self.endpointer.reset() - # spectrum augment - audio = self.collate_fn_test.augmentation.transform_feature( - spectrum) + # reset hys will trancate history transcripts. + # self.output_reset() - audio_len = audio.shape[0] - audio = paddle.to_tensor(audio, dtype='float32') - # audio_len = paddle.to_tensor(audio_len) - audio = paddle.unsqueeze(audio, axis=0) + def reset(self): + if "deepspeech2" in self.model_type: + # for deepspeech2 + # init state + self.chunk_state_h_box = np.zeros( + (self.model_config.num_rnn_layers, 1, + self.model_config.rnn_layer_size), + dtype=float32) + self.chunk_state_c_box = np.zeros( + (self.model_config.num_rnn_layers, 1, + self.model_config.rnn_layer_size), + dtype=float32) + self.decoder.reset_decoder(batch_size=1) - if self.cached_feat is None: - self.cached_feat = audio - else: - assert (len(audio.shape) == 3) - assert (len(self.cached_feat.shape) == 3) - self.cached_feat = paddle.concat( - [self.cached_feat, audio], axis=1) + if "conformer" in self.model_type or "transformer" in self.model_type: + self.searcher.reset() + self.endpointer.reset() - # set the feat device - if self.device is None: - self.device = self.cached_feat.place + self.device = None - self.num_frames += audio_len - self.remained_wav = self.remained_wav[self.n_shift * audio_len:] + ## common + # global sample and frame step + self.num_samples = 0 + self.global_frame_offset = 0 + # frame step of cur utterance + self.num_frames = 0 - logger.info( - f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" - ) - logger.info( - f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" - ) - elif "conformer_online" in self.model_type: - logger.info("Online ASR extract the feat") - samples = np.frombuffer(samples, dtype=np.int16) - assert samples.ndim == 1 - - logger.info(f"This package receive {samples.shape[0]} pcm data") - self.num_samples += samples.shape[0] - - # self.reamined_wav stores all the samples, - # include the original remained_wav and this package samples - if self.remained_wav is None: - self.remained_wav = samples - else: - assert self.remained_wav.ndim == 1 - self.remained_wav = np.concatenate([self.remained_wav, samples]) - logger.info( - f"The connection remain the audio samples: {self.remained_wav.shape}" - ) - if len(self.remained_wav) < self.win_length: - return 0 - - # fbank - x_chunk = self.preprocessing(self.remained_wav, - **self.preprocess_args) - x_chunk = paddle.to_tensor( - x_chunk, dtype="float32").unsqueeze(axis=0) - if self.cached_feat is None: - self.cached_feat = x_chunk - else: - assert (len(x_chunk.shape) == 3) - assert (len(self.cached_feat.shape) == 3) - self.cached_feat = paddle.concat( - [self.cached_feat, x_chunk], axis=1) + ## endpoint + self.endpoint_state = False # True for detect endpoint - # set the feat device - if self.device is None: - self.device = self.cached_feat.place + ## conformer + self.model_reset() - num_frames = x_chunk.shape[1] - self.num_frames += num_frames - self.remained_wav = self.remained_wav[self.n_shift * num_frames:] + ## outputs + self.output_reset() - logger.info( - f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" - ) - logger.info( - f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" - ) - # logger.info(f"accumulate samples: {self.num_samples}") + def extract_feat(self, samples: ByteString): + logger.info("Online ASR extract the feat") + samples = np.frombuffer(samples, dtype=np.int16) + assert samples.ndim == 1 - def reset(self): - if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: - # for deepspeech2 - self.chunk_state_h_box = copy.deepcopy( - self.asr_engine.executor.chunk_state_h_box) - self.chunk_state_c_box = copy.deepcopy( - self.asr_engine.executor.chunk_state_c_box) - self.decoder.reset_decoder(batch_size=1) + self.num_samples += samples.shape[0] + logger.info( + f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}" + ) - # for conformer online - self.subsampling_cache = None - self.elayers_output_cache = None - self.conformer_cnn_cache = None - self.encoder_out = None - self.cached_feat = None - self.remained_wav = None - self.offset = 0 - self.num_samples = 0 - self.device = None - self.hyps = [] - self.num_frames = 0 - self.chunk_num = 0 - self.global_frame_offset = 0 - self.result_transcripts = [''] - self.word_time_stamp = [] - self.time_stamp = [] - self.first_char_occur_elapsed = None + # self.reamined_wav stores all the samples, + # include the original remained_wav and this package samples + if self.remained_wav is None: + self.remained_wav = samples + else: + assert self.remained_wav.ndim == 1 # (T,) + self.remained_wav = np.concatenate([self.remained_wav, samples]) + logger.info( + f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}" + ) + + if len(self.remained_wav) < self.win_length: + # samples not enough for feature window + return 0 + + # fbank + x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args) + x_chunk = paddle.to_tensor(x_chunk, dtype="float32").unsqueeze(axis=0) + + # feature cache + if self.cached_feat is None: + self.cached_feat = x_chunk + else: + assert (len(x_chunk.shape) == 3) # (B,T,D) + assert (len(self.cached_feat.shape) == 3) # (B,T,D) + self.cached_feat = paddle.concat( + [self.cached_feat, x_chunk], axis=1) + + # set the feat device + if self.device is None: + self.device = self.cached_feat.place + + # cur frame step + num_frames = x_chunk.shape[1] + + # global frame step + self.num_frames += num_frames + + # update remained wav + self.remained_wav = self.remained_wav[self.n_shift * num_frames:] + + logger.info( + f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}" + ) + logger.info( + f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}" + ) + logger.info(f"global samples: {self.num_samples}") + logger.info(f"global frames: {self.num_frames}") def decode(self, is_finished=False): - if "deepspeech2online" in self.model_type: - # x_chunk 是特征数据 - decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model - context = 7 # context=7 in deepspeech2 model - subsampling = 4 # subsampling=4 in deepspeech2 model - stride = subsampling * decoding_chunk_size + """advance decoding + + Args: + is_finished (bool, optional): Is last frame or not. Defaults to False. + + Returns: + None: + """ + if "deepspeech2" in self.model_type: + decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit + + context = 7 # context=7, in audio frame unit + subsampling = 4 # subsampling=4, in audio frame unit + cached_feature_num = context - subsampling - # decoding window for model + # decoding window for model, in audio frame unit decoding_window = (decoding_chunk_size - 1) * subsampling + context + # decoding stride for model, in audio frame unit + stride = subsampling * decoding_chunk_size if self.cached_feat is None: logger.info("no audio feat, please input more pcm data") @@ -280,6 +290,7 @@ class PaddleASRConnectionHanddler: logger.info( f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" ) + # the cached feat must be larger decoding_window if num_frames < decoding_window and not is_finished: logger.info( @@ -293,6 +304,7 @@ class PaddleASRConnectionHanddler: "flast {num_frames} is less than context {context} frames, and we cannot do model forward" ) return None, None + logger.info("start to do model forward") # num_frames - context + 1 ensure that current frame can get context window if is_finished: @@ -302,17 +314,22 @@ class PaddleASRConnectionHanddler: # we only process decoding_window frames for one chunk left_frames = decoding_window + end = None for cur in range(0, num_frames - left_frames + 1, stride): end = min(cur + decoding_window, num_frames) + # extract the audio x_chunk = self.cached_feat[:, cur:end, :].numpy() x_chunk_lens = np.array([x_chunk.shape[1]]) + trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens) self.result_transcripts = [trans_best] + # update feat cache self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :] - # return trans_best[0] + + # return trans_best[0] elif "conformer" in self.model_type or "transformer" in self.model_type: try: logger.info( @@ -328,7 +345,16 @@ class PaddleASRConnectionHanddler: @paddle.no_grad() def decode_one_chunk(self, x_chunk, x_chunk_lens): - logger.info("start to decoce one chunk with deepspeech2 model") + """forward one chunk frames + + Args: + x_chunk (np.ndarray): (B,T,D), audio frames. + x_chunk_lens ([type]): (B,), audio frame lens + + Returns: + logprob: poster probability. + """ + logger.info("start to decoce one chunk for deepspeech2") input_names = self.am_predictor.get_input_names() audio_handle = self.am_predictor.get_input_handle(input_names[0]) audio_len_handle = self.am_predictor.get_input_handle(input_names[1]) @@ -365,28 +391,46 @@ class PaddleASRConnectionHanddler: self.decoder.next(output_chunk_probs, output_chunk_lens) trans_best, trans_beam = self.decoder.decode() - logger.info(f"decode one best result: {trans_best[0]}") + logger.info(f"decode one best result for deepspeech2: {trans_best[0]}") return trans_best[0] @paddle.no_grad() def advance_decoding(self, is_finished=False): - logger.info("start to decode with advanced_decoding method") + if "deepspeech" in self.model_type: + return + + # reset endpiont state + self.endpoint_state = False + + logger.info( + "Conformer/Transformer: start to decode with advanced_decoding method" + ) cfg = self.ctc_decode_config + + # cur chunk size, in decoding frame unit, e.g. 16 decoding_chunk_size = cfg.decoding_chunk_size + # using num of history chunks, e.g -1 num_decoding_left_chunks = cfg.num_decoding_left_chunks - assert decoding_chunk_size > 0 + + # e.g. 4 subsampling = self.model.encoder.embed.subsampling_rate + # e.g. 7 context = self.model.encoder.embed.right_context + 1 - stride = subsampling * decoding_chunk_size - cached_feature_num = context - subsampling # processed chunk feature cached for next chunk - # decoding window for model + # processed chunk feature cached for next chunk, e.g. 3 + cached_feature_num = context - subsampling + + # decoding window, in audio frame unit decoding_window = (decoding_chunk_size - 1) * subsampling + context + # decoding stride, in audio frame unit + stride = subsampling * decoding_chunk_size + if self.cached_feat is None: logger.info("no audio feat, please input more pcm data") return + # (B=1,T,D) num_frames = self.cached_feat.shape[1] logger.info( f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" @@ -407,8 +451,6 @@ class PaddleASRConnectionHanddler: return None, None logger.info("start to do model forward") - required_cache_size = decoding_chunk_size * num_decoding_left_chunks - outputs = [] # num_frames - context + 1 ensure that current frame can get context window if is_finished: @@ -418,13 +460,20 @@ class PaddleASRConnectionHanddler: # we only process decoding_window frames for one chunk left_frames = decoding_window + # hist of chunks, in deocding frame unit + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + # record the end for removing the processed feat + outputs = [] end = None for cur in range(0, num_frames - left_frames + 1, stride): end = min(cur + decoding_window, num_frames) + # global chunk_num self.chunk_num += 1 + # cur chunk chunk_xs = self.cached_feat[:, cur:end, :] + # forward chunk (y, self.subsampling_cache, self.elayers_output_cache, self.conformer_cnn_cache) = self.model.encoder.forward_chunk( chunk_xs, self.offset, required_cache_size, @@ -432,7 +481,7 @@ class PaddleASRConnectionHanddler: self.conformer_cnn_cache) outputs.append(y) - # update the offset + # update the global offset, in decoding frame unit self.offset += y.shape[1] ys = paddle.cat(outputs, 1) @@ -440,72 +489,116 @@ class PaddleASRConnectionHanddler: self.encoder_out = ys else: self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1) + logger.info( + f"This connection handler encoder out shape: {self.encoder_out.shape}" + ) # get the ctc probs ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) + ## decoding + # advance decoding self.searcher.search(ctc_probs, self.cached_feat.place) - + # get one best hyps self.hyps = self.searcher.get_one_best_hyps() - assert self.cached_feat.shape[0] == 1 - assert end >= cached_feature_num - self.cached_feat = self.cached_feat[0, end - - cached_feature_num:, :].unsqueeze(0) + # endpoint + if not is_finished: + + def contain_nonsilence(): + return len(self.hyps) > 0 and len(self.hyps[0]) > 0 + + decoding_something = contain_nonsilence() + if self.endpointer.endpoint_detected(ctc_probs.numpy(), + decoding_something): + self.endpoint_state = True + logger.info(f"Endpoint is detected at {self.num_frames} frame.") + + # advance cache of feat + assert self.cached_feat.shape[0] == 1 #(B=1,T,D) + assert end >= cached_feature_num + self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :] assert len( self.cached_feat.shape ) == 3, f"current cache feat shape is: {self.cached_feat.shape}" - logger.info( - f"This connection handler encoder out shape: {self.encoder_out.shape}" - ) - def update_result(self): + """Conformer/Transformer hyps to result. + """ logger.info("update the final result") hyps = self.hyps + + # output results and tokenids self.result_transcripts = [ self.text_feature.defeaturize(hyp) for hyp in hyps ] self.result_tokenids = [hyp for hyp in hyps] def get_result(self): + """return partial/ending asr result. + + Returns: + str: one best result of partial/ending. + """ if len(self.result_transcripts) > 0: return self.result_transcripts[0] else: return '' def get_word_time_stamp(self): + """return token timestamp result. + + Returns: + list: List of ('w':token, 'bg':time, 'ed':time) + """ return self.word_time_stamp @paddle.no_grad() def rescoring(self): - if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + """Second-Pass Decoding, + only for conformer and transformer model. + """ + if "deepspeech2" in self.model_type: + logger.info("deepspeech2 not support rescoring decoding.") return - logger.info("rescoring the final result") if "attention_rescoring" != self.ctc_decode_config.decoding_method: + logger.info( + f"decoding method not match: {self.ctc_decode_config.decoding_method}, need attention_rescoring" + ) return + logger.info("rescoring the final result") + + # last decoding for last audio self.searcher.finalize_search() + # update beam search results self.update_result() beam_size = self.ctc_decode_config.beam_size hyps = self.searcher.get_hyps() if hyps is None or len(hyps) == 0: + logger.info("No Hyps!") return + # rescore by decoder post probability + # assert len(hyps) == beam_size + # list of Tensor hyp_list = [] for hyp in hyps: hyp_content = hyp[0] # Prevent the hyp is empty if len(hyp_content) == 0: hyp_content = (self.model.ctc.blank_id, ) + hyp_content = paddle.to_tensor( hyp_content, place=self.device, dtype=paddle.long) hyp_list.append(hyp_content) - hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id) + + hyps_pad = pad_sequence( + hyp_list, batch_first=True, padding_value=self.model.ignore_id) hyps_lens = paddle.to_tensor( [len(hyp[0]) for hyp in hyps], place=self.device, dtype=paddle.long) # (beam_size,) @@ -531,10 +624,12 @@ class PaddleASRConnectionHanddler: score = 0.0 for j, w in enumerate(hyp[0]): score += decoder_out[i][j][w] + # last decoder output token is `eos`, for laste decoder input token. score += decoder_out[i][len(hyp[0])][self.model.eos] # add ctc score (which in ln domain) score += hyp[1] * self.ctc_decode_config.ctc_weight + if score > best_score: best_score = score best_index = i @@ -542,43 +637,56 @@ class PaddleASRConnectionHanddler: # update the one best result # hyps stored the beam results and each fields is: - logger.info(f"best index: {best_index}") + logger.info(f"best hyp index: {best_index}") # logger.info(f'best result: {hyps[best_index]}') # the field of the hyps is: + ## asr results # hyps[0][0]: the sentence word-id in the vocab with a tuple # hyps[0][1]: the sentence decoding probability with all paths + ## timestamp # hyps[0][2]: viterbi_blank ending probability - # hyps[0][3]: viterbi_non_blank probability + # hyps[0][3]: viterbi_non_blank dending probability # hyps[0][4]: current_token_prob, - # hyps[0][5]: times_viterbi_blank, - # hyps[0][6]: times_titerbi_non_blank + # hyps[0][5]: times_viterbi_blank ending timestamp, + # hyps[0][6]: times_titerbi_non_blank encding timestamp. self.hyps = [hyps[best_index][0]] + logger.info(f"best hyp ids: {self.hyps}") # update the hyps time stamp self.time_stamp = hyps[best_index][5] if hyps[best_index][2] > hyps[ best_index][3] else hyps[best_index][6] logger.info(f"time stamp: {self.time_stamp}") + # update one best result self.update_result() # update each word start and end time stamp - frame_shift_in_ms = self.model.encoder.embed.subsampling_rate * self.n_shift / self.sample_rate - logger.info(f"frame shift ms: {frame_shift_in_ms}") + # decoding frame to audio frame + decode_frame_shift = self.model.encoder.embed.subsampling_rate + decode_frame_shift_in_sec = decode_frame_shift * (self.n_shift / + self.sample_rate) + logger.info(f"decode frame shift in sec: {decode_frame_shift_in_sec}") + + global_offset_in_sec = self.global_frame_offset * self.frame_shift_in_ms / 1000.0 + logger.info(f"global offset: {global_offset_in_sec} sec.") + word_time_stamp = [] for idx, _ in enumerate(self.time_stamp): start = (self.time_stamp[idx - 1] + self.time_stamp[idx] ) / 2.0 if idx > 0 else 0 - start = start * frame_shift_in_ms + start = start * decode_frame_shift_in_sec end = (self.time_stamp[idx] + self.time_stamp[idx + 1] ) / 2.0 if idx < len(self.time_stamp) - 1 else self.offset - end = end * frame_shift_in_ms + + end = end * decode_frame_shift_in_sec word_time_stamp.append({ "w": self.result_transcripts[0][idx], - "bg": start, - "ed": end + "bg": global_offset_in_sec + start, + "ed": global_offset_in_sec + end }) - # logger.info(f"{self.result_transcripts[0][idx]}, start: {start}, end: {end}") + # logger.info(f"{word_time_stamp[-1]}") + self.word_time_stamp = word_time_stamp logger.info(f"word time stamp: {self.word_time_stamp}") @@ -586,7 +694,8 @@ class PaddleASRConnectionHanddler: class ASRServerExecutor(ASRExecutor): def __init__(self): super().__init__() - self.pretrained_models = pretrained_models + self.task_resource = CommonTaskResource( + task='asr', model_format='dynamic', inference_mode='online') def _init_from_path(self, model_type: str=None, @@ -596,6 +705,7 @@ class ASRServerExecutor(ASRExecutor): sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, decode_method: str='attention_rescoring', + num_decoding_left_chunks: int=-1, am_predictor_conf: dict=None): """ Init model and other resources from a specific path. @@ -608,21 +718,21 @@ class ASRServerExecutor(ASRExecutor): self.model_type = model_type self.sample_rate = sample_rate + logger.info(f"model_type: {self.model_type}") + sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str - if cfg_path is None or am_model is None or am_params is None: - logger.info(f"Load the pretrained model, tag = {tag}") - res_path = self._get_pretrained_path(tag) # wenetspeech_zh - self.res_path = res_path + self.task_resource.set_task_model(model_tag=tag) + if cfg_path is None or am_model is None or am_params is None: + self.res_path = self.task_resource.res_dir self.cfg_path = os.path.join( - res_path, self.pretrained_models[tag]['cfg_path']) + self.res_path, self.task_resource.res_dict['cfg_path']) - self.am_model = os.path.join(res_path, - self.pretrained_models[tag]['model']) - self.am_params = os.path.join(res_path, - self.pretrained_models[tag]['params']) - logger.info(res_path) + self.am_model = os.path.join(self.res_path, + self.task_resource.res_dict['model']) + self.am_params = os.path.join(self.res_path, + self.task_resource.res_dict['params']) else: self.cfg_path = os.path.abspath(cfg_path) self.am_model = os.path.abspath(am_model) @@ -630,44 +740,61 @@ class ASRServerExecutor(ASRExecutor): self.res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) - logger.info(self.cfg_path) - logger.info(self.am_model) - logger.info(self.am_params) + logger.info("Load the pretrained model:") + logger.info(f" tag = {tag}") + logger.info(f" res_path: {self.res_path}") + logger.info(f" cfg path: {self.cfg_path}") + logger.info(f" am_model path: {self.am_model}") + logger.info(f" am_params path: {self.am_params}") #Init body. self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) - with UpdateConfig(self.config): - if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: - from paddlespeech.s2t.io.collator import SpeechCollator - self.vocab = self.config.vocab_filepath + if self.config.spm_model_prefix: + self.config.spm_model_prefix = os.path.join( + self.res_path, self.config.spm_model_prefix) + logger.info(f"spm model path: {self.config.spm_model_prefix}") + + self.vocab = self.config.vocab_filepath + + self.text_feature = TextFeaturizer( + unit_type=self.config.unit_type, + vocab=self.config.vocab_filepath, + spm_model_prefix=self.config.spm_model_prefix) + + if "deepspeech2" in model_type: + with UpdateConfig(self.config): + # download lm self.config.decode.lang_model_path = os.path.join( MODEL_HOME, 'language_model', self.config.decode.lang_model_path) - self.collate_fn_test = SpeechCollator.from_config(self.config) - self.text_feature = TextFeaturizer( - unit_type=self.config.unit_type, vocab=self.vocab) - - lm_url = self.pretrained_models[tag]['lm_url'] - lm_md5 = self.pretrained_models[tag]['lm_md5'] - logger.info(f"Start to load language model {lm_url}") - self.download_lm( - lm_url, - os.path.dirname(self.config.decode.lang_model_path), lm_md5) - elif "conformer" in model_type or "transformer" in model_type: + + lm_url = self.task_resource.res_dict['lm_url'] + lm_md5 = self.task_resource.res_dict['lm_md5'] + logger.info(f"Start to load language model {lm_url}") + self.download_lm( + lm_url, + os.path.dirname(self.config.decode.lang_model_path), lm_md5) + + # AM predictor + logger.info("ASR engine start to init the am predictor") + self.am_predictor_conf = am_predictor_conf + self.am_predictor = init_predictor( + model_file=self.am_model, + params_file=self.am_params, + predictor_conf=self.am_predictor_conf) + + elif "conformer" in model_type or "transformer" in model_type: + with UpdateConfig(self.config): logger.info("start to create the stream conformer asr engine") - if self.config.spm_model_prefix: - self.config.spm_model_prefix = os.path.join( - self.res_path, self.config.spm_model_prefix) - self.vocab = self.config.vocab_filepath - self.text_feature = TextFeaturizer( - unit_type=self.config.unit_type, - vocab=self.config.vocab_filepath, - spm_model_prefix=self.config.spm_model_prefix) # update the decoding method if decode_method: self.config.decode.decoding_method = decode_method + # update num_decoding_left_chunks + if num_decoding_left_chunks: + assert num_decoding_left_chunks == -1 or num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0" + self.config.decode.num_decoding_left_chunks = num_decoding_left_chunks # we only support ctc_prefix_beam_search and attention_rescoring dedoding method # Generally we set the decoding_method to attention_rescoring @@ -677,337 +804,29 @@ class ASRServerExecutor(ASRExecutor): logger.info( "we set the decoding_method to attention_rescoring") self.config.decode.decoding_method = "attention_rescoring" + assert self.config.decode.decoding_method in [ "ctc_prefix_beam_search", "attention_rescoring" ], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}" - else: - raise Exception("wrong type") - if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: - # AM predictor - logger.info("ASR engine start to init the am predictor") - self.am_predictor_conf = am_predictor_conf - self.am_predictor = init_predictor( - model_file=self.am_model, - params_file=self.am_params, - predictor_conf=self.am_predictor_conf) - # decoder - logger.info("ASR engine start to create the ctc decoder instance") - self.decoder = CTCDecoder( - odim=self.config.output_dim, # is in vocab - enc_n_units=self.config.rnn_layer_size * 2, - blank_id=self.config.blank_id, - dropout_rate=0.0, - reduction=True, # sum - batch_average=True, # sum / batch_size - grad_norm_type=self.config.get('ctc_grad_norm_type', None)) - - # init decoder - logger.info("ASR engine start to init the ctc decoder") - cfg = self.config.decode - decode_batch_size = 1 # for online - self.decoder.init_decoder( - decode_batch_size, self.text_feature.vocab_list, - cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, - cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, - cfg.num_proc_bsearch) - - # init state box - self.chunk_state_h_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) - self.chunk_state_c_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) - elif "conformer" in model_type or "transformer" in model_type: + # load model model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} logger.info(f"model name: {model_name}") - model_class = dynamic_import(model_name, self.model_alias) - model_conf = self.config - model = model_class.from_config(model_conf) + model_class = self.task_resource.get_model_class(model_name) + model = model_class.from_config(self.config) self.model = model + self.model.set_state_dict(paddle.load(self.am_model)) self.model.eval() - - # load model - model_dict = paddle.load(self.am_model) - self.model.set_state_dict(model_dict) - logger.info("create the transformer like model success") - - # update the ctc decoding - self.searcher = CTCPrefixBeamSearch(self.config.decode) - self.transformer_decode_reset() - - return True - - def reset_decoder_and_chunk(self): - """reset decoder and chunk state for an new audio - """ - if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: - self.decoder.reset_decoder(batch_size=1) - # init state box, for new audio request - self.chunk_state_h_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) - self.chunk_state_c_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) - elif "conformer" in self.model_type or "transformer" in self.model_type: - self.transformer_decode_reset() - - def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str): - """decode one chunk - - Args: - x_chunk (numpy.array): shape[B, T, D] - x_chunk_lens (numpy.array): shape[B] - model_type (str): online model type - - Returns: - str: one best result - """ - logger.info("start to decoce chunk by chunk") - if "deepspeech2online" in model_type: - input_names = self.am_predictor.get_input_names() - audio_handle = self.am_predictor.get_input_handle(input_names[0]) - audio_len_handle = self.am_predictor.get_input_handle( - input_names[1]) - h_box_handle = self.am_predictor.get_input_handle(input_names[2]) - c_box_handle = self.am_predictor.get_input_handle(input_names[3]) - - audio_handle.reshape(x_chunk.shape) - audio_handle.copy_from_cpu(x_chunk) - - audio_len_handle.reshape(x_chunk_lens.shape) - audio_len_handle.copy_from_cpu(x_chunk_lens) - - h_box_handle.reshape(self.chunk_state_h_box.shape) - h_box_handle.copy_from_cpu(self.chunk_state_h_box) - - c_box_handle.reshape(self.chunk_state_c_box.shape) - c_box_handle.copy_from_cpu(self.chunk_state_c_box) - - output_names = self.am_predictor.get_output_names() - output_handle = self.am_predictor.get_output_handle(output_names[0]) - output_lens_handle = self.am_predictor.get_output_handle( - output_names[1]) - output_state_h_handle = self.am_predictor.get_output_handle( - output_names[2]) - output_state_c_handle = self.am_predictor.get_output_handle( - output_names[3]) - - self.am_predictor.run() - - output_chunk_probs = output_handle.copy_to_cpu() - output_chunk_lens = output_lens_handle.copy_to_cpu() - self.chunk_state_h_box = output_state_h_handle.copy_to_cpu() - self.chunk_state_c_box = output_state_c_handle.copy_to_cpu() - - self.decoder.next(output_chunk_probs, output_chunk_lens) - trans_best, trans_beam = self.decoder.decode() - logger.info(f"decode one best result: {trans_best[0]}") - return trans_best[0] - - elif "conformer" in model_type or "transformer" in model_type: - try: - logger.info( - f"we will use the transformer like model : {self.model_type}" - ) - self.advanced_decoding(x_chunk, x_chunk_lens) - self.update_result() - - return self.result_transcripts[0] - except Exception as e: - logger.exception(e) else: - raise Exception("invalid model name") - - def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens): - logger.info("start to decode with advanced_decoding method") - encoder_out, encoder_mask = self.encoder_forward(xs) - ctc_probs = self.model.ctc.log_softmax( - encoder_out) # (1, maxlen, vocab_size) - ctc_probs = ctc_probs.squeeze(0) - self.searcher.search(ctc_probs, xs.place) - # update the one best result - self.hyps = self.searcher.get_one_best_hyps() + raise Exception(f"not support: {model_type}") - # now we supprot ctc_prefix_beam_search and attention_rescoring - if "attention_rescoring" in self.config.decode.decoding_method: - self.rescoring(encoder_out, xs.place) - - def encoder_forward(self, xs): - logger.info("get the model out from the feat") - cfg = self.config.decode - decoding_chunk_size = cfg.decoding_chunk_size - num_decoding_left_chunks = cfg.num_decoding_left_chunks - - assert decoding_chunk_size > 0 - subsampling = self.model.encoder.embed.subsampling_rate - context = self.model.encoder.embed.right_context + 1 - stride = subsampling * decoding_chunk_size - - # decoding window for model - decoding_window = (decoding_chunk_size - 1) * subsampling + context - num_frames = xs.shape[1] - required_cache_size = decoding_chunk_size * num_decoding_left_chunks - - logger.info("start to do model forward") - outputs = [] - - # num_frames - context + 1 ensure that current frame can get context window - for cur in range(0, num_frames - context + 1, stride): - end = min(cur + decoding_window, num_frames) - chunk_xs = xs[:, cur:end, :] - (y, self.subsampling_cache, self.elayers_output_cache, - self.conformer_cnn_cache) = self.model.encoder.forward_chunk( - chunk_xs, self.offset, required_cache_size, - self.subsampling_cache, self.elayers_output_cache, - self.conformer_cnn_cache) - outputs.append(y) - self.offset += y.shape[1] - - ys = paddle.cat(outputs, 1) - masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) - masks = masks.unsqueeze(1) - return ys, masks - - def rescoring(self, encoder_out, device): - logger.info("start to rescoring the hyps") - beam_size = self.config.decode.beam_size - hyps = self.searcher.get_hyps() - assert len(hyps) == beam_size - - hyp_list = [] - for hyp in hyps: - hyp_content = hyp[0] - # Prevent the hyp is empty - if len(hyp_content) == 0: - hyp_content = (self.model.ctc.blank_id, ) - hyp_content = paddle.to_tensor( - hyp_content, place=device, dtype=paddle.long) - hyp_list.append(hyp_content) - hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id) - hyps_lens = paddle.to_tensor( - [len(hyp[0]) for hyp in hyps], place=device, - dtype=paddle.long) # (beam_size,) - hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos, - self.model.ignore_id) - hyps_lens = hyps_lens + 1 # Add at begining - - encoder_out = encoder_out.repeat(beam_size, 1, 1) - encoder_mask = paddle.ones( - (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) - decoder_out, _ = self.model.decoder( - encoder_out, encoder_mask, hyps_pad, - hyps_lens) # (beam_size, max_hyps_len, vocab_size) - # ctc score in ln domain - decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) - decoder_out = decoder_out.numpy() - - # Only use decoder score for rescoring - best_score = -float('inf') - best_index = 0 - # hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size - for i, hyp in enumerate(hyps): - score = 0.0 - for j, w in enumerate(hyp[0]): - score += decoder_out[i][j][w] - # last decoder output token is `eos`, for laste decoder input token. - score += decoder_out[i][len(hyp[0])][self.model.eos] - # add ctc score (which in ln domain) - score += hyp[1] * self.config.decode.ctc_weight - if score > best_score: - best_score = score - best_index = i - - # update the one best result - self.hyps = [hyps[best_index][0]] - return hyps[best_index][0] - - def transformer_decode_reset(self): - self.subsampling_cache = None - self.elayers_output_cache = None - self.conformer_cnn_cache = None - self.offset = 0 - # decoding reset - self.searcher.reset() - - def update_result(self): - logger.info("update the final result") - hyps = self.hyps - self.result_transcripts = [ - self.text_feature.defeaturize(hyp) for hyp in hyps - ] - self.result_tokenids = [hyp for hyp in hyps] - - def extract_feat(self, samples, sample_rate): - """extract feat - - Args: - samples (numpy.array): numpy.float32 - sample_rate (int): sample rate - - Returns: - x_chunk (numpy.array): shape[B, T, D] - x_chunk_lens (numpy.array): shape[B] - """ - - if "deepspeech2online" in self.model_type: - # pcm16 -> pcm 32 - samples = pcm2float(samples) - # read audio - speech_segment = SpeechSegment.from_pcm( - samples, sample_rate, transcript=" ") - # audio augment - self.collate_fn_test.augmentation.transform_audio(speech_segment) - - # extract speech feature - spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( - speech_segment, self.collate_fn_test.keep_transcription_text) - # CMVN spectrum - if self.collate_fn_test._normalizer: - spectrum = self.collate_fn_test._normalizer.apply(spectrum) - - # spectrum augment - audio = self.collate_fn_test.augmentation.transform_feature( - spectrum) - - audio_len = audio.shape[0] - audio = paddle.to_tensor(audio, dtype='float32') - # audio_len = paddle.to_tensor(audio_len) - audio = paddle.unsqueeze(audio, axis=0) - - x_chunk = audio.numpy() - x_chunk_lens = np.array([audio_len]) - - return x_chunk, x_chunk_lens - elif "conformer_online" in self.model_type: - - if sample_rate != self.sample_rate: - logger.info(f"audio sample rate {sample_rate} is not match," - "the model sample_rate is {self.sample_rate}") - logger.info(f"ASR Engine use the {self.model_type} to process") - logger.info("Create the preprocess instance") - preprocess_conf = self.config.preprocess_config - preprocess_args = {"train": False} - preprocessing = Transformation(preprocess_conf) - - logger.info("Read the audio file") - logger.info(f"audio shape: {samples.shape}") - # fbank - x_chunk = preprocessing(samples, **preprocess_args) - x_chunk_lens = paddle.to_tensor(x_chunk.shape[0]) - x_chunk = paddle.to_tensor( - x_chunk, dtype="float32").unsqueeze(axis=0) - logger.info( - f"process the audio feature success, feat shape: {x_chunk.shape}" - ) - return x_chunk, x_chunk_lens + logger.info(f"create the {model_type} model success") + return True class ASREngine(BaseEngine): - """ASR server engine + """ASR server resource Args: metaclass: Defaults to Singleton. @@ -1015,7 +834,7 @@ class ASREngine(BaseEngine): def __init__(self): super(ASREngine, self).__init__() - logger.info("create the online asr engine instance") + logger.info("create the online asr engine resource instance") def init(self, config: dict) -> bool: """init engine resource @@ -1026,16 +845,11 @@ class ASREngine(BaseEngine): Returns: bool: init failed or success """ - self.input = None - self.output = "" - self.executor = ASRServerExecutor() self.config = config + self.executor = ASRServerExecutor() + try: - if self.config.get("device", None): - self.device = self.config.device - else: - self.device = paddle.get_device() - logger.info(f"paddlespeech_server set the device: {self.device}") + self.device = self.config.get("device", paddle.get_device()) paddle.set_device(self.device) except BaseException as e: logger.error( @@ -1045,6 +859,8 @@ class ASREngine(BaseEngine): "If all GPU or XPU is used, you can set the server to 'cpu'") sys.exit(-1) + logger.info(f"paddlespeech_server set the device: {self.device}") + if not self.executor._init_from_path( model_type=self.config.model_type, am_model=self.config.am_model, @@ -1053,6 +869,7 @@ class ASREngine(BaseEngine): sample_rate=self.config.sample_rate, cfg_path=self.config.cfg_path, decode_method=self.config.decode_method, + num_decoding_left_chunks=self.config.num_decoding_left_chunks, am_predictor_conf=self.config.am_predictor_conf): logger.error( "Init the ASR server occurs error, please check the server configuration yaml" @@ -1062,42 +879,19 @@ class ASREngine(BaseEngine): logger.info("Initialize ASR server engine successfully.") return True - def preprocess(self, - samples, - sample_rate, - model_type="deepspeech2online_aishell-zh-16k"): - """preprocess - - Args: - samples (numpy.array): numpy.float32 - sample_rate (int): sample rate + def new_handler(self): + """New handler from model. Returns: - x_chunk (numpy.array): shape[B, T, D] - x_chunk_lens (numpy.array): shape[B] + PaddleASRConnectionHanddler: asr handler instance """ - # if "deepspeech" in model_type: - x_chunk, x_chunk_lens = self.executor.extract_feat(samples, sample_rate) - return x_chunk, x_chunk_lens + return PaddleASRConnectionHanddler(self) - def run(self, x_chunk, x_chunk_lens, decoder_chunk_size=1): - """run online engine + def preprocess(self, *args, **kwargs): + raise NotImplementedError("Online not using this.") - Args: - x_chunk (numpy.array): shape[B, T, D] - x_chunk_lens (numpy.array): shape[B] - decoder_chunk_size(int) - """ - self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens, - self.config.model_type) + def run(self, *args, **kwargs): + raise NotImplementedError("Online not using this.") def postprocess(self): - """postprocess - """ - return self.output - - def reset(self): - """reset engine decoder and inference state - """ - self.executor.reset_decoder_and_chunk() - self.output = "" + raise NotImplementedError("Online not using this.") diff --git a/paddlespeech/server/engine/asr/online/ctc_endpoint.py b/paddlespeech/server/engine/asr/online/ctc_endpoint.py new file mode 100644 index 000000000..2dba36417 --- /dev/null +++ b/paddlespeech/server/engine/asr/online/ctc_endpoint.py @@ -0,0 +1,118 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass + +import numpy as np + +from paddlespeech.cli.log import logger + + +@dataclass +class OnlineCTCEndpointRule: + must_contain_nonsilence: bool = True + min_trailing_silence: int = 1000 + min_utterance_length: int = 0 + + +@dataclass +class OnlineCTCEndpoingOpt: + frame_shift_in_ms: int = 10 + + blank: int = 0 # blank id, that we consider as silence for purposes of endpointing. + blank_threshold: float = 0.8 # above blank threshold is silence + + # We support three rules. We terminate decoding if ANY of these rules + # evaluates to "true". If you want to add more rules, do it by changing this + # code. If you want to disable a rule, you can set the silence-timeout for + # that rule to a very large number. + + # rule1 times out after 5 seconds of silence, even if we decoded nothing. + rule1: OnlineCTCEndpointRule = OnlineCTCEndpointRule(False, 5000, 0) + # rule4 times out after 1.0 seconds of silence after decoding something, + # even if we did not reach a final-state at all. + rule2: OnlineCTCEndpointRule = OnlineCTCEndpointRule(True, 1000, 0) + # rule5 times out after the utterance is 20 seconds long, regardless of + # anything else. + rule3: OnlineCTCEndpointRule = OnlineCTCEndpointRule(False, 0, 20000) + + +class OnlineCTCEndpoint: + """ + [END-TO-END AUTOMATIC SPEECH RECOGNITION INTEGRATED WITH CTC-BASED VOICE ACTIVITY DETECTION](https://arxiv.org/pdf/2002.00551.pdf) + """ + + def __init__(self, opts: OnlineCTCEndpoingOpt): + self.opts = opts + logger.info(f"Endpont Opts: {opts}") + self.frame_shift_in_ms = opts.frame_shift_in_ms + + self.num_frames_decoded = 0 + self.trailing_silence_frames = 0 + + self.reset() + + def reset(self): + self.num_frames_decoded = 0 + self.trailing_silence_frames = 0 + + def rule_activated(self, + rule: OnlineCTCEndpointRule, + rule_name: str, + decoding_something: bool, + trailine_silence: int, + utterance_length: int) -> bool: + ans = ( + decoding_something or (not rule.must_contain_nonsilence) + ) and trailine_silence >= rule.min_trailing_silence and utterance_length >= rule.min_utterance_length + if (ans): + logger.info(f"Endpoint Rule: {rule_name} activated: {rule}") + return ans + + def endpoint_detected(self, + ctc_log_probs: np.ndarray, + decoding_something: bool) -> bool: + """detect endpoint. + + Args: + ctc_log_probs (np.ndarray): (T, D) + decoding_something (bool): contain nonsilince. + + Returns: + bool: whether endpoint detected. + """ + for logprob in ctc_log_probs: + blank_prob = np.exp(logprob[self.opts.blank]) + + self.num_frames_decoded += 1 + if blank_prob > self.opts.blank_threshold: + self.trailing_silence_frames += 1 + else: + self.trailing_silence_frames = 0 + + assert self.num_frames_decoded >= self.trailing_silence_frames + assert self.frame_shift_in_ms > 0 + + utterance_length = self.num_frames_decoded * self.frame_shift_in_ms + trailing_silence = self.trailing_silence_frames * self.frame_shift_in_ms + + if self.rule_activated(self.opts.rule1, 'rule1', decoding_something, + trailing_silence, utterance_length): + return True + if self.rule_activated(self.opts.rule2, 'rule2', decoding_something, + trailing_silence, utterance_length): + return True + if self.rule_activated(self.opts.rule3, 'rule3', decoding_something, + trailing_silence, utterance_length): + return True + return False diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py index 4c9ac3acb..46f310c80 100644 --- a/paddlespeech/server/engine/asr/online/ctc_search.py +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -30,8 +30,29 @@ class CTCPrefixBeamSearch: config (yacs.config.CfgNode): the ctc prefix beam search configuration """ self.config = config + + # beam size + self.first_beam_size = self.config.beam_size + # TODO(support second beam size) + self.second_beam_size = int(self.first_beam_size * 1.0) + logger.info( + f"first and second beam size: {self.first_beam_size}, {self.second_beam_size}" + ) + + # state + self.cur_hyps = None + self.hyps = None + self.abs_time_step = 0 + self.reset() + def reset(self): + """Rest the search cache value + """ + self.cur_hyps = None + self.hyps = None + self.abs_time_step = 0 + @paddle.no_grad() def search(self, ctc_probs, device, blank_id=0): """ctc prefix beam search method decode a chunk feature @@ -47,12 +68,17 @@ class CTCPrefixBeamSearch: """ # decode logger.info("start to ctc prefix search") - + assert len(ctc_probs.shape) == 2 batch_size = 1 - beam_size = self.config.beam_size - maxlen = ctc_probs.shape[0] - assert len(ctc_probs.shape) == 2 + vocab_size = ctc_probs.shape[1] + first_beam_size = min(self.first_beam_size, vocab_size) + second_beam_size = min(self.second_beam_size, vocab_size) + logger.info( + f"effect first and second beam size: {self.first_beam_size}, {self.second_beam_size}" + ) + + maxlen = ctc_probs.shape[0] # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) # 0. blank_ending_score, @@ -75,7 +101,8 @@ class CTCPrefixBeamSearch: # 2.1 First beam prune: select topk best # do token passing process - top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) + top_k_logp, top_k_index = logp.topk( + first_beam_size) # (first_beam_size,) for s in top_k_index: s = s.item() ps = logp[s].item() @@ -148,7 +175,7 @@ class CTCPrefixBeamSearch: next_hyps.items(), key=lambda x: log_add([x[1][0], x[1][1]]), reverse=True) - self.cur_hyps = next_hyps[:beam_size] + self.cur_hyps = next_hyps[:second_beam_size] # 2.3 update the absolute time step self.abs_time_step += 1 @@ -163,7 +190,7 @@ class CTCPrefixBeamSearch: """Return the one best result Returns: - list: the one best result + list: the one best result, List[str] """ return [self.hyps[0][0]] @@ -171,17 +198,10 @@ class CTCPrefixBeamSearch: """Return the search hyps Returns: - list: return the search hyps + list: return the search hyps, List[Tuple[str, float, ...]] """ return self.hyps - def reset(self): - """Rest the search cache value - """ - self.cur_hyps = None - self.hyps = None - self.abs_time_step = 0 - def finalize_search(self): """do nothing in ctc_prefix_beam_search """ diff --git a/paddlespeech/server/engine/asr/online/pretrained_models.py b/paddlespeech/server/engine/asr/online/pretrained_models.py deleted file mode 100644 index ff3778657..000000000 --- a/paddlespeech/server/engine/asr/online/pretrained_models.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -pretrained_models = { - "deepspeech2online_aishell-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz', - 'md5': - '98b87b171b7240b7cae6e07d8d0bc9be', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/deepspeech2_online/checkpoints/avg_1', - 'model': - 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', - 'params': - 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', - 'lm_url': - 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5': - '29e02312deb2e59b3c8686c7966d4fe3' - }, - "conformer_online_multicn-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz', - 'md5': - '0ac93d390552336f2a906aec9e33c5fa', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/chunk_conformer/checkpoints/multi_cn', - 'model': - 'exp/chunk_conformer/checkpoints/multi_cn.pdparams', - 'params': - 'exp/chunk_conformer/checkpoints/multi_cn.pdparams', - 'lm_url': - 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5': - '29e02312deb2e59b3c8686c7966d4fe3' - }, - "conformer_online_wenetspeech-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz', - 'md5': - 'b8c02632b04da34aca88459835be54a6', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/chunk_conformer/checkpoints/avg_10', - 'model': - 'exp/chunk_conformer/checkpoints/avg_10.pdparams', - 'params': - 'exp/chunk_conformer/checkpoints/avg_10.pdparams', - 'lm_url': - '', - 'lm_md5': - '', - }, -} diff --git a/paddlespeech/server/engine/asr/paddleinference/asr_engine.py b/paddlespeech/server/engine/asr/paddleinference/asr_engine.py index e275f1088..1a3b4620a 100644 --- a/paddlespeech/server/engine/asr/paddleinference/asr_engine.py +++ b/paddlespeech/server/engine/asr/paddleinference/asr_engine.py @@ -19,10 +19,10 @@ from typing import Optional import paddle from yacs.config import CfgNode -from .pretrained_models import pretrained_models from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.cli.utils import MODEL_HOME +from paddlespeech.resource import CommonTaskResource from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.utils.utility import UpdateConfig @@ -30,13 +30,14 @@ from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.paddle_predictor import init_predictor from paddlespeech.server.utils.paddle_predictor import run_model -__all__ = ['ASREngine'] +__all__ = ['ASREngine', 'PaddleASRConnectionHandler'] class ASRServerExecutor(ASRExecutor): def __init__(self): super().__init__() - self.pretrained_models = pretrained_models + self.task_resource = CommonTaskResource( + task='asr', model_format='static') def _init_from_path(self, model_type: str='wenetspeech', @@ -50,20 +51,21 @@ class ASRServerExecutor(ASRExecutor): """ Init model and other resources from a specific path. """ - + self.max_len = 50 sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str + self.max_len = 50 + self.task_resource.set_task_model(model_tag=tag) if cfg_path is None or am_model is None or am_params is None: - res_path = self._get_pretrained_path(tag) # wenetspeech_zh - self.res_path = res_path + self.res_path = self.task_resource.res_dir self.cfg_path = os.path.join( - res_path, self.pretrained_models[tag]['cfg_path']) + self.res_path, self.task_resource.res_dict['cfg_path']) - self.am_model = os.path.join(res_path, - self.pretrained_models[tag]['model']) - self.am_params = os.path.join(res_path, - self.pretrained_models[tag]['params']) - logger.info(res_path) + self.am_model = os.path.join(self.res_path, + self.task_resource.res_dict['model']) + self.am_params = os.path.join(self.res_path, + self.task_resource.res_dict['params']) + logger.info(self.res_path) logger.info(self.cfg_path) logger.info(self.am_model) logger.info(self.am_params) @@ -79,22 +81,25 @@ class ASRServerExecutor(ASRExecutor): self.config.merge_from_file(self.cfg_path) with UpdateConfig(self.config): - if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: - from paddlespeech.s2t.io.collator import SpeechCollator + if "deepspeech2" in model_type: self.vocab = self.config.vocab_filepath + if self.config.spm_model_prefix: + self.config.spm_model_prefix = os.path.join( + self.res_path, self.config.spm_model_prefix) + self.text_feature = TextFeaturizer( + unit_type=self.config.unit_type, + vocab=self.vocab, + spm_model_prefix=self.config.spm_model_prefix) self.config.decode.lang_model_path = os.path.join( MODEL_HOME, 'language_model', self.config.decode.lang_model_path) - self.collate_fn_test = SpeechCollator.from_config(self.config) - self.text_feature = TextFeaturizer( - unit_type=self.config.unit_type, vocab=self.vocab) - lm_url = self.pretrained_models[tag]['lm_url'] - lm_md5 = self.pretrained_models[tag]['lm_md5'] + lm_url = self.task_resource.res_dict['lm_url'] + lm_md5 = self.task_resource.res_dict['lm_md5'] self.download_lm( lm_url, os.path.dirname(self.config.decode.lang_model_path), lm_md5) - elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: + elif "conformer" in model_type or "transformer" in model_type: raise Exception("wrong type") else: raise Exception("wrong type") @@ -124,7 +129,7 @@ class ASRServerExecutor(ASRExecutor): cfg = self.config.decode audio = self._inputs["audio"] audio_len = self._inputs["audio_len"] - if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: + if "deepspeech2" in model_type: decode_batch_size = audio.shape[0] # init once self.decoder.init_decoder( @@ -172,10 +177,23 @@ class ASREngine(BaseEngine): Returns: bool: init failed or success """ - self.input = None - self.output = None self.executor = ASRServerExecutor() self.config = config + self.engine_type = "inference" + + try: + if self.config.am_predictor_conf.device is not None: + self.device = self.config.am_predictor_conf.device + else: + self.device = paddle.get_device() + + paddle.set_device(self.device) + except Exception as e: + logger.error( + "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" + ) + logger.error(e) + return False self.executor._init_from_path( model_type=self.config.model_type, @@ -190,22 +208,41 @@ class ASREngine(BaseEngine): logger.info("Initialize ASR server engine successfully.") return True + +class PaddleASRConnectionHandler(ASRServerExecutor): + def __init__(self, asr_engine): + """The PaddleSpeech ASR Server Connection Handler + This connection process every asr server request + Args: + asr_engine (ASREngine): The ASR engine + """ + super().__init__() + self.input = None + self.output = None + self.asr_engine = asr_engine + self.executor = self.asr_engine.executor + self.config = self.executor.config + self.max_len = self.executor.max_len + self.decoder = self.executor.decoder + self.am_predictor = self.executor.am_predictor + self.text_feature = self.executor.text_feature + def run(self, audio_data): - """engine run + """engine run Args: audio_data (bytes): base64.b64decode """ - if self.executor._check( - io.BytesIO(audio_data), self.config.sample_rate, - self.config.force_yes): + if self._check( + io.BytesIO(audio_data), self.asr_engine.config.sample_rate, + self.asr_engine.config.force_yes): logger.info("start running asr engine") - self.executor.preprocess(self.config.model_type, - io.BytesIO(audio_data)) + self.preprocess(self.asr_engine.config.model_type, + io.BytesIO(audio_data)) st = time.time() - self.executor.infer(self.config.model_type) + self.infer(self.asr_engine.config.model_type) infer_time = time.time() - st - self.output = self.executor.postprocess() # Retrieve result of asr. + self.output = self.postprocess() # Retrieve result of asr. logger.info("end inferring asr engine") else: logger.info("file check failed!") @@ -213,8 +250,3 @@ class ASREngine(BaseEngine): logger.info("inference time: {}".format(infer_time)) logger.info("asr engine type: paddle inference") - - def postprocess(self): - """postprocess - """ - return self.output diff --git a/paddlespeech/server/engine/asr/paddleinference/pretrained_models.py b/paddlespeech/server/engine/asr/paddleinference/pretrained_models.py deleted file mode 100644 index c4c23e38c..000000000 --- a/paddlespeech/server/engine/asr/paddleinference/pretrained_models.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -pretrained_models = { - "deepspeech2offline_aishell-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz', - 'md5': - '932c3593d62fe5c741b59b31318aa314', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/deepspeech2/checkpoints/avg_1', - 'model': - 'exp/deepspeech2/checkpoints/avg_1.jit.pdmodel', - 'params': - 'exp/deepspeech2/checkpoints/avg_1.jit.pdiparams', - 'lm_url': - 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5': - '29e02312deb2e59b3c8686c7966d4fe3' - }, -} diff --git a/paddlespeech/server/engine/asr/python/asr_engine.py b/paddlespeech/server/engine/asr/python/asr_engine.py index d60a5feae..f9cc3a665 100644 --- a/paddlespeech/server/engine/asr/python/asr_engine.py +++ b/paddlespeech/server/engine/asr/python/asr_engine.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import io +import sys import time import paddle @@ -20,7 +21,7 @@ from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.server.engine.base_engine import BaseEngine -__all__ = ['ASREngine'] +__all__ = ['ASREngine', 'PaddleASRConnectionHandler'] class ASRServerExecutor(ASRExecutor): @@ -48,20 +49,23 @@ class ASREngine(BaseEngine): Returns: bool: init failed or success """ - self.input = None - self.output = None self.executor = ASRServerExecutor() self.config = config + self.engine_type = "python" + try: - if self.config.device: + if self.config.device is not None: self.device = self.config.device else: self.device = paddle.get_device() + paddle.set_device(self.device) - except BaseException: + except Exception as e: logger.error( "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" ) + logger.error(e) + return False self.executor._init_from_path( self.config.model, self.config.lang, self.config.sample_rate, @@ -72,6 +76,24 @@ class ASREngine(BaseEngine): (self.device)) return True + +class PaddleASRConnectionHandler(ASRServerExecutor): + def __init__(self, asr_engine): + """The PaddleSpeech ASR Server Connection Handler + This connection process every asr server request + Args: + asr_engine (ASREngine): The ASR engine + """ + super().__init__() + self.input = None + self.output = None + self.asr_engine = asr_engine + self.executor = self.asr_engine.executor + self.max_len = self.executor.max_len + self.text_feature = self.executor.text_feature + self.model = self.executor.model + self.config = self.executor.config + def run(self, audio_data): """engine run @@ -79,17 +101,16 @@ class ASREngine(BaseEngine): audio_data (bytes): base64.b64decode """ try: - if self.executor._check( - io.BytesIO(audio_data), self.config.sample_rate, - self.config.force_yes): + if self._check( + io.BytesIO(audio_data), self.asr_engine.config.sample_rate, + self.asr_engine.config.force_yes): logger.info("start run asr engine") - self.executor.preprocess(self.config.model, - io.BytesIO(audio_data)) + self.preprocess(self.asr_engine.config.model, + io.BytesIO(audio_data)) st = time.time() - self.executor.infer(self.config.model) + self.infer(self.asr_engine.config.model) infer_time = time.time() - st - self.output = self.executor.postprocess( - ) # Retrieve result of asr. + self.output = self.postprocess() # Retrieve result of asr. else: logger.info("file check failed!") self.output = None @@ -98,8 +119,4 @@ class ASREngine(BaseEngine): logger.info("asr engine type: python") except Exception as e: logger.info(e) - - def postprocess(self): - """postprocess - """ - return self.output + sys.exit(-1) diff --git a/paddlespeech/server/engine/cls/paddleinference/cls_engine.py b/paddlespeech/server/engine/cls/paddleinference/cls_engine.py index 0906c2412..389d56055 100644 --- a/paddlespeech/server/engine/cls/paddleinference/cls_engine.py +++ b/paddlespeech/server/engine/cls/paddleinference/cls_engine.py @@ -14,30 +14,32 @@ import io import os import time +from collections import OrderedDict from typing import Optional import numpy as np import paddle import yaml -from .pretrained_models import pretrained_models from paddlespeech.cli.cls.infer import CLSExecutor from paddlespeech.cli.log import logger +from paddlespeech.resource import CommonTaskResource from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.paddle_predictor import init_predictor from paddlespeech.server.utils.paddle_predictor import run_model -__all__ = ['CLSEngine'] +__all__ = ['CLSEngine', 'PaddleCLSConnectionHandler'] class CLSServerExecutor(CLSExecutor): def __init__(self): super().__init__() - self.pretrained_models = pretrained_models + self.task_resource = CommonTaskResource( + task='cls', model_format='static') def _init_from_path( self, - model_type: str='panns_cnn14', + model_type: str='panns_cnn14_audioset', cfg_path: Optional[os.PathLike]=None, model_path: Optional[os.PathLike]=None, params_path: Optional[os.PathLike]=None, @@ -49,15 +51,16 @@ class CLSServerExecutor(CLSExecutor): if cfg_path is None or model_path is None or params_path is None or label_file is None: tag = model_type + '-' + '32k' - self.res_path = self._get_pretrained_path(tag) + self.task_resource.set_task_model(model_tag=tag) + self.res_path = self.task_resource.res_dir self.cfg_path = os.path.join( - self.res_path, self.pretrained_models[tag]['cfg_path']) + self.res_path, self.task_resource.res_dict['cfg_path']) self.model_path = os.path.join( - self.res_path, self.pretrained_models[tag]['model_path']) + self.res_path, self.task_resource.res_dict['model_path']) self.params_path = os.path.join( - self.res_path, self.pretrained_models[tag]['params_path']) + self.res_path, self.task_resource.res_dict['params_path']) self.label_file = os.path.join( - self.res_path, self.pretrained_models[tag]['label_file']) + self.res_path, self.task_resource.res_dict['label_file']) else: self.cfg_path = os.path.abspath(cfg_path) self.model_path = os.path.abspath(model_path) @@ -119,14 +122,55 @@ class CLSEngine(BaseEngine): """ self.executor = CLSServerExecutor() self.config = config - self.executor._init_from_path( - self.config.model_type, self.config.cfg_path, - self.config.model_path, self.config.params_path, - self.config.label_file, self.config.predictor_conf) + self.engine_type = "inference" + + try: + if self.config.predictor_conf.device is not None: + self.device = self.config.predictor_conf.device + else: + self.device = paddle.get_device() + paddle.set_device(self.device) + except Exception as e: + logger.error( + "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" + ) + logger.error(e) + return False + + try: + self.executor._init_from_path( + self.config.model_type, self.config.cfg_path, + self.config.model_path, self.config.params_path, + self.config.label_file, self.config.predictor_conf) + + except Exception as e: + logger.error("Initialize CLS server engine Failed.") + logger.error(e) + return False logger.info("Initialize CLS server engine successfully.") return True + +class PaddleCLSConnectionHandler(CLSServerExecutor): + def __init__(self, cls_engine): + """The PaddleSpeech CLS Server Connection Handler + This connection process every cls server request + Args: + cls_engine (CLSEngine): The CLS engine + """ + super().__init__() + logger.info( + "Create PaddleCLSConnectionHandler to process the cls request") + + self._inputs = OrderedDict() + self._outputs = OrderedDict() + self.cls_engine = cls_engine + self.executor = self.cls_engine.executor + self._conf = self.executor._conf + self._label_list = self.executor._label_list + self.predictor = self.executor.predictor + def run(self, audio_data): """engine run @@ -134,9 +178,9 @@ class CLSEngine(BaseEngine): audio_data (bytes): base64.b64decode """ - self.executor.preprocess(io.BytesIO(audio_data)) + self.preprocess(io.BytesIO(audio_data)) st = time.time() - self.executor.infer() + self.infer() infer_time = time.time() - st logger.info("inference time: {}".format(infer_time)) @@ -145,15 +189,15 @@ class CLSEngine(BaseEngine): def postprocess(self, topk: int): """postprocess """ - assert topk <= len(self.executor._label_list - ), 'Value of topk is larger than number of labels.' + assert topk <= len( + self._label_list), 'Value of topk is larger than number of labels.' - result = np.squeeze(self.executor._outputs['logits'], axis=0) + result = np.squeeze(self._outputs['logits'], axis=0) topk_idx = (-result).argsort()[:topk] topk_results = [] for idx in topk_idx: res = {} - label, score = self.executor._label_list[idx], result[idx] + label, score = self._label_list[idx], result[idx] res['class_name'] = label res['prob'] = score topk_results.append(res) diff --git a/paddlespeech/server/engine/cls/paddleinference/pretrained_models.py b/paddlespeech/server/engine/cls/paddleinference/pretrained_models.py deleted file mode 100644 index e49148746..000000000 --- a/paddlespeech/server/engine/cls/paddleinference/pretrained_models.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -pretrained_models = { - "panns_cnn6-32k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn6_static.tar.gz', - 'md5': - 'da087c31046d23281d8ec5188c1967da', - 'cfg_path': - 'panns.yaml', - 'model_path': - 'inference.pdmodel', - 'params_path': - 'inference.pdiparams', - 'label_file': - 'audioset_labels.txt', - }, - "panns_cnn10-32k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn10_static.tar.gz', - 'md5': - '5460cc6eafbfaf0f261cc75b90284ae1', - 'cfg_path': - 'panns.yaml', - 'model_path': - 'inference.pdmodel', - 'params_path': - 'inference.pdiparams', - 'label_file': - 'audioset_labels.txt', - }, - "panns_cnn14-32k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn14_static.tar.gz', - 'md5': - 'ccc80b194821274da79466862b2ab00f', - 'cfg_path': - 'panns.yaml', - 'model_path': - 'inference.pdmodel', - 'params_path': - 'inference.pdiparams', - 'label_file': - 'audioset_labels.txt', - }, -} diff --git a/paddlespeech/server/engine/cls/python/cls_engine.py b/paddlespeech/server/engine/cls/python/cls_engine.py index 1a975b0a0..f8d8f20ef 100644 --- a/paddlespeech/server/engine/cls/python/cls_engine.py +++ b/paddlespeech/server/engine/cls/python/cls_engine.py @@ -13,7 +13,7 @@ # limitations under the License. import io import time -from typing import List +from collections import OrderedDict import paddle @@ -21,7 +21,7 @@ from paddlespeech.cli.cls.infer import CLSExecutor from paddlespeech.cli.log import logger from paddlespeech.server.engine.base_engine import BaseEngine -__all__ = ['CLSEngine'] +__all__ = ['CLSEngine', 'PaddleCLSConnectionHandler'] class CLSServerExecutor(CLSExecutor): @@ -29,21 +29,6 @@ class CLSServerExecutor(CLSExecutor): super().__init__() pass - def get_topk_results(self, topk: int) -> List: - assert topk <= len( - self._label_list), 'Value of topk is larger than number of labels.' - - result = self._outputs['logits'].squeeze(0).numpy() - topk_idx = (-result).argsort()[:topk] - res = {} - topk_results = [] - for idx in topk_idx: - label, score = self._label_list[idx], result[idx] - res['class'] = label - res['prob'] = score - topk_results.append(res) - return topk_results - class CLSEngine(BaseEngine): """CLS server engine @@ -64,42 +49,65 @@ class CLSEngine(BaseEngine): Returns: bool: init failed or success """ - self.input = None - self.output = None self.executor = CLSServerExecutor() self.config = config + self.engine_type = "python" + try: - if self.config.device: + if self.config.device is not None: self.device = self.config.device else: self.device = paddle.get_device() paddle.set_device(self.device) - except BaseException: + except Exception as e: logger.error( "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" ) + logger.error(e) + return False try: self.executor._init_from_path( self.config.model, self.config.cfg_path, self.config.ckpt_path, self.config.label_file) - except BaseException: + except Exception as e: logger.error("Initialize CLS server engine Failed.") + logger.error(e) return False logger.info("Initialize CLS server engine successfully on device: %s." % (self.device)) return True + +class PaddleCLSConnectionHandler(CLSServerExecutor): + def __init__(self, cls_engine): + """The PaddleSpeech CLS Server Connection Handler + This connection process every cls server request + Args: + cls_engine (CLSEngine): The CLS engine + """ + super().__init__() + logger.info( + "Create PaddleCLSConnectionHandler to process the cls request") + + self._inputs = OrderedDict() + self._outputs = OrderedDict() + self.cls_engine = cls_engine + self.executor = self.cls_engine.executor + self._conf = self.executor._conf + self._label_list = self.executor._label_list + self.model = self.executor.model + def run(self, audio_data): """engine run Args: audio_data (bytes): base64.b64decode """ - self.executor.preprocess(io.BytesIO(audio_data)) + self.preprocess(io.BytesIO(audio_data)) st = time.time() - self.executor.infer() + self.infer() infer_time = time.time() - st logger.info("inference time: {}".format(infer_time)) @@ -108,15 +116,15 @@ class CLSEngine(BaseEngine): def postprocess(self, topk: int): """postprocess """ - assert topk <= len(self.executor._label_list - ), 'Value of topk is larger than number of labels.' + assert topk <= len( + self._label_list), 'Value of topk is larger than number of labels.' - result = self.executor._outputs['logits'].squeeze(0).numpy() + result = self._outputs['logits'].squeeze(0).numpy() topk_idx = (-result).argsort()[:topk] topk_results = [] for idx in topk_idx: res = {} - label, score = self.executor._label_list[idx], result[idx] + label, score = self._label_list[idx], result[idx] res['class_name'] = label res['prob'] = score topk_results.append(res) diff --git a/paddlespeech/server/engine/engine_warmup.py b/paddlespeech/server/engine/engine_warmup.py new file mode 100644 index 000000000..5f548f71d --- /dev/null +++ b/paddlespeech/server/engine/engine_warmup.py @@ -0,0 +1,75 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +from paddlespeech.cli.log import logger +from paddlespeech.server.engine.engine_pool import get_engine_pool + + +def warm_up(engine_and_type: str, warm_up_time: int=3) -> bool: + engine_pool = get_engine_pool() + + if "tts" in engine_and_type: + tts_engine = engine_pool['tts'] + flag_online = False + if tts_engine.lang == 'zh': + sentence = "您好,欢迎使用语音合成服务。" + elif tts_engine.lang == 'en': + sentence = "Hello and welcome to the speech synthesis service." + else: + logger.error("tts engine only support lang: zh or en.") + sys.exit(-1) + + if engine_and_type == "tts_python": + from paddlespeech.server.engine.tts.python.tts_engine import PaddleTTSConnectionHandler + elif engine_and_type == "tts_inference": + from paddlespeech.server.engine.tts.paddleinference.tts_engine import PaddleTTSConnectionHandler + elif engine_and_type == "tts_online": + from paddlespeech.server.engine.tts.online.python.tts_engine import PaddleTTSConnectionHandler + flag_online = True + elif engine_and_type == "tts_online-onnx": + from paddlespeech.server.engine.tts.online.onnx.tts_engine import PaddleTTSConnectionHandler + flag_online = True + else: + logger.error("Please check tte engine type.") + + try: + logger.info("Start to warm up tts engine.") + for i in range(warm_up_time): + connection_handler = PaddleTTSConnectionHandler(tts_engine) + if flag_online: + for wav in connection_handler.infer( + text=sentence, + lang=tts_engine.lang, + am=tts_engine.config.am): + logger.info( + f"The first response time of the {i} warm up: {connection_handler.first_response_time} s" + ) + break + + else: + st = time.time() + connection_handler.infer(text=sentence) + et = time.time() + logger.info( + f"The response time of the {i} warm up: {et - st} s") + except Exception as e: + logger.error("Failed to warm up on tts engine.") + logger.error(e) + return False + + else: + pass + + return True diff --git a/paddlespeech/server/engine/tts/online/onnx/pretrained_models.py b/paddlespeech/server/engine/tts/online/onnx/pretrained_models.py deleted file mode 100644 index 789f5be7d..000000000 --- a/paddlespeech/server/engine/tts/online/onnx/pretrained_models.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# support online model -pretrained_models = { - # fastspeech2 - "fastspeech2_csmsc_onnx-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_onnx_0.2.0.zip', - 'md5': - 'fd3ad38d83273ad51f0ea4f4abf3ab4e', - 'ckpt': ['fastspeech2_csmsc.onnx'], - 'phones_dict': - 'phone_id_map.txt', - 'sample_rate': - 24000, - }, - "fastspeech2_cnndecoder_csmsc_onnx-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip', - 'md5': - '5f70e1a6bcd29d72d54e7931aa86f266', - 'ckpt': [ - 'fastspeech2_csmsc_am_encoder_infer.onnx', - 'fastspeech2_csmsc_am_decoder.onnx', - 'fastspeech2_csmsc_am_postnet.onnx', - ], - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - 'sample_rate': - 24000, - }, - - # mb_melgan - "mb_melgan_csmsc_onnx-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip', - 'md5': - '5b83ec746e8414bc29032d954ffd07ec', - 'ckpt': - 'mb_melgan_csmsc.onnx', - 'sample_rate': - 24000, - }, - - # hifigan - "hifigan_csmsc_onnx-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_onnx_0.2.0.zip', - 'md5': - '1a7dc0385875889e46952e50c0994a6b', - 'ckpt': - 'hifigan_csmsc.onnx', - 'sample_rate': - 24000, - }, -} diff --git a/paddlespeech/server/engine/tts/online/onnx/tts_engine.py b/paddlespeech/server/engine/tts/online/onnx/tts_engine.py index 792442065..cb9155a2d 100644 --- a/paddlespeech/server/engine/tts/online/onnx/tts_engine.py +++ b/paddlespeech/server/engine/tts/online/onnx/tts_engine.py @@ -20,9 +20,9 @@ from typing import Optional import numpy as np import paddle -from .pretrained_models import pretrained_models from paddlespeech.cli.log import logger from paddlespeech.cli.tts.infer import TTSExecutor +from paddlespeech.resource import CommonTaskResource from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import float2pcm from paddlespeech.server.utils.onnx_infer import get_sess @@ -31,19 +31,13 @@ from paddlespeech.server.utils.util import get_chunks from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend.zh_frontend import Frontend -__all__ = ['TTSEngine'] +__all__ = ['TTSEngine', 'PaddleTTSConnectionHandler'] class TTSServerExecutor(TTSExecutor): - def __init__(self, am_block, am_pad, voc_block, voc_pad, voc_upsample): + def __init__(self): super().__init__() - self.am_block = am_block - self.am_pad = am_pad - self.voc_block = voc_block - self.voc_pad = voc_pad - self.voc_upsample = voc_upsample - - self.pretrained_models = pretrained_models + self.task_resource = CommonTaskResource(task='tts', model_format='onnx') def _init_from_path( self, @@ -72,16 +66,21 @@ class TTSServerExecutor(TTSExecutor): return # am am_tag = am + '-' + lang + self.task_resource.set_task_model( + model_tag=am_tag, + model_type=0, # am + version=None, # default version + ) + self.am_res_path = self.task_resource.res_dir if am == "fastspeech2_csmsc_onnx": # get model info if am_ckpt is None or phones_dict is None: - am_res_path = self._get_pretrained_path(am_tag) - self.am_res_path = am_res_path self.am_ckpt = os.path.join( - am_res_path, self.pretrained_models[am_tag]['ckpt'][0]) + self.am_res_path, self.task_resource.res_dict['ckpt'][0]) # must have phones_dict in acoustic self.phones_dict = os.path.join( - am_res_path, self.pretrained_models[am_tag]['phones_dict']) + self.am_res_path, + self.task_resource.res_dict['phones_dict']) else: self.am_ckpt = os.path.abspath(am_ckpt[0]) @@ -94,19 +93,19 @@ class TTSServerExecutor(TTSExecutor): elif am == "fastspeech2_cnndecoder_csmsc_onnx": if am_ckpt is None or am_stat is None or phones_dict is None: - am_res_path = self._get_pretrained_path(am_tag) - self.am_res_path = am_res_path self.am_encoder_infer = os.path.join( - am_res_path, self.pretrained_models[am_tag]['ckpt'][0]) + self.am_res_path, self.task_resource.res_dict['ckpt'][0]) self.am_decoder = os.path.join( - am_res_path, self.pretrained_models[am_tag]['ckpt'][1]) + self.am_res_path, self.task_resource.res_dict['ckpt'][1]) self.am_postnet = os.path.join( - am_res_path, self.pretrained_models[am_tag]['ckpt'][2]) + self.am_res_path, self.task_resource.res_dict['ckpt'][2]) # must have phones_dict in acoustic self.phones_dict = os.path.join( - am_res_path, self.pretrained_models[am_tag]['phones_dict']) + self.am_res_path, + self.task_resource.res_dict['phones_dict']) self.am_stat = os.path.join( - am_res_path, self.pretrained_models[am_tag]['speech_stats']) + self.am_res_path, + self.task_resource.res_dict['speech_stats']) else: self.am_encoder_infer = os.path.abspath(am_ckpt[0]) @@ -131,11 +130,15 @@ class TTSServerExecutor(TTSExecutor): # voc model info voc_tag = voc + '-' + lang + self.task_resource.set_task_model( + model_tag=voc_tag, + model_type=1, # vocoder + version=None, # default version + ) if voc_ckpt is None: - voc_res_path = self._get_pretrained_path(voc_tag) - self.voc_res_path = voc_res_path + self.voc_res_path = self.task_resource.voc_res_dir self.voc_ckpt = os.path.join( - voc_res_path, self.pretrained_models[voc_tag]['ckpt']) + self.voc_res_path, self.task_resource.voc_res_dict['ckpt']) else: self.voc_ckpt = os.path.abspath(voc_ckpt) self.voc_res_path = os.path.dirname(os.path.abspath(self.voc_ckpt)) @@ -161,6 +164,115 @@ class TTSServerExecutor(TTSExecutor): self.frontend = English(phone_vocab_path=self.phones_dict) logger.info("frontend done!") + +class TTSEngine(BaseEngine): + """TTS server engine + + Args: + metaclass: Defaults to Singleton. + """ + + def __init__(self, name=None): + """Initialize TTS server engine + """ + super().__init__() + + def init(self, config: dict) -> bool: + self.executor = TTSServerExecutor() + self.config = config + self.lang = self.config.lang + self.engine_type = "online-onnx" + + self.am_block = self.config.am_block + self.am_pad = self.config.am_pad + self.voc_block = self.config.voc_block + self.voc_pad = self.config.voc_pad + self.am_upsample = 1 + self.voc_upsample = self.config.voc_upsample + + assert ( + self.config.am == "fastspeech2_csmsc_onnx" or + self.config.am == "fastspeech2_cnndecoder_csmsc_onnx" + ) and ( + self.config.voc == "hifigan_csmsc_onnx" or + self.config.voc == "mb_melgan_csmsc_onnx" + ), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.' + + assert ( + self.config.voc_block > 0 and self.config.voc_pad > 0 + ), "Please set correct voc_block and voc_pad, they should be more than 0." + + assert ( + self.config.voc_sample_rate == self.config.am_sample_rate + ), "The sample rate of AM and Vocoder model are different, please check model." + + try: + if self.config.am_sess_conf.device is not None: + self.device = self.config.am_sess_conf.device + elif self.config.voc_sess_conf.device is not None: + self.device = self.config.voc_sess_conf.device + else: + self.device = paddle.get_device() + paddle.set_device(self.device) + except Exception as e: + logger.error( + "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" + ) + logger.error("Initialize TTS server engine Failed on device: %s." % + (self.device)) + logger.error(e) + return False + + try: + self.executor._init_from_path( + am=self.config.am, + am_ckpt=self.config.am_ckpt, + am_stat=self.config.am_stat, + phones_dict=self.config.phones_dict, + tones_dict=self.config.tones_dict, + speaker_dict=self.config.speaker_dict, + am_sample_rate=self.config.am_sample_rate, + am_sess_conf=self.config.am_sess_conf, + voc=self.config.voc, + voc_ckpt=self.config.voc_ckpt, + voc_sample_rate=self.config.voc_sample_rate, + voc_sess_conf=self.config.voc_sess_conf, + lang=self.config.lang) + + except Exception as e: + logger.error("Failed to get model related files.") + logger.error("Initialize TTS server engine Failed on device: %s." % + (self.config.voc_sess_conf.device)) + logger(e) + return False + + logger.info("Initialize TTS server engine successfully on device: %s." % + (self.config.voc_sess_conf.device)) + + return True + + +class PaddleTTSConnectionHandler: + def __init__(self, tts_engine): + """The PaddleSpeech TTS Server Connection Handler + This connection process every tts server request + Args: + tts_engine (TTSEngine): The TTS engine + """ + super().__init__() + logger.info( + "Create PaddleTTSConnectionHandler to process the tts request") + + self.tts_engine = tts_engine + self.executor = self.tts_engine.executor + self.config = self.tts_engine.config + self.am_block = self.tts_engine.am_block + self.am_pad = self.tts_engine.am_pad + self.voc_block = self.tts_engine.voc_block + self.voc_pad = self.tts_engine.voc_pad + self.am_upsample = self.tts_engine.am_upsample + self.voc_upsample = self.tts_engine.voc_upsample + def depadding(self, data, chunk_num, chunk_id, block, pad, upsample): """ Streaming inference removes the result of pad inference @@ -189,12 +301,6 @@ class TTSServerExecutor(TTSExecutor): Model inference and result stored in self.output. """ - am_block = self.am_block - am_pad = self.am_pad - am_upsample = 1 - voc_block = self.voc_block - voc_pad = self.voc_pad - voc_upsample = self.voc_upsample # first_flag 用于标记首包 first_flag = 1 get_tone_ids = False @@ -203,7 +309,7 @@ class TTSServerExecutor(TTSExecutor): # front frontend_st = time.time() if lang == 'zh': - input_ids = self.frontend.get_input_ids( + input_ids = self.executor.frontend.get_input_ids( text, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) @@ -211,7 +317,7 @@ class TTSServerExecutor(TTSExecutor): if get_tone_ids: tone_ids = input_ids["tone_ids"] elif lang == 'en': - input_ids = self.frontend.get_input_ids( + input_ids = self.executor.frontend.get_input_ids( text, merge_sentences=merge_sentences) phone_ids = input_ids["phone_ids"] else: @@ -226,7 +332,7 @@ class TTSServerExecutor(TTSExecutor): # fastspeech2_csmsc if am == "fastspeech2_csmsc_onnx": # am - mel = self.am_sess.run( + mel = self.executor.am_sess.run( output_names=None, input_feed={'text': part_phone_ids}) mel = mel[0] if first_flag == 1: @@ -234,14 +340,16 @@ class TTSServerExecutor(TTSExecutor): self.first_am_infer = first_am_et - frontend_et # voc streaming - mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc") + mel_chunks = get_chunks(mel, self.voc_block, self.voc_pad, + "voc") voc_chunk_num = len(mel_chunks) voc_st = time.time() for i, mel_chunk in enumerate(mel_chunks): - sub_wav = self.voc_sess.run( + sub_wav = self.executor.voc_sess.run( output_names=None, input_feed={'logmel': mel_chunk}) sub_wav = self.depadding(sub_wav[0], voc_chunk_num, i, - voc_block, voc_pad, voc_upsample) + self.voc_block, self.voc_pad, + self.voc_upsample) if first_flag == 1: first_voc_et = time.time() self.first_voc_infer = first_voc_et - first_am_et @@ -253,7 +361,7 @@ class TTSServerExecutor(TTSExecutor): # fastspeech2_cnndecoder_csmsc elif am == "fastspeech2_cnndecoder_csmsc_onnx": # am - orig_hs = self.am_encoder_infer_sess.run( + orig_hs = self.executor.am_encoder_infer_sess.run( None, input_feed={'text': part_phone_ids}) orig_hs = orig_hs[0] @@ -267,9 +375,9 @@ class TTSServerExecutor(TTSExecutor): hss = get_chunks(orig_hs, self.am_block, self.am_pad, "am") am_chunk_num = len(hss) for i, hs in enumerate(hss): - am_decoder_output = self.am_decoder_sess.run( + am_decoder_output = self.executor.am_decoder_sess.run( None, input_feed={'xs': hs}) - am_postnet_output = self.am_postnet_sess.run( + am_postnet_output = self.executor.am_postnet_sess.run( None, input_feed={ 'xs': np.transpose(am_decoder_output[0], (0, 2, 1)) @@ -278,9 +386,11 @@ class TTSServerExecutor(TTSExecutor): am_postnet_output[0], (0, 2, 1)) normalized_mel = am_output_data[0][0] - sub_mel = denorm(normalized_mel, self.am_mu, self.am_std) - sub_mel = self.depadding(sub_mel, am_chunk_num, i, am_block, - am_pad, am_upsample) + sub_mel = denorm(normalized_mel, self.executor.am_mu, + self.executor.am_std) + sub_mel = self.depadding(sub_mel, am_chunk_num, i, + self.am_block, self.am_pad, + self.am_upsample) if i == 0: mel_streaming = sub_mel @@ -297,11 +407,11 @@ class TTSServerExecutor(TTSExecutor): self.first_am_infer = first_am_et - frontend_et voc_chunk = mel_streaming[start:end, :] - sub_wav = self.voc_sess.run( + sub_wav = self.executor.voc_sess.run( output_names=None, input_feed={'logmel': voc_chunk}) - sub_wav = self.depadding(sub_wav[0], voc_chunk_num, - voc_chunk_id, voc_block, - voc_pad, voc_upsample) + sub_wav = self.depadding( + sub_wav[0], voc_chunk_num, voc_chunk_id, + self.voc_block, self.voc_pad, self.voc_upsample) if first_flag == 1: first_voc_et = time.time() self.first_voc_infer = first_voc_et - first_am_et @@ -311,9 +421,11 @@ class TTSServerExecutor(TTSExecutor): yield sub_wav voc_chunk_id += 1 - start = max(0, voc_chunk_id * voc_block - voc_pad) - end = min((voc_chunk_id + 1) * voc_block + voc_pad, - mel_len) + start = max( + 0, voc_chunk_id * self.voc_block - self.voc_pad) + end = min( + (voc_chunk_id + 1) * self.voc_block + self.voc_pad, + mel_len) else: logger.error( @@ -322,111 +434,6 @@ class TTSServerExecutor(TTSExecutor): self.final_response_time = time.time() - frontend_st - -class TTSEngine(BaseEngine): - """TTS server engine - - Args: - metaclass: Defaults to Singleton. - """ - - def __init__(self, name=None): - """Initialize TTS server engine - """ - super().__init__() - - def init(self, config: dict) -> bool: - self.config = config - assert ( - self.config.am == "fastspeech2_csmsc_onnx" or - self.config.am == "fastspeech2_cnndecoder_csmsc_onnx" - ) and ( - self.config.voc == "hifigan_csmsc_onnx" or - self.config.voc == "mb_melgan_csmsc_onnx" - ), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.' - - assert ( - self.config.voc_block > 0 and self.config.voc_pad > 0 - ), "Please set correct voc_block and voc_pad, they should be more than 0." - - assert ( - self.config.voc_sample_rate == self.config.am_sample_rate - ), "The sample rate of AM and Vocoder model are different, please check model." - - self.executor = TTSServerExecutor( - self.config.am_block, self.config.am_pad, self.config.voc_block, - self.config.voc_pad, self.config.voc_upsample) - - try: - if self.config.am_sess_conf.device is not None: - self.device = self.config.am_sess_conf.device - elif self.config.voc_sess_conf.device is not None: - self.device = self.config.voc_sess_conf.device - else: - self.device = paddle.get_device() - paddle.set_device(self.device) - except BaseException as e: - logger.error( - "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" - ) - logger.error("Initialize TTS server engine Failed on device: %s." % - (self.device)) - return False - - try: - self.executor._init_from_path( - am=self.config.am, - am_ckpt=self.config.am_ckpt, - am_stat=self.config.am_stat, - phones_dict=self.config.phones_dict, - tones_dict=self.config.tones_dict, - speaker_dict=self.config.speaker_dict, - am_sample_rate=self.config.am_sample_rate, - am_sess_conf=self.config.am_sess_conf, - voc=self.config.voc, - voc_ckpt=self.config.voc_ckpt, - voc_sample_rate=self.config.voc_sample_rate, - voc_sess_conf=self.config.voc_sess_conf, - lang=self.config.lang) - - except Exception as e: - logger.error("Failed to get model related files.") - logger.error("Initialize TTS server engine Failed on device: %s." % - (self.config.voc_sess_conf.device)) - return False - - # warm up - try: - self.warm_up() - logger.info("Warm up successfully.") - except Exception as e: - logger.error("Failed to warm up on tts engine.") - return False - - logger.info("Initialize TTS server engine successfully on device: %s." % - (self.config.voc_sess_conf.device)) - - return True - - def warm_up(self): - """warm up - """ - if self.config.lang == 'zh': - sentence = "您好,欢迎使用语音合成服务。" - if self.config.lang == 'en': - sentence = "Hello and welcome to the speech synthesis service." - logger.info("Start to warm up.") - for i in range(3): - for wav in self.executor.infer( - text=sentence, - lang=self.config.lang, - am=self.config.am, - spk_id=0, ): - logger.info( - f"The first response time of the {i} warm up: {self.executor.first_response_time} s" - ) - break - def preprocess(self, text_bese64: str=None, text_bytes: bytes=None): # Convert byte to text if text_bese64: @@ -459,7 +466,7 @@ class TTSEngine(BaseEngine): """ wav_list = [] - for wav in self.executor.infer( + for wav in self.infer( text=sentence, lang=self.config.lang, am=self.config.am, @@ -477,11 +484,9 @@ class TTSEngine(BaseEngine): duration = len(wav_all) / self.config.voc_sample_rate logger.info(f"sentence: {sentence}") logger.info(f"The durations of audio is: {duration} s") + logger.info(f"first response time: {self.first_response_time} s") + logger.info(f"final response time: {self.final_response_time} s") + logger.info(f"RTF: {self.final_response_time / duration}") logger.info( - f"first response time: {self.executor.first_response_time} s") - logger.info( - f"final response time: {self.executor.final_response_time} s") - logger.info(f"RTF: {self.executor.final_response_time / duration}") - logger.info( - f"Other info: front time: {self.executor.frontend_time} s, first am infer time: {self.executor.first_am_infer} s, first voc infer time: {self.executor.first_voc_infer} s," + f"Other info: front time: {self.frontend_time} s, first am infer time: {self.first_am_infer} s, first voc infer time: {self.first_voc_infer} s," ) diff --git a/paddlespeech/server/engine/tts/online/python/pretrained_models.py b/paddlespeech/server/engine/tts/online/python/pretrained_models.py deleted file mode 100644 index bf6aded51..000000000 --- a/paddlespeech/server/engine/tts/online/python/pretrained_models.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# support online model -pretrained_models = { - # fastspeech2 - "fastspeech2_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip', - 'md5': - '637d28a5e53aa60275612ba4393d5f22', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_76000.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - }, - "fastspeech2_cnndecoder_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0.zip', - 'md5': - '6eb28e22ace73e0ebe7845f86478f89f', - 'config': - 'cnndecoder.yaml', - 'ckpt': - 'snapshot_iter_153000.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - }, - - # mb_melgan - "mb_melgan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip', - 'md5': - 'ee5f0604e20091f0d495b6ec4618b90d', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_1000000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - - # hifigan - "hifigan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip', - 'md5': - 'dd40a3d88dfcf64513fba2f0f961ada6', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_2500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, -} diff --git a/paddlespeech/server/engine/tts/online/python/tts_engine.py b/paddlespeech/server/engine/tts/online/python/tts_engine.py index 1fca52837..2e8997e0f 100644 --- a/paddlespeech/server/engine/tts/online/python/tts_engine.py +++ b/paddlespeech/server/engine/tts/online/python/tts_engine.py @@ -22,10 +22,9 @@ import paddle import yaml from yacs.config import CfgNode -from .pretrained_models import pretrained_models from paddlespeech.cli.log import logger from paddlespeech.cli.tts.infer import TTSExecutor -from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.resource import CommonTaskResource from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import float2pcm from paddlespeech.server.utils.util import denorm @@ -34,17 +33,14 @@ from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.modules.normalizer import ZScore -__all__ = ['TTSEngine'] +__all__ = ['TTSEngine', 'PaddleTTSConnectionHandler'] class TTSServerExecutor(TTSExecutor): - def __init__(self, am_block, am_pad, voc_block, voc_pad): + def __init__(self): super().__init__() - self.am_block = am_block - self.am_pad = am_pad - self.voc_block = voc_block - self.voc_pad = voc_pad - self.pretrained_models = pretrained_models + self.task_resource = CommonTaskResource( + task='tts', model_format='dynamic', inference_mode='online') def get_model_info(self, field: str, @@ -65,7 +61,7 @@ class TTSServerExecutor(TTSExecutor): [Tensor]: standard deviation """ - model_class = dynamic_import(model_name, self.model_alias) + model_class = self.task_resource.get_model_class(model_name) if field == "am": odim = self.am_config.n_mels @@ -110,20 +106,24 @@ class TTSServerExecutor(TTSExecutor): return # am model info am_tag = am + '-' + lang + self.task_resource.set_task_model( + model_tag=am_tag, + model_type=0, # am + version=None, # default version + ) if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None: - am_res_path = self._get_pretrained_path(am_tag) - self.am_res_path = am_res_path - self.am_config = os.path.join( - am_res_path, self.pretrained_models[am_tag]['config']) - self.am_ckpt = os.path.join(am_res_path, - self.pretrained_models[am_tag]['ckpt']) + self.am_res_path = self.task_resource.res_dir + self.am_config = os.path.join(self.am_res_path, + self.task_resource.res_dict['config']) + self.am_ckpt = os.path.join(self.am_res_path, + self.task_resource.res_dict['ckpt']) self.am_stat = os.path.join( - am_res_path, self.pretrained_models[am_tag]['speech_stats']) + self.am_res_path, self.task_resource.res_dict['speech_stats']) # must have phones_dict in acoustic self.phones_dict = os.path.join( - am_res_path, self.pretrained_models[am_tag]['phones_dict']) + self.am_res_path, self.task_resource.res_dict['phones_dict']) print("self.phones_dict:", self.phones_dict) - logger.info(am_res_path) + logger.info(self.am_res_path) logger.info(self.am_config) logger.info(self.am_ckpt) else: @@ -139,16 +139,21 @@ class TTSServerExecutor(TTSExecutor): # voc model info voc_tag = voc + '-' + lang + self.task_resource.set_task_model( + model_tag=voc_tag, + model_type=1, # vocoder + version=None, # default version + ) if voc_ckpt is None or voc_config is None or voc_stat is None: - voc_res_path = self._get_pretrained_path(voc_tag) - self.voc_res_path = voc_res_path + self.voc_res_path = self.task_resource.voc_res_dir self.voc_config = os.path.join( - voc_res_path, self.pretrained_models[voc_tag]['config']) + self.voc_res_path, self.task_resource.voc_res_dict['config']) self.voc_ckpt = os.path.join( - voc_res_path, self.pretrained_models[voc_tag]['ckpt']) + self.voc_res_path, self.task_resource.voc_res_dict['ckpt']) self.voc_stat = os.path.join( - voc_res_path, self.pretrained_models[voc_tag]['speech_stats']) - logger.info(voc_res_path) + self.voc_res_path, + self.task_resource.voc_res_dict['speech_stats']) + logger.info(self.voc_res_path) logger.info(self.voc_config) logger.info(self.voc_ckpt) else: @@ -188,8 +193,8 @@ class TTSServerExecutor(TTSExecutor): am, am_mu, am_std = self.get_model_info("am", self.am_name, self.am_ckpt, self.am_stat) am_normalizer = ZScore(am_mu, am_std) - am_inference_class = dynamic_import(self.am_name + '_inference', - self.model_alias) + am_inference_class = self.task_resource.get_model_class( + self.am_name + '_inference') self.am_inference = am_inference_class(am_normalizer, am) self.am_inference.eval() print("acoustic model done!") @@ -199,12 +204,112 @@ class TTSServerExecutor(TTSExecutor): voc, voc_mu, voc_std = self.get_model_info("voc", self.voc_name, self.voc_ckpt, self.voc_stat) voc_normalizer = ZScore(voc_mu, voc_std) - voc_inference_class = dynamic_import(self.voc_name + '_inference', - self.model_alias) + voc_inference_class = self.task_resource.get_model_class(self.voc_name + + '_inference') self.voc_inference = voc_inference_class(voc_normalizer, voc) self.voc_inference.eval() print("voc done!") + +class TTSEngine(BaseEngine): + """TTS server engine + + Args: + metaclass: Defaults to Singleton. + """ + + def __init__(self, name=None): + """Initialize TTS server engine + """ + super().__init__() + + def init(self, config: dict) -> bool: + self.executor = TTSServerExecutor() + self.config = config + self.lang = self.config.lang + self.engine_type = "online" + + assert ( + config.am == "fastspeech2_csmsc" or + config.am == "fastspeech2_cnndecoder_csmsc" + ) and ( + config.voc == "hifigan_csmsc" or config.voc == "mb_melgan_csmsc" + ), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.' + + assert ( + config.voc_block > 0 and config.voc_pad > 0 + ), "Please set correct voc_block and voc_pad, they should be more than 0." + + try: + if self.config.device is not None: + self.device = self.config.device + else: + self.device = paddle.get_device() + paddle.set_device(self.device) + except Exception as e: + logger.error( + "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" + ) + logger.error("Initialize TTS server engine Failed on device: %s." % + (self.device)) + logger.error(e) + return False + + try: + self.executor._init_from_path( + am=self.config.am, + am_config=self.config.am_config, + am_ckpt=self.config.am_ckpt, + am_stat=self.config.am_stat, + phones_dict=self.config.phones_dict, + tones_dict=self.config.tones_dict, + speaker_dict=self.config.speaker_dict, + voc=self.config.voc, + voc_config=self.config.voc_config, + voc_ckpt=self.config.voc_ckpt, + voc_stat=self.config.voc_stat, + lang=self.config.lang) + except Exception as e: + logger.error("Failed to get model related files.") + logger.error("Initialize TTS server engine Failed on device: %s." % + (self.device)) + logger.error(e) + return False + + self.am_block = self.config.am_block + self.am_pad = self.config.am_pad + self.voc_block = self.config.voc_block + self.voc_pad = self.config.voc_pad + self.am_upsample = 1 + self.voc_upsample = self.executor.voc_config.n_shift + + logger.info("Initialize TTS server engine successfully on device: %s." % + (self.device)) + + return True + + +class PaddleTTSConnectionHandler: + def __init__(self, tts_engine): + """The PaddleSpeech TTS Server Connection Handler + This connection process every tts server request + Args: + tts_engine (TTSEngine): The TTS engine + """ + super().__init__() + logger.info( + "Create PaddleTTSConnectionHandler to process the tts request") + + self.tts_engine = tts_engine + self.executor = self.tts_engine.executor + self.config = self.tts_engine.config + self.am_block = self.tts_engine.am_block + self.am_pad = self.tts_engine.am_pad + self.voc_block = self.tts_engine.voc_block + self.voc_pad = self.tts_engine.voc_pad + self.am_upsample = self.tts_engine.am_upsample + self.voc_upsample = self.tts_engine.voc_upsample + def depadding(self, data, chunk_num, chunk_id, block, pad, upsample): """ Streaming inference removes the result of pad inference @@ -233,12 +338,6 @@ class TTSServerExecutor(TTSExecutor): Model inference and result stored in self.output. """ - am_block = self.am_block - am_pad = self.am_pad - am_upsample = 1 - voc_block = self.voc_block - voc_pad = self.voc_pad - voc_upsample = self.voc_config.n_shift # first_flag 用于标记首包 first_flag = 1 @@ -246,7 +345,7 @@ class TTSServerExecutor(TTSExecutor): merge_sentences = False frontend_st = time.time() if lang == 'zh': - input_ids = self.frontend.get_input_ids( + input_ids = self.executor.frontend.get_input_ids( text, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) @@ -254,7 +353,7 @@ class TTSServerExecutor(TTSExecutor): if get_tone_ids: tone_ids = input_ids["tone_ids"] elif lang == 'en': - input_ids = self.frontend.get_input_ids( + input_ids = self.executor.frontend.get_input_ids( text, merge_sentences=merge_sentences) phone_ids = input_ids["phone_ids"] else: @@ -269,19 +368,21 @@ class TTSServerExecutor(TTSExecutor): # fastspeech2_csmsc if am == "fastspeech2_csmsc": # am - mel = self.am_inference(part_phone_ids) + mel = self.executor.am_inference(part_phone_ids) if first_flag == 1: first_am_et = time.time() self.first_am_infer = first_am_et - frontend_et # voc streaming - mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc") + mel_chunks = get_chunks(mel, self.voc_block, self.voc_pad, + "voc") voc_chunk_num = len(mel_chunks) voc_st = time.time() for i, mel_chunk in enumerate(mel_chunks): - sub_wav = self.voc_inference(mel_chunk) + sub_wav = self.executor.voc_inference(mel_chunk) sub_wav = self.depadding(sub_wav, voc_chunk_num, i, - voc_block, voc_pad, voc_upsample) + self.voc_block, self.voc_pad, + self.voc_upsample) if first_flag == 1: first_voc_et = time.time() self.first_voc_infer = first_voc_et - first_am_et @@ -293,7 +394,8 @@ class TTSServerExecutor(TTSExecutor): # fastspeech2_cnndecoder_csmsc elif am == "fastspeech2_cnndecoder_csmsc": # am - orig_hs = self.am_inference.encoder_infer(part_phone_ids) + orig_hs = self.executor.am_inference.encoder_infer( + part_phone_ids) # streaming voc chunk info mel_len = orig_hs.shape[1] @@ -305,13 +407,15 @@ class TTSServerExecutor(TTSExecutor): hss = get_chunks(orig_hs, self.am_block, self.am_pad, "am") am_chunk_num = len(hss) for i, hs in enumerate(hss): - before_outs = self.am_inference.decoder(hs) - after_outs = before_outs + self.am_inference.postnet( + before_outs = self.executor.am_inference.decoder(hs) + after_outs = before_outs + self.executor.am_inference.postnet( before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) normalized_mel = after_outs[0] - sub_mel = denorm(normalized_mel, self.am_mu, self.am_std) - sub_mel = self.depadding(sub_mel, am_chunk_num, i, am_block, - am_pad, am_upsample) + sub_mel = denorm(normalized_mel, self.executor.am_mu, + self.executor.am_std) + sub_mel = self.depadding(sub_mel, am_chunk_num, i, + self.am_block, self.am_pad, + self.am_upsample) if i == 0: mel_streaming = sub_mel @@ -328,11 +432,11 @@ class TTSServerExecutor(TTSExecutor): self.first_am_infer = first_am_et - frontend_et voc_chunk = mel_streaming[start:end, :] voc_chunk = paddle.to_tensor(voc_chunk) - sub_wav = self.voc_inference(voc_chunk) + sub_wav = self.executor.voc_inference(voc_chunk) - sub_wav = self.depadding(sub_wav, voc_chunk_num, - voc_chunk_id, voc_block, - voc_pad, voc_upsample) + sub_wav = self.depadding( + sub_wav, voc_chunk_num, voc_chunk_id, + self.voc_block, self.voc_pad, self.voc_upsample) if first_flag == 1: first_voc_et = time.time() self.first_voc_infer = first_voc_et - first_am_et @@ -342,9 +446,11 @@ class TTSServerExecutor(TTSExecutor): yield sub_wav voc_chunk_id += 1 - start = max(0, voc_chunk_id * voc_block - voc_pad) - end = min((voc_chunk_id + 1) * voc_block + voc_pad, - mel_len) + start = max( + 0, voc_chunk_id * self.voc_block - self.voc_pad) + end = min( + (voc_chunk_id + 1) * self.voc_block + self.voc_pad, + mel_len) else: logger.error( @@ -353,100 +459,6 @@ class TTSServerExecutor(TTSExecutor): self.final_response_time = time.time() - frontend_st - -class TTSEngine(BaseEngine): - """TTS server engine - - Args: - metaclass: Defaults to Singleton. - """ - - def __init__(self, name=None): - """Initialize TTS server engine - """ - super().__init__() - - def init(self, config: dict) -> bool: - self.config = config - assert ( - config.am == "fastspeech2_csmsc" or - config.am == "fastspeech2_cnndecoder_csmsc" - ) and ( - config.voc == "hifigan_csmsc" or config.voc == "mb_melgan_csmsc" - ), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.' - - assert ( - config.voc_block > 0 and config.voc_pad > 0 - ), "Please set correct voc_block and voc_pad, they should be more than 0." - - try: - if self.config.device is not None: - self.device = self.config.device - else: - self.device = paddle.get_device() - paddle.set_device(self.device) - except Exception as e: - logger.error( - "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" - ) - logger.error("Initialize TTS server engine Failed on device: %s." % - (self.device)) - return False - - self.executor = TTSServerExecutor(config.am_block, config.am_pad, - config.voc_block, config.voc_pad) - - try: - self.executor._init_from_path( - am=self.config.am, - am_config=self.config.am_config, - am_ckpt=self.config.am_ckpt, - am_stat=self.config.am_stat, - phones_dict=self.config.phones_dict, - tones_dict=self.config.tones_dict, - speaker_dict=self.config.speaker_dict, - voc=self.config.voc, - voc_config=self.config.voc_config, - voc_ckpt=self.config.voc_ckpt, - voc_stat=self.config.voc_stat, - lang=self.config.lang) - except Exception as e: - logger.error("Failed to get model related files.") - logger.error("Initialize TTS server engine Failed on device: %s." % - (self.device)) - return False - - # warm up - try: - self.warm_up() - logger.info("Warm up successfully.") - except Exception as e: - logger.error("Failed to warm up on tts engine.") - return False - - logger.info("Initialize TTS server engine successfully on device: %s." % - (self.device)) - return True - - def warm_up(self): - """warm up - """ - if self.config.lang == 'zh': - sentence = "您好,欢迎使用语音合成服务。" - if self.config.lang == 'en': - sentence = "Hello and welcome to the speech synthesis service." - logger.info("Start to warm up.") - for i in range(3): - for wav in self.executor.infer( - text=sentence, - lang=self.config.lang, - am=self.config.am, - spk_id=0, ): - logger.info( - f"The first response time of the {i} warm up: {self.executor.first_response_time} s" - ) - break - def preprocess(self, text_bese64: str=None, text_bytes: bytes=None): # Convert byte to text if text_bese64: @@ -480,7 +492,7 @@ class TTSEngine(BaseEngine): wav_list = [] - for wav in self.executor.infer( + for wav in self.infer( text=sentence, lang=self.config.lang, am=self.config.am, @@ -496,13 +508,12 @@ class TTSEngine(BaseEngine): wav_all = np.concatenate(wav_list, axis=0) duration = len(wav_all) / self.executor.am_config.fs + logger.info(f"sentence: {sentence}") logger.info(f"The durations of audio is: {duration} s") + logger.info(f"first response time: {self.first_response_time} s") + logger.info(f"final response time: {self.final_response_time} s") + logger.info(f"RTF: {self.final_response_time / duration}") logger.info( - f"first response time: {self.executor.first_response_time} s") - logger.info( - f"final response time: {self.executor.final_response_time} s") - logger.info(f"RTF: {self.executor.final_response_time / duration}") - logger.info( - f"Other info: front time: {self.executor.frontend_time} s, first am infer time: {self.executor.first_am_infer} s, first voc infer time: {self.executor.first_voc_infer} s," + f"Other info: front time: {self.frontend_time} s, first am infer time: {self.first_am_infer} s, first voc infer time: {self.first_voc_infer} s," ) diff --git a/paddlespeech/server/engine/tts/paddleinference/pretrained_models.py b/paddlespeech/server/engine/tts/paddleinference/pretrained_models.py deleted file mode 100644 index 9618a7a69..000000000 --- a/paddlespeech/server/engine/tts/paddleinference/pretrained_models.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Static model applied on paddle inference -pretrained_models = { - # speedyspeech - "speedyspeech_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_static_0.5.zip', - 'md5': - 'f10cbdedf47dc7a9668d2264494e1823', - 'model': - 'speedyspeech_csmsc.pdmodel', - 'params': - 'speedyspeech_csmsc.pdiparams', - 'phones_dict': - 'phone_id_map.txt', - 'tones_dict': - 'tone_id_map.txt', - 'sample_rate': - 24000, - }, - # fastspeech2 - "fastspeech2_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_static_0.4.zip', - 'md5': - '9788cd9745e14c7a5d12d32670b2a5a7', - 'model': - 'fastspeech2_csmsc.pdmodel', - 'params': - 'fastspeech2_csmsc.pdiparams', - 'phones_dict': - 'phone_id_map.txt', - 'sample_rate': - 24000, - }, - # pwgan - "pwgan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip', - 'md5': - 'e3504aed9c5a290be12d1347836d2742', - 'model': - 'pwgan_csmsc.pdmodel', - 'params': - 'pwgan_csmsc.pdiparams', - 'sample_rate': - 24000, - }, - # mb_melgan - "mb_melgan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_static_0.1.1.zip', - 'md5': - 'ac6eee94ba483421d750433f4c3b8d36', - 'model': - 'mb_melgan_csmsc.pdmodel', - 'params': - 'mb_melgan_csmsc.pdiparams', - 'sample_rate': - 24000, - }, - # hifigan - "hifigan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_static_0.1.1.zip', - 'md5': - '7edd8c436b3a5546b3a7cb8cff9d5a0c', - 'model': - 'hifigan_csmsc.pdmodel', - 'params': - 'hifigan_csmsc.pdiparams', - 'sample_rate': - 24000, - }, -} diff --git a/paddlespeech/server/engine/tts/paddleinference/tts_engine.py b/paddlespeech/server/engine/tts/paddleinference/tts_engine.py index f1ce8b76e..ab5b721ff 100644 --- a/paddlespeech/server/engine/tts/paddleinference/tts_engine.py +++ b/paddlespeech/server/engine/tts/paddleinference/tts_engine.py @@ -14,6 +14,7 @@ import base64 import io import os +import sys import time from typing import Optional @@ -23,9 +24,9 @@ import paddle import soundfile as sf from scipy.io import wavfile -from .pretrained_models import pretrained_models from paddlespeech.cli.log import logger from paddlespeech.cli.tts.infer import TTSExecutor +from paddlespeech.resource import CommonTaskResource from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import change_speed from paddlespeech.server.utils.errors import ErrorCode @@ -35,13 +36,14 @@ from paddlespeech.server.utils.paddle_predictor import run_model from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend.zh_frontend import Frontend -__all__ = ['TTSEngine'] +__all__ = ['TTSEngine', 'PaddleTTSConnectionHandler'] class TTSServerExecutor(TTSExecutor): def __init__(self): super().__init__() - self.pretrained_models = pretrained_models + self.task_resource = CommonTaskResource( + task='tts', model_format='static') def _init_from_path( self, @@ -67,19 +69,23 @@ class TTSServerExecutor(TTSExecutor): return # am am_tag = am + '-' + lang + self.task_resource.set_task_model( + model_tag=am_tag, + model_type=0, # am + version=None, # default version + ) if am_model is None or am_params is None or phones_dict is None: - am_res_path = self._get_pretrained_path(am_tag) - self.am_res_path = am_res_path - self.am_model = os.path.join( - am_res_path, self.pretrained_models[am_tag]['model']) - self.am_params = os.path.join( - am_res_path, self.pretrained_models[am_tag]['params']) + self.am_res_path = self.task_resource.res_dir + self.am_model = os.path.join(self.am_res_path, + self.task_resource.res_dict['model']) + self.am_params = os.path.join(self.am_res_path, + self.task_resource.res_dict['params']) # must have phones_dict in acoustic self.phones_dict = os.path.join( - am_res_path, self.pretrained_models[am_tag]['phones_dict']) - self.am_sample_rate = self.pretrained_models[am_tag]['sample_rate'] + self.am_res_path, self.task_resource.res_dict['phones_dict']) + self.am_sample_rate = self.task_resource.res_dict['sample_rate'] - logger.info(am_res_path) + logger.info(self.am_res_path) logger.info(self.am_model) logger.info(self.am_params) else: @@ -92,32 +98,36 @@ class TTSServerExecutor(TTSExecutor): # for speedyspeech self.tones_dict = None - if 'tones_dict' in self.pretrained_models[am_tag]: + if 'tones_dict' in self.task_resource.res_dict: self.tones_dict = os.path.join( - am_res_path, self.pretrained_models[am_tag]['tones_dict']) + self.am_res_path, self.task_resource.res_dict['tones_dict']) if tones_dict: self.tones_dict = tones_dict # for multi speaker fastspeech2 self.speaker_dict = None - if 'speaker_dict' in self.pretrained_models[am_tag]: + if 'speaker_dict' in self.task_resource.res_dict: self.speaker_dict = os.path.join( - am_res_path, self.pretrained_models[am_tag]['speaker_dict']) + self.am_res_path, self.task_resource.res_dict['speaker_dict']) if speaker_dict: self.speaker_dict = speaker_dict # voc voc_tag = voc + '-' + lang + self.task_resource.set_task_model( + model_tag=voc_tag, + model_type=1, # vocoder + version=None, # default version + ) if voc_model is None or voc_params is None: - voc_res_path = self._get_pretrained_path(voc_tag) - self.voc_res_path = voc_res_path + self.voc_res_path = self.task_resource.voc_res_dir self.voc_model = os.path.join( - voc_res_path, self.pretrained_models[voc_tag]['model']) + self.voc_res_path, self.task_resource.voc_res_dict['model']) self.voc_params = os.path.join( - voc_res_path, self.pretrained_models[voc_tag]['params']) - self.voc_sample_rate = self.pretrained_models[voc_tag][ + self.voc_res_path, self.task_resource.voc_res_dict['params']) + self.voc_sample_rate = self.task_resource.voc_res_dict[ 'sample_rate'] - logger.info(voc_res_path) + logger.info(self.voc_res_path) logger.info(self.voc_model) logger.info(self.voc_params) else: @@ -245,7 +255,7 @@ class TTSServerExecutor(TTSExecutor): else: wav_all = paddle.concat([wav_all, wav]) self.voc_time += (time.time() - voc_st) - self._outputs['wav'] = wav_all + self._outputs["wav"] = wav_all class TTSEngine(BaseEngine): @@ -263,6 +273,8 @@ class TTSEngine(BaseEngine): def init(self, config: dict) -> bool: self.executor = TTSServerExecutor() self.config = config + self.lang = self.config.lang + self.engine_type = "inference" try: if self.config.am_predictor_conf.device is not None: @@ -272,58 +284,59 @@ class TTSEngine(BaseEngine): else: self.device = paddle.get_device() paddle.set_device(self.device) - except BaseException as e: + except Exception as e: logger.error( "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" ) logger.error("Initialize TTS server engine Failed on device: %s." % (self.device)) + logger.error(e) return False - self.executor._init_from_path( - am=self.config.am, - am_model=self.config.am_model, - am_params=self.config.am_params, - am_sample_rate=self.config.am_sample_rate, - phones_dict=self.config.phones_dict, - tones_dict=self.config.tones_dict, - speaker_dict=self.config.speaker_dict, - voc=self.config.voc, - voc_model=self.config.voc_model, - voc_params=self.config.voc_params, - voc_sample_rate=self.config.voc_sample_rate, - lang=self.config.lang, - am_predictor_conf=self.config.am_predictor_conf, - voc_predictor_conf=self.config.voc_predictor_conf, ) - - # warm up try: - self.warm_up() - logger.info("Warm up successfully.") + self.executor._init_from_path( + am=self.config.am, + am_model=self.config.am_model, + am_params=self.config.am_params, + am_sample_rate=self.config.am_sample_rate, + phones_dict=self.config.phones_dict, + tones_dict=self.config.tones_dict, + speaker_dict=self.config.speaker_dict, + voc=self.config.voc, + voc_model=self.config.voc_model, + voc_params=self.config.voc_params, + voc_sample_rate=self.config.voc_sample_rate, + lang=self.config.lang, + am_predictor_conf=self.config.am_predictor_conf, + voc_predictor_conf=self.config.voc_predictor_conf, ) except Exception as e: - logger.error("Failed to warm up on tts engine.") + logger.error("Failed to get model related files.") + logger.error("Initialize TTS server engine Failed on device: %s." % + (self.device)) + logger.error(e) return False logger.info("Initialize TTS server engine successfully.") return True - def warm_up(self): - """warm up + +class PaddleTTSConnectionHandler(TTSServerExecutor): + def __init__(self, tts_engine): + """The PaddleSpeech TTS Server Connection Handler + This connection process every tts server request + Args: + tts_engine (TTSEngine): The TTS engine """ - if self.config.lang == 'zh': - sentence = "您好,欢迎使用语音合成服务。" - if self.config.lang == 'en': - sentence = "Hello and welcome to the speech synthesis service." - logger.info("Start to warm up.") - for i in range(3): - st = time.time() - self.executor.infer( - text=sentence, - lang=self.config.lang, - am=self.config.am, - spk_id=0, ) - logger.info( - f"The response time of the {i} warm up: {time.time() - st} s") + super().__init__() + logger.info( + "Create PaddleTTSConnectionHandler to process the tts request") + + self.tts_engine = tts_engine + self.executor = self.tts_engine.executor + self.config = self.tts_engine.config + self.frontend = self.executor.frontend + self.am_predictor = self.executor.am_predictor + self.voc_predictor = self.executor.voc_predictor def postprocess(self, wav, @@ -375,8 +388,11 @@ class TTSEngine(BaseEngine): ErrorCode.SERVER_INTERNAL_ERR, "Failed to transform speed. Can not install soxbindings on your system. \ You need to set speed value 1.0.") - except BaseException: + sys.exit(-1) + except Exception as e: logger.error("Failed to transform speed.") + logger.error(e) + sys.exit(-1) # wav to base64 buf = io.BytesIO() @@ -433,7 +449,7 @@ class TTSEngine(BaseEngine): try: infer_st = time.time() - self.executor.infer( + self.infer( text=sentence, lang=lang, am=self.config.am, spk_id=spk_id) infer_et = time.time() infer_time = infer_et - infer_st @@ -441,13 +457,16 @@ class TTSEngine(BaseEngine): except ServerBaseException: raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, "tts infer failed.") - except BaseException: + sys.exit(-1) + except Exception as e: logger.error("tts infer failed.") + logger.error(e) + sys.exit(-1) try: postprocess_st = time.time() target_sample_rate, wav_base64 = self.postprocess( - wav=self.executor._outputs['wav'].numpy(), + wav=self._outputs["wav"].numpy(), original_fs=self.executor.am_sample_rate, target_fs=sample_rate, volume=volume, @@ -455,26 +474,28 @@ class TTSEngine(BaseEngine): audio_path=save_path) postprocess_et = time.time() postprocess_time = postprocess_et - postprocess_st - duration = len(self.executor._outputs['wav'] - .numpy()) / self.executor.am_sample_rate + duration = len( + self._outputs["wav"].numpy()) / self.executor.am_sample_rate rtf = infer_time / duration except ServerBaseException: raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, "tts postprocess failed.") - except BaseException: + sys.exit(-1) + except Exception as e: logger.error("tts postprocess failed.") + logger.error(e) + sys.exit(-1) logger.info("AM model: {}".format(self.config.am)) logger.info("Vocoder model: {}".format(self.config.voc)) logger.info("Language: {}".format(lang)) - logger.info("tts engine type: paddle inference") + logger.info("tts engine type: python") logger.info("audio duration: {}".format(duration)) - logger.info( - "frontend inference time: {}".format(self.executor.frontend_time)) - logger.info("AM inference time: {}".format(self.executor.am_time)) - logger.info("Vocoder inference time: {}".format(self.executor.voc_time)) + logger.info("frontend inference time: {}".format(self.frontend_time)) + logger.info("AM inference time: {}".format(self.am_time)) + logger.info("Vocoder inference time: {}".format(self.voc_time)) logger.info("total inference time: {}".format(infer_time)) logger.info( "postprocess (change speed, volume, target sample rate) time: {}". @@ -482,5 +503,6 @@ class TTSEngine(BaseEngine): logger.info("total generate audio time: {}".format(infer_time + postprocess_time)) logger.info("RTF: {}".format(rtf)) + logger.info("device: {}".format(self.tts_engine.device)) return lang, target_sample_rate, duration, wav_base64 diff --git a/paddlespeech/server/engine/tts/python/tts_engine.py b/paddlespeech/server/engine/tts/python/tts_engine.py index d0002baa4..b048b01a4 100644 --- a/paddlespeech/server/engine/tts/python/tts_engine.py +++ b/paddlespeech/server/engine/tts/python/tts_engine.py @@ -13,6 +13,7 @@ # limitations under the License. import base64 import io +import sys import time import librosa @@ -28,7 +29,7 @@ from paddlespeech.server.utils.audio_process import change_speed from paddlespeech.server.utils.errors import ErrorCode from paddlespeech.server.utils.exception import ServerBaseException -__all__ = ['TTSEngine'] +__all__ = ['TTSEngine', 'PaddleTTSConnectionHandler'] class TTSServerExecutor(TTSExecutor): @@ -52,6 +53,8 @@ class TTSEngine(BaseEngine): def init(self, config: dict) -> bool: self.executor = TTSServerExecutor() self.config = config + self.lang = self.config.lang + self.engine_type = "python" try: if self.config.device is not None: @@ -59,12 +62,13 @@ class TTSEngine(BaseEngine): else: self.device = paddle.get_device() paddle.set_device(self.device) - except BaseException as e: + except Exception as e: logger.error( "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" ) logger.error("Initialize TTS server engine Failed on device: %s." % (self.device)) + logger.error(e) return False try: @@ -81,41 +85,35 @@ class TTSEngine(BaseEngine): voc_ckpt=self.config.voc_ckpt, voc_stat=self.config.voc_stat, lang=self.config.lang) - except BaseException: + except Exception as e: logger.error("Failed to get model related files.") logger.error("Initialize TTS server engine Failed on device: %s." % (self.device)) - return False - - # warm up - try: - self.warm_up() - logger.info("Warm up successfully.") - except Exception as e: - logger.error("Failed to warm up on tts engine.") + logger.error(e) return False logger.info("Initialize TTS server engine successfully on device: %s." % (self.device)) return True - def warm_up(self): - """warm up + +class PaddleTTSConnectionHandler(TTSServerExecutor): + def __init__(self, tts_engine): + """The PaddleSpeech TTS Server Connection Handler + This connection process every tts server request + Args: + tts_engine (TTSEngine): The TTS engine """ - if self.config.lang == 'zh': - sentence = "您好,欢迎使用语音合成服务。" - if self.config.lang == 'en': - sentence = "Hello and welcome to the speech synthesis service." - logger.info("Start to warm up.") - for i in range(3): - st = time.time() - self.executor.infer( - text=sentence, - lang=self.config.lang, - am=self.config.am, - spk_id=0, ) - logger.info( - f"The response time of the {i} warm up: {time.time() - st} s") + super().__init__() + logger.info( + "Create PaddleTTSConnectionHandler to process the tts request") + + self.tts_engine = tts_engine + self.executor = self.tts_engine.executor + self.config = self.tts_engine.config + self.frontend = self.executor.frontend + self.am_inference = self.executor.am_inference + self.voc_inference = self.executor.voc_inference def postprocess(self, wav, @@ -167,8 +165,11 @@ class TTSEngine(BaseEngine): ErrorCode.SERVER_INTERNAL_ERR, "Failed to transform speed. Can not install soxbindings on your system. \ You need to set speed value 1.0.") - except BaseException: + sys.exit(-1) + except Exception as e: logger.error("Failed to transform speed.") + logger.error(e) + sys.exit(-1) # wav to base64 buf = io.BytesIO() @@ -225,24 +226,27 @@ class TTSEngine(BaseEngine): try: infer_st = time.time() - self.executor.infer( + self.infer( text=sentence, lang=lang, am=self.config.am, spk_id=spk_id) infer_et = time.time() infer_time = infer_et - infer_st - duration = len(self.executor._outputs['wav'] - .numpy()) / self.executor.am_config.fs + duration = len( + self._outputs["wav"].numpy()) / self.executor.am_config.fs rtf = infer_time / duration except ServerBaseException: raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, "tts infer failed.") - except BaseException: + sys.exit(-1) + except Exception as e: logger.error("tts infer failed.") + logger.error(e) + sys.exit(-1) try: postprocess_st = time.time() target_sample_rate, wav_base64 = self.postprocess( - wav=self.executor._outputs['wav'].numpy(), + wav=self._outputs["wav"].numpy(), original_fs=self.executor.am_config.fs, target_fs=sample_rate, volume=volume, @@ -254,8 +258,11 @@ class TTSEngine(BaseEngine): except ServerBaseException: raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, "tts postprocess failed.") - except BaseException: + sys.exit(-1) + except Exception as e: logger.error("tts postprocess failed.") + logger.error(e) + sys.exit(-1) logger.info("AM model: {}".format(self.config.am)) logger.info("Vocoder model: {}".format(self.config.voc)) @@ -263,10 +270,9 @@ class TTSEngine(BaseEngine): logger.info("tts engine type: python") logger.info("audio duration: {}".format(duration)) - logger.info( - "frontend inference time: {}".format(self.executor.frontend_time)) - logger.info("AM inference time: {}".format(self.executor.am_time)) - logger.info("Vocoder inference time: {}".format(self.executor.voc_time)) + logger.info("frontend inference time: {}".format(self.frontend_time)) + logger.info("AM inference time: {}".format(self.am_time)) + logger.info("Vocoder inference time: {}".format(self.voc_time)) logger.info("total inference time: {}".format(infer_time)) logger.info( "postprocess (change speed, volume, target sample rate) time: {}". @@ -274,6 +280,6 @@ class TTSEngine(BaseEngine): logger.info("total generate audio time: {}".format(infer_time + postprocess_time)) logger.info("RTF: {}".format(rtf)) - logger.info("device: {}".format(self.device)) + logger.info("device: {}".format(self.tts_engine.device)) return lang, target_sample_rate, duration, wav_base64 diff --git a/paddlespeech/server/engine/vector/python/vector_engine.py b/paddlespeech/server/engine/vector/python/vector_engine.py index 854303701..3c72f55d4 100644 --- a/paddlespeech/server/engine/vector/python/vector_engine.py +++ b/paddlespeech/server/engine/vector/python/vector_engine.py @@ -16,9 +16,9 @@ from collections import OrderedDict import numpy as np import paddle -from paddleaudio.backends import load as load_audio -from paddleaudio.compliance.librosa import melspectrogram +from paddlespeech.audio.backends import load as load_audio +from paddlespeech.audio.compliance.librosa import melspectrogram from paddlespeech.cli.log import logger from paddlespeech.cli.vector.infer import VectorExecutor from paddlespeech.server.engine.base_engine import BaseEngine diff --git a/paddlespeech/server/restful/api.py b/paddlespeech/server/restful/api.py index 1c2dd2814..9722c2614 100644 --- a/paddlespeech/server/restful/api.py +++ b/paddlespeech/server/restful/api.py @@ -17,12 +17,12 @@ from typing import List from fastapi import APIRouter from paddlespeech.cli.log import logger +from paddlespeech.server.restful.acs_api import router as acs_router from paddlespeech.server.restful.asr_api import router as asr_router from paddlespeech.server.restful.cls_api import router as cls_router from paddlespeech.server.restful.text_api import router as text_router from paddlespeech.server.restful.tts_api import router as tts_router from paddlespeech.server.restful.vector_api import router as vec_router -from paddlespeech.server.restful.acs_api import router as acs_router _router = APIRouter() diff --git a/paddlespeech/server/restful/asr_api.py b/paddlespeech/server/restful/asr_api.py index cf46735dc..c7bc50ce4 100644 --- a/paddlespeech/server/restful/asr_api.py +++ b/paddlespeech/server/restful/asr_api.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import base64 +import sys import traceback from typing import Union from fastapi import APIRouter +from paddlespeech.cli.log import logger from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.restful.request import ASRRequest from paddlespeech.server.restful.response import ASRResponse @@ -68,8 +70,18 @@ def asr(request_body: ASRRequest): engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] - asr_engine.run(audio_data) - asr_results = asr_engine.postprocess() + if asr_engine.engine_type == "python": + from paddlespeech.server.engine.asr.python.asr_engine import PaddleASRConnectionHandler + elif asr_engine.engine_type == "inference": + from paddlespeech.server.engine.asr.paddleinference.asr_engine import PaddleASRConnectionHandler + else: + logger.error("Offline asr engine only support python or inference.") + sys.exit(-1) + + connection_handler = PaddleASRConnectionHandler(asr_engine) + + connection_handler.run(audio_data) + asr_results = connection_handler.postprocess() response = { "success": True, diff --git a/paddlespeech/server/restful/cls_api.py b/paddlespeech/server/restful/cls_api.py index 306d9ca9c..7cfb4a297 100644 --- a/paddlespeech/server/restful/cls_api.py +++ b/paddlespeech/server/restful/cls_api.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import base64 +import sys import traceback from typing import Union from fastapi import APIRouter +from paddlespeech.cli.log import logger from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.restful.request import CLSRequest from paddlespeech.server.restful.response import CLSResponse @@ -68,8 +70,18 @@ def cls(request_body: CLSRequest): engine_pool = get_engine_pool() cls_engine = engine_pool['cls'] - cls_engine.run(audio_data) - cls_results = cls_engine.postprocess(request_body.topk) + if cls_engine.engine_type == "python": + from paddlespeech.server.engine.cls.python.cls_engine import PaddleCLSConnectionHandler + elif cls_engine.engine_type == "inference": + from paddlespeech.server.engine.cls.paddleinference.cls_engine import PaddleCLSConnectionHandler + else: + logger.error("Offline cls engine only support python or inference.") + sys.exit(-1) + + connection_handler = PaddleCLSConnectionHandler(cls_engine) + + connection_handler.run(audio_data) + cls_results = connection_handler.postprocess(request_body.topk) response = { "success": True, @@ -85,8 +97,11 @@ def cls(request_body: CLSRequest): except ServerBaseException as e: response = failed_response(e.error_code, e.msg) - except BaseException: + logger.error(e) + sys.exit(-1) + except Exception as e: response = failed_response(ErrorCode.SERVER_UNKOWN_ERR) + logger.error(e) traceback.print_exc() return response diff --git a/paddlespeech/server/restful/tts_api.py b/paddlespeech/server/restful/tts_api.py index 15d618d93..53fe159fd 100644 --- a/paddlespeech/server/restful/tts_api.py +++ b/paddlespeech/server/restful/tts_api.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys import traceback from typing import Union @@ -99,7 +100,16 @@ def tts(request_body: TTSRequest): tts_engine = engine_pool['tts'] logger.info("Get tts engine successfully.") - lang, target_sample_rate, duration, wav_base64 = tts_engine.run( + if tts_engine.engine_type == "python": + from paddlespeech.server.engine.tts.python.tts_engine import PaddleTTSConnectionHandler + elif tts_engine.engine_type == "inference": + from paddlespeech.server.engine.tts.paddleinference.tts_engine import PaddleTTSConnectionHandler + else: + logger.error("Offline tts engine only support python or inference.") + sys.exit(-1) + + connection_handler = PaddleTTSConnectionHandler(tts_engine) + lang, target_sample_rate, duration, wav_base64 = connection_handler.run( text, spk_id, speed, volume, sample_rate, save_path) response = { @@ -136,4 +146,14 @@ async def stream_tts(request_body: TTSRequest): tts_engine = engine_pool['tts'] logger.info("Get tts engine successfully.") - return StreamingResponse(tts_engine.run(sentence=text)) + if tts_engine.engine_type == "online": + from paddlespeech.server.engine.tts.online.python.tts_engine import PaddleTTSConnectionHandler + elif tts_engine.engine_type == "online-onnx": + from paddlespeech.server.engine.tts.online.onnx.tts_engine import PaddleTTSConnectionHandler + else: + logger.error("Online tts engine only support online or online-onnx.") + sys.exit(-1) + + connection_handler = PaddleTTSConnectionHandler(tts_engine) + + return StreamingResponse(connection_handler.run(sentence=text)) diff --git a/paddlespeech/server/util.py b/paddlespeech/server/util.py index ae3e9c6aa..32546a330 100644 --- a/paddlespeech/server/util.py +++ b/paddlespeech/server/util.py @@ -24,14 +24,14 @@ from typing import Any from typing import Dict import paddle -import paddleaudio import requests import yaml from paddle.framework import load -from . import download +import paddlespeech.audio from .entry import client_commands from .entry import server_commands +from paddlespeech.cli import download try: from .. import __version__ except ImportError: @@ -289,7 +289,7 @@ def _note_one_stat(cls_name, params={}): if 'audio_file' in params: try: - _, sr = paddleaudio.load(params['audio_file']) + _, sr = paddlespeech.audio.load(params['audio_file']) except Exception: sr = -1 diff --git a/paddlespeech/server/utils/audio_handler.py b/paddlespeech/server/utils/audio_handler.py index baa7b9343..e3d90d469 100644 --- a/paddlespeech/server/utils/audio_handler.py +++ b/paddlespeech/server/utils/audio_handler.py @@ -248,7 +248,7 @@ class ASRHttpHandler: } res = requests.post(url=self.url, data=json.dumps(data)) - + return res.json() diff --git a/paddlespeech/server/utils/buffer.py b/paddlespeech/server/utils/buffer.py index f56db752d..20cd3cf62 100644 --- a/paddlespeech/server/utils/buffer.py +++ b/paddlespeech/server/utils/buffer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + class Frame(object): """Represents a "frame" of audio data.""" @@ -45,7 +46,7 @@ class ChunkBuffer(object): self.shift_ms = shift_ms self.sample_rate = sample_rate self.sample_width = sample_width # int16 = 2; float32 = 4 - + self.window_sec = float((self.window_n - 1) * self.shift_ms + self.window_ms) / 1000.0 self.shift_sec = float(self.shift_n * self.shift_ms / 1000.0) @@ -77,8 +78,8 @@ class ChunkBuffer(object): offset = 0 while offset + self.window_bytes <= len(audio): - yield Frame(audio[offset:offset + self.window_bytes], self.timestamp, - self.window_sec) + yield Frame(audio[offset:offset + self.window_bytes], + self.timestamp, self.window_sec) self.timestamp += self.shift_sec offset += self.shift_bytes diff --git a/paddlespeech/server/ws/asr_api.py b/paddlespeech/server/ws/asr_api.py index 0faa131aa..ae1c88310 100644 --- a/paddlespeech/server/ws/asr_api.py +++ b/paddlespeech/server/ws/asr_api.py @@ -19,7 +19,6 @@ from fastapi import WebSocketDisconnect from starlette.websockets import WebSocketState as WebSocketState from paddlespeech.cli.log import logger -from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler from paddlespeech.server.engine.engine_pool import get_engine_pool router = APIRouter() @@ -38,7 +37,7 @@ async def websocket_endpoint(websocket: WebSocket): #2. if we accept the websocket headers, we will get the online asr engine instance engine_pool = get_engine_pool() - asr_engine = engine_pool['asr'] + asr_model = engine_pool['asr'] #3. each websocket connection, we will create an PaddleASRConnectionHanddler to process such audio # and each connection has its own connection instance to process the request @@ -70,7 +69,8 @@ async def websocket_endpoint(websocket: WebSocket): resp = {"status": "ok", "signal": "server_ready"} # do something at begining here # create the instance to process the audio - connection_handler = PaddleASRConnectionHanddler(asr_engine) + #connection_handler = PaddleASRConnectionHanddler(asr_model) + connection_handler = asr_model.new_handler() await websocket.send_json(resp) elif message['signal'] == 'end': # reset single engine for an new connection @@ -92,6 +92,7 @@ async def websocket_endpoint(websocket: WebSocket): else: resp = {"status": "ok", "message": "no valid json data"} await websocket.send_json(resp) + elif "bytes" in message: # bytes for the pcm data message = message["bytes"] @@ -100,11 +101,34 @@ async def websocket_endpoint(websocket: WebSocket): # and decode for the result in this package data connection_handler.extract_feat(message) connection_handler.decode(is_finished=False) + + if connection_handler.endpoint_state: + logger.info("endpoint: detected and rescoring.") + connection_handler.rescoring() + word_time_stamp = connection_handler.get_word_time_stamp() + asr_results = connection_handler.get_result() - # return the current period result - # if the engine create the vad instance, this connection will have many period results + if connection_handler.endpoint_state: + if connection_handler.continuous_decoding: + logger.info("endpoint: continue decoding") + connection_handler.reset_continuous_decoding() + else: + logger.info("endpoint: exit decoding") + # ending by endpoint + resp = { + "status": "ok", + "signal": "finished", + 'result': asr_results, + 'times': word_time_stamp + } + await websocket.send_json(resp) + break + + # return the current partial result + # if the engine create the vad instance, this connection will have many partial results resp = {'result': asr_results} await websocket.send_json(resp) + except WebSocketDisconnect as e: logger.error(e) diff --git a/paddlespeech/server/ws/tts_api.py b/paddlespeech/server/ws/tts_api.py index a3a4c4d4b..3d8b222ea 100644 --- a/paddlespeech/server/ws/tts_api.py +++ b/paddlespeech/server/ws/tts_api.py @@ -40,6 +40,16 @@ async def websocket_endpoint(websocket: WebSocket): engine_pool = get_engine_pool() tts_engine = engine_pool['tts'] + connection_handler = None + + if tts_engine.engine_type == "online": + from paddlespeech.server.engine.tts.online.python.tts_engine import PaddleTTSConnectionHandler + elif tts_engine.engine_type == "online-onnx": + from paddlespeech.server.engine.tts.online.onnx.tts_engine import PaddleTTSConnectionHandler + else: + logger.error("Online tts engine only support online or online-onnx.") + sys.exit(-1) + try: while True: # careful here, changed the source code from starlette.websockets @@ -57,10 +67,13 @@ async def websocket_endpoint(websocket: WebSocket): "signal": "server ready", "session": session } + + connection_handler = PaddleTTSConnectionHandler(tts_engine) await websocket.send_json(resp) # end request elif message['signal'] == 'end': + connection_handler = None resp = { "status": 0, "signal": "connection will be closed", @@ -75,10 +88,11 @@ async def websocket_endpoint(websocket: WebSocket): # speech synthesis request elif 'text' in message: text_bese64 = message["text"] - sentence = tts_engine.preprocess(text_bese64=text_bese64) + sentence = connection_handler.preprocess( + text_bese64=text_bese64) # run - wav_generator = tts_engine.run(sentence) + wav_generator = connection_handler.run(sentence) while True: try: diff --git a/paddlespeech/t2s/datasets/am_batch_fn.py b/paddlespeech/t2s/datasets/am_batch_fn.py index 4e3ad3c12..0b278abaf 100644 --- a/paddlespeech/t2s/datasets/am_batch_fn.py +++ b/paddlespeech/t2s/datasets/am_batch_fn.py @@ -293,3 +293,45 @@ def transformer_single_spk_batch_fn(examples): "speech_lengths": speech_lengths, } return batch + + +def vits_single_spk_batch_fn(examples): + """ + Returns: + Dict[str, Any]: + - text (Tensor): Text index tensor (B, T_text). + - text_lengths (Tensor): Text length tensor (B,). + - feats (Tensor): Feature tensor (B, T_feats, aux_channels). + - feats_lengths (Tensor): Feature length tensor (B,). + - speech (Tensor): Speech waveform tensor (B, T_wav). + + """ + # fields = ["text", "text_lengths", "feats", "feats_lengths", "speech"] + text = [np.array(item["text"], dtype=np.int64) for item in examples] + feats = [np.array(item["feats"], dtype=np.float32) for item in examples] + speech = [np.array(item["wave"], dtype=np.float32) for item in examples] + text_lengths = [ + np.array(item["text_lengths"], dtype=np.int64) for item in examples + ] + feats_lengths = [ + np.array(item["feats_lengths"], dtype=np.int64) for item in examples + ] + + text = batch_sequences(text) + feats = batch_sequences(feats) + speech = batch_sequences(speech) + + # convert each batch to paddle.Tensor + text = paddle.to_tensor(text) + feats = paddle.to_tensor(feats) + text_lengths = paddle.to_tensor(text_lengths) + feats_lengths = paddle.to_tensor(feats_lengths) + + batch = { + "text": text, + "text_lengths": text_lengths, + "feats": feats, + "feats_lengths": feats_lengths, + "speech": speech + } + return batch diff --git a/paddlespeech/t2s/datasets/batch.py b/paddlespeech/t2s/datasets/batch.py index 9d83bbe09..4f21d4470 100644 --- a/paddlespeech/t2s/datasets/batch.py +++ b/paddlespeech/t2s/datasets/batch.py @@ -167,7 +167,6 @@ def batch_spec(minibatch, pad_value=0., time_major=False, dtype=np.float32): def batch_sequences(sequences, axis=0, pad_value=0): - # import pdb; pdb.set_trace() seq = sequences[0] ndim = seq.ndim if axis < 0: diff --git a/paddlespeech/t2s/datasets/get_feats.py b/paddlespeech/t2s/datasets/get_feats.py index b4bea0bd0..21458f152 100644 --- a/paddlespeech/t2s/datasets/get_feats.py +++ b/paddlespeech/t2s/datasets/get_feats.py @@ -20,15 +20,14 @@ from scipy.interpolate import interp1d class LogMelFBank(): def __init__(self, - sr=24000, - n_fft=2048, - hop_length=300, - win_length=None, - window="hann", - n_mels=80, - fmin=80, - fmax=7600, - eps=1e-10): + sr: int=24000, + n_fft: int=2048, + hop_length: int=300, + win_length: int=None, + window: str="hann", + n_mels: int=80, + fmin: int=80, + fmax: int=7600): self.sr = sr # stft self.n_fft = n_fft @@ -54,7 +53,7 @@ class LogMelFBank(): fmax=self.fmax) return mel_filter - def _stft(self, wav): + def _stft(self, wav: np.ndarray): D = librosa.core.stft( wav, n_fft=self.n_fft, @@ -65,11 +64,11 @@ class LogMelFBank(): pad_mode=self.pad_mode) return D - def _spectrogram(self, wav): + def _spectrogram(self, wav: np.ndarray): D = self._stft(wav) return np.abs(D) - def _mel_spectrogram(self, wav): + def _mel_spectrogram(self, wav: np.ndarray): S = self._spectrogram(wav) mel = np.dot(self.mel_filter, S) return mel @@ -90,14 +89,18 @@ class LogMelFBank(): class Pitch(): - def __init__(self, sr=24000, hop_length=300, f0min=80, f0max=7600): + def __init__(self, + sr: int=24000, + hop_length: int=300, + f0min: int=80, + f0max: int=7600): self.sr = sr self.hop_length = hop_length self.f0min = f0min self.f0max = f0max - def _convert_to_continuous_f0(self, f0: np.array) -> np.array: + def _convert_to_continuous_f0(self, f0: np.ndarray) -> np.ndarray: if (f0 == 0).all(): print("All frames seems to be unvoiced.") return f0 @@ -120,9 +123,9 @@ class Pitch(): return f0 def _calculate_f0(self, - input: np.array, - use_continuous_f0=True, - use_log_f0=True) -> np.array: + input: np.ndarray, + use_continuous_f0: bool=True, + use_log_f0: bool=True) -> np.ndarray: input = input.astype(np.float) frame_period = 1000 * self.hop_length / self.sr f0, timeaxis = pyworld.dio( @@ -139,7 +142,8 @@ class Pitch(): f0[nonzero_idxs] = np.log(f0[nonzero_idxs]) return f0.reshape(-1) - def _average_by_duration(self, input: np.array, d: np.array) -> np.array: + def _average_by_duration(self, input: np.ndarray, + d: np.ndarray) -> np.ndarray: d_cumsum = np.pad(d.cumsum(0), (1, 0), 'constant') arr_list = [] for start, end in zip(d_cumsum[:-1], d_cumsum[1:]): @@ -154,11 +158,11 @@ class Pitch(): return arr_list def get_pitch(self, - wav, - use_continuous_f0=True, - use_log_f0=True, - use_token_averaged_f0=True, - duration=None): + wav: np.ndarray, + use_continuous_f0: bool=True, + use_log_f0: bool=True, + use_token_averaged_f0: bool=True, + duration: np.ndarray=None): f0 = self._calculate_f0(wav, use_continuous_f0, use_log_f0) if use_token_averaged_f0 and duration is not None: f0 = self._average_by_duration(f0, duration) @@ -167,15 +171,13 @@ class Pitch(): class Energy(): def __init__(self, - sr=24000, - n_fft=2048, - hop_length=300, - win_length=None, - window="hann", - center=True, - pad_mode="reflect"): + n_fft: int=2048, + hop_length: int=300, + win_length: int=None, + window: str="hann", + center: bool=True, + pad_mode: str="reflect"): - self.sr = sr self.n_fft = n_fft self.win_length = win_length self.hop_length = hop_length @@ -183,7 +185,7 @@ class Energy(): self.center = center self.pad_mode = pad_mode - def _stft(self, wav): + def _stft(self, wav: np.ndarray): D = librosa.core.stft( wav, n_fft=self.n_fft, @@ -194,7 +196,7 @@ class Energy(): pad_mode=self.pad_mode) return D - def _calculate_energy(self, input): + def _calculate_energy(self, input: np.ndarray): input = input.astype(np.float32) input_stft = self._stft(input) input_power = np.abs(input_stft)**2 @@ -203,7 +205,8 @@ class Energy(): np.sum(input_power, axis=0), a_min=1.0e-10, a_max=float('inf'))) return energy - def _average_by_duration(self, input: np.array, d: np.array) -> np.array: + def _average_by_duration(self, input: np.ndarray, + d: np.ndarray) -> np.ndarray: d_cumsum = np.pad(d.cumsum(0), (1, 0), 'constant') arr_list = [] for start, end in zip(d_cumsum[:-1], d_cumsum[1:]): @@ -214,8 +217,49 @@ class Energy(): arr_list = np.expand_dims(np.array(arr_list), 0).T return arr_list - def get_energy(self, wav, use_token_averaged_energy=True, duration=None): + def get_energy(self, + wav: np.ndarray, + use_token_averaged_energy: bool=True, + duration: np.ndarray=None): energy = self._calculate_energy(wav) if use_token_averaged_energy and duration is not None: energy = self._average_by_duration(energy, duration) return energy + + +class LinearSpectrogram(): + def __init__( + self, + n_fft: int=1024, + win_length: int=None, + hop_length: int=256, + window: str="hann", + center: bool=True, ): + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.center = center + self.n_fft = n_fft + self.pad_mode = "reflect" + + def _stft(self, wav: np.ndarray): + D = librosa.core.stft( + wav, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + center=self.center, + pad_mode=self.pad_mode) + return D + + def _spectrogram(self, wav: np.ndarray): + D = self._stft(wav) + return np.abs(D) + + def get_linear_spectrogram(self, wav: np.ndarray): + linear_spectrogram = self._spectrogram(wav) + linear_spectrogram = np.clip( + linear_spectrogram, a_min=1e-10, a_max=float("inf")) + return linear_spectrogram.T diff --git a/paddlespeech/t2s/exps/fastspeech2/preprocess.py b/paddlespeech/t2s/exps/fastspeech2/preprocess.py index 5fc51365a..eac75f982 100644 --- a/paddlespeech/t2s/exps/fastspeech2/preprocess.py +++ b/paddlespeech/t2s/exps/fastspeech2/preprocess.py @@ -147,10 +147,17 @@ def process_sentences(config, spk_emb_dir: Path=None): if nprocs == 1: results = [] - for fp in fps: - record = process_sentence(config, fp, sentences, output_dir, - mel_extractor, pitch_extractor, - energy_extractor, cut_sil, spk_emb_dir) + for fp in tqdm.tqdm(fps, total=len(fps)): + record = process_sentence( + config=config, + fp=fp, + sentences=sentences, + output_dir=output_dir, + mel_extractor=mel_extractor, + pitch_extractor=pitch_extractor, + energy_extractor=energy_extractor, + cut_sil=cut_sil, + spk_emb_dir=spk_emb_dir) if record: results.append(record) else: @@ -325,7 +332,6 @@ def main(): f0min=config.f0min, f0max=config.f0max) energy_extractor = Energy( - sr=config.fs, n_fft=config.n_fft, hop_length=config.n_shift, win_length=config.win_length, @@ -334,36 +340,36 @@ def main(): # process for the 3 sections if train_wav_files: process_sentences( - config, - train_wav_files, - sentences, - train_dump_dir, - mel_extractor, - pitch_extractor, - energy_extractor, + config=config, + fps=train_wav_files, + sentences=sentences, + output_dir=train_dump_dir, + mel_extractor=mel_extractor, + pitch_extractor=pitch_extractor, + energy_extractor=energy_extractor, nprocs=args.num_cpu, cut_sil=args.cut_sil, spk_emb_dir=spk_emb_dir) if dev_wav_files: process_sentences( - config, - dev_wav_files, - sentences, - dev_dump_dir, - mel_extractor, - pitch_extractor, - energy_extractor, + config=config, + fps=dev_wav_files, + sentences=sentences, + output_dir=dev_dump_dir, + mel_extractor=mel_extractor, + pitch_extractor=pitch_extractor, + energy_extractor=energy_extractor, cut_sil=args.cut_sil, spk_emb_dir=spk_emb_dir) if test_wav_files: process_sentences( - config, - test_wav_files, - sentences, - test_dump_dir, - mel_extractor, - pitch_extractor, - energy_extractor, + config=config, + fps=test_wav_files, + sentences=sentences, + output_dir=test_dump_dir, + mel_extractor=mel_extractor, + pitch_extractor=pitch_extractor, + energy_extractor=energy_extractor, nprocs=args.num_cpu, cut_sil=args.cut_sil, spk_emb_dir=spk_emb_dir) diff --git a/paddlespeech/t2s/exps/gan_vocoder/hifigan/train.py b/paddlespeech/t2s/exps/gan_vocoder/hifigan/train.py index c70821e78..4c733dc9b 100644 --- a/paddlespeech/t2s/exps/gan_vocoder/hifigan/train.py +++ b/paddlespeech/t2s/exps/gan_vocoder/hifigan/train.py @@ -243,8 +243,7 @@ def main(): # parse args and config and redirect to train_sp parser = argparse.ArgumentParser(description="Train a HiFiGAN model.") - parser.add_argument( - "--config", type=str, help="config file to overwrite default config.") + parser.add_argument("--config", type=str, help="HiFiGAN config file.") parser.add_argument("--train-metadata", type=str, help="training data.") parser.add_argument("--dev-metadata", type=str, help="dev data.") parser.add_argument("--output-dir", type=str, help="output dir.") diff --git a/paddlespeech/t2s/exps/gan_vocoder/multi_band_melgan/train.py b/paddlespeech/t2s/exps/gan_vocoder/multi_band_melgan/train.py index 27ffded63..3b3ebb478 100644 --- a/paddlespeech/t2s/exps/gan_vocoder/multi_band_melgan/train.py +++ b/paddlespeech/t2s/exps/gan_vocoder/multi_band_melgan/train.py @@ -233,7 +233,7 @@ def main(): parser = argparse.ArgumentParser( description="Train a Multi-Band MelGAN model.") parser.add_argument( - "--config", type=str, help="config file to overwrite default config.") + "--config", type=str, help="Multi-Band MelGAN config file.") parser.add_argument("--train-metadata", type=str, help="training data.") parser.add_argument("--dev-metadata", type=str, help="dev data.") parser.add_argument("--output-dir", type=str, help="output dir.") diff --git a/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/train.py b/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/train.py index 92de7a2c4..b26407028 100644 --- a/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/train.py +++ b/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/train.py @@ -208,7 +208,7 @@ def main(): parser = argparse.ArgumentParser( description="Train a ParallelWaveGAN model.") parser.add_argument( - "--config", type=str, help="config file to overwrite default config.") + "--config", type=str, help="ParallelWaveGAN config file.") parser.add_argument("--train-metadata", type=str, help="training data.") parser.add_argument("--dev-metadata", type=str, help="dev data.") parser.add_argument("--output-dir", type=str, help="output dir.") diff --git a/paddlespeech/t2s/exps/gan_vocoder/preprocess.py b/paddlespeech/t2s/exps/gan_vocoder/preprocess.py index 8adab0fea..546367964 100644 --- a/paddlespeech/t2s/exps/gan_vocoder/preprocess.py +++ b/paddlespeech/t2s/exps/gan_vocoder/preprocess.py @@ -88,15 +88,17 @@ def process_sentence(config: Dict[str, Any], y, (0, num_frames * config.n_shift - y.size), mode="reflect") else: y = y[:num_frames * config.n_shift] - num_sample = y.shape[0] + num_samples = y.shape[0] mel_path = output_dir / (utt_id + "_feats.npy") wav_path = output_dir / (utt_id + "_wave.npy") - np.save(wav_path, y) # (num_samples, ) - np.save(mel_path, logmel) # (num_frames, n_mels) + # (num_samples, ) + np.save(wav_path, y) + # (num_frames, n_mels) + np.save(mel_path, logmel) record = { "utt_id": utt_id, - "num_samples": num_sample, + "num_samples": num_samples, "num_frames": num_frames, "feats": str(mel_path), "wave": str(wav_path), @@ -111,11 +113,17 @@ def process_sentences(config, mel_extractor=None, nprocs: int=1, cut_sil: bool=True): + if nprocs == 1: results = [] for fp in tqdm.tqdm(fps, total=len(fps)): - record = process_sentence(config, fp, sentences, output_dir, - mel_extractor, cut_sil) + record = process_sentence( + config=config, + fp=fp, + sentences=sentences, + output_dir=output_dir, + mel_extractor=mel_extractor, + cut_sil=cut_sil) if record: results.append(record) else: @@ -150,7 +158,7 @@ def main(): "--dataset", default="baker", type=str, - help="name of dataset, should in {baker, ljspeech, vctk} now") + help="name of dataset, should in {baker, aishell3, ljspeech, vctk} now") parser.add_argument( "--rootdir", default=None, type=str, help="directory to dataset.") parser.add_argument( @@ -264,28 +272,28 @@ def main(): # process for the 3 sections if train_wav_files: process_sentences( - config, - train_wav_files, - sentences, - train_dump_dir, + config=config, + fps=train_wav_files, + sentences=sentences, + output_dir=train_dump_dir, mel_extractor=mel_extractor, nprocs=args.num_cpu, cut_sil=args.cut_sil) if dev_wav_files: process_sentences( - config, - dev_wav_files, - sentences, - dev_dump_dir, + config=config, + fps=dev_wav_files, + sentences=sentences, + output_dir=dev_dump_dir, mel_extractor=mel_extractor, nprocs=args.num_cpu, cut_sil=args.cut_sil) if test_wav_files: process_sentences( - config, - test_wav_files, - sentences, - test_dump_dir, + config=config, + fps=test_wav_files, + sentences=sentences, + output_dir=test_dump_dir, mel_extractor=mel_extractor, nprocs=args.num_cpu, cut_sil=args.cut_sil) diff --git a/paddlespeech/t2s/exps/gan_vocoder/style_melgan/train.py b/paddlespeech/t2s/exps/gan_vocoder/style_melgan/train.py index be3ba7425..a87cc7a18 100644 --- a/paddlespeech/t2s/exps/gan_vocoder/style_melgan/train.py +++ b/paddlespeech/t2s/exps/gan_vocoder/style_melgan/train.py @@ -224,8 +224,7 @@ def main(): # parse args and config and redirect to train_sp parser = argparse.ArgumentParser(description="Train a Style MelGAN model.") - parser.add_argument( - "--config", type=str, help="config file to overwrite default config.") + parser.add_argument("--config", type=str, help="Style MelGAN config file.") parser.add_argument("--train-metadata", type=str, help="training data.") parser.add_argument("--dev-metadata", type=str, help="dev data.") parser.add_argument("--output-dir", type=str, help="output dir.") diff --git a/paddlespeech/t2s/exps/inference_streaming.py b/paddlespeech/t2s/exps/inference_streaming.py index b680f19a9..624defc6a 100644 --- a/paddlespeech/t2s/exps/inference_streaming.py +++ b/paddlespeech/t2s/exps/inference_streaming.py @@ -90,7 +90,7 @@ def parse_args(): default=False, help="whether use streaming acoustic model") parser.add_argument( - "--chunk_size", type=int, default=42, help="chunk size of am streaming") + "--block_size", type=int, default=42, help="block size of am streaming") parser.add_argument( "--pad_size", type=int, default=12, help="pad size of am streaming") @@ -169,7 +169,7 @@ def main(): N = 0 T = 0 - chunk_size = args.chunk_size + block_size = args.block_size pad_size = args.pad_size get_tone_ids = False for utt_id, sentence in sentences: @@ -189,7 +189,7 @@ def main(): am_encoder_infer_predictor, input=phones) if args.am_streaming: - hss = get_chunks(orig_hs, chunk_size, pad_size) + hss = get_chunks(orig_hs, block_size, pad_size) chunk_num = len(hss) mel_list = [] for i, hs in enumerate(hss): @@ -211,7 +211,7 @@ def main(): sub_mel = sub_mel[pad_size:] else: # 倒数几块的右侧也可能没有 pad 够 - sub_mel = sub_mel[pad_size:(chunk_size + pad_size) - + sub_mel = sub_mel[pad_size:(block_size + pad_size) - sub_mel.shape[0]] mel_list.append(sub_mel) mel = np.concatenate(mel_list, axis=0) diff --git a/paddlespeech/t2s/exps/ort_predict_streaming.py b/paddlespeech/t2s/exps/ort_predict_streaming.py index 5d2c66bc9..d5241f1c6 100644 --- a/paddlespeech/t2s/exps/ort_predict_streaming.py +++ b/paddlespeech/t2s/exps/ort_predict_streaming.py @@ -97,7 +97,7 @@ def ort_predict(args): T = 0 merge_sentences = True get_tone_ids = False - chunk_size = args.chunk_size + block_size = args.block_size pad_size = args.pad_size for utt_id, sentence in sentences: @@ -115,7 +115,7 @@ def ort_predict(args): orig_hs = am_encoder_infer_sess.run( None, input_feed={'text': phone_ids}) if args.am_streaming: - hss = get_chunks(orig_hs[0], chunk_size, pad_size) + hss = get_chunks(orig_hs[0], block_size, pad_size) chunk_num = len(hss) mel_list = [] for i, hs in enumerate(hss): @@ -139,7 +139,7 @@ def ort_predict(args): sub_mel = sub_mel[pad_size:] else: # 倒数几块的右侧也可能没有 pad 够 - sub_mel = sub_mel[pad_size:(chunk_size + pad_size) - + sub_mel = sub_mel[pad_size:(block_size + pad_size) - sub_mel.shape[0]] mel_list.append(sub_mel) mel = np.concatenate(mel_list, axis=0) @@ -236,7 +236,7 @@ def parse_args(): default=False, help="whether use streaming acoustic model") parser.add_argument( - "--chunk_size", type=int, default=42, help="chunk size of am streaming") + "--block_size", type=int, default=42, help="block size of am streaming") parser.add_argument( "--pad_size", type=int, default=12, help="pad size of am streaming") diff --git a/paddlespeech/t2s/exps/speedyspeech/preprocess.py b/paddlespeech/t2s/exps/speedyspeech/preprocess.py index 6c6b443fb..aa7608d6b 100644 --- a/paddlespeech/t2s/exps/speedyspeech/preprocess.py +++ b/paddlespeech/t2s/exps/speedyspeech/preprocess.py @@ -126,11 +126,17 @@ def process_sentences(config, nprocs: int=1, cut_sil: bool=True, use_relative_path: bool=False): + if nprocs == 1: results = [] for fp in tqdm.tqdm(fps, total=len(fps)): - record = process_sentence(config, fp, sentences, output_dir, - mel_extractor, cut_sil) + record = process_sentence( + config=config, + fp=fp, + sentences=sentences, + output_dir=output_dir, + mel_extractor=mel_extractor, + cut_sil=cut_sil) if record: results.append(record) else: @@ -268,30 +274,30 @@ def main(): # process for the 3 sections if train_wav_files: process_sentences( - config, - train_wav_files, - sentences, - train_dump_dir, - mel_extractor, + config=config, + fps=train_wav_files, + sentences=sentences, + output_dir=train_dump_dir, + mel_extractor=mel_extractor, nprocs=args.num_cpu, cut_sil=args.cut_sil, use_relative_path=args.use_relative_path) if dev_wav_files: process_sentences( - config, - dev_wav_files, - sentences, - dev_dump_dir, - mel_extractor, + config=config, + fps=dev_wav_files, + sentences=sentences, + output_dir=dev_dump_dir, + mel_extractor=mel_extractor, cut_sil=args.cut_sil, use_relative_path=args.use_relative_path) if test_wav_files: process_sentences( - config, - test_wav_files, - sentences, - test_dump_dir, - mel_extractor, + config=config, + fps=test_wav_files, + sentences=sentences, + output_dir=test_dump_dir, + mel_extractor=mel_extractor, nprocs=args.num_cpu, cut_sil=args.cut_sil, use_relative_path=args.use_relative_path) diff --git a/paddlespeech/t2s/exps/speedyspeech/synthesize_e2e.py b/paddlespeech/t2s/exps/speedyspeech/synthesize_e2e.py index 252ac9326..644ec250d 100644 --- a/paddlespeech/t2s/exps/speedyspeech/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/speedyspeech/synthesize_e2e.py @@ -176,7 +176,10 @@ def main(): parser.add_argument( "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu.") parser.add_argument( - "--nxpu", type=int, default=0, help="if nxpu == 0 and ngpu == 0, use cpu.") + "--nxpu", + type=int, + default=0, + help="if nxpu == 0 and ngpu == 0, use cpu.") args, _ = parser.parse_known_args() diff --git a/paddlespeech/t2s/exps/speedyspeech/train.py b/paddlespeech/t2s/exps/speedyspeech/train.py index d4cfe3488..7b422e64f 100644 --- a/paddlespeech/t2s/exps/speedyspeech/train.py +++ b/paddlespeech/t2s/exps/speedyspeech/train.py @@ -188,7 +188,10 @@ def main(): parser.add_argument("--dev-metadata", type=str, help="dev data.") parser.add_argument("--output-dir", type=str, help="output dir.") parser.add_argument( - "--nxpu", type=int, default=0, help="if nxpu == 0 and ngpu == 0, use cpu.") + "--nxpu", + type=int, + default=0, + help="if nxpu == 0 and ngpu == 0, use cpu.") parser.add_argument( "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu") diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index ce0aee05e..6b9f41a6b 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -27,11 +27,11 @@ from paddle import jit from paddle.static import InputSpec from yacs.config import CfgNode -from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.modules.normalizer import ZScore +from paddlespeech.utils.dynamic_import import dynamic_import model_alias = { # acoustic model @@ -75,13 +75,13 @@ def denorm(data, mean, std): return data * std + mean -def get_chunks(data, chunk_size: int, pad_size: int): +def get_chunks(data, block_size: int, pad_size: int): data_len = data.shape[1] chunks = [] - n = math.ceil(data_len / chunk_size) + n = math.ceil(data_len / block_size) for i in range(n): - start = max(0, i * chunk_size - pad_size) - end = min((i + 1) * chunk_size + pad_size, data_len) + start = max(0, i * block_size - pad_size) + end = min((i + 1) * block_size + pad_size, data_len) chunks.append(data[:, start:end, :]) return chunks diff --git a/paddlespeech/t2s/exps/synthesize.py b/paddlespeech/t2s/exps/synthesize.py index 0855a6a2a..9ddab726e 100644 --- a/paddlespeech/t2s/exps/synthesize.py +++ b/paddlespeech/t2s/exps/synthesize.py @@ -107,8 +107,8 @@ def evaluate(args): if args.voice_cloning and "spk_emb" in datum: spk_emb = paddle.to_tensor(np.load(datum["spk_emb"])) mel = am_inference(phone_ids, spk_emb=spk_emb) - # vocoder - wav = voc_inference(mel) + # vocoder + wav = voc_inference(mel) wav = wav.numpy() N += wav.size @@ -125,7 +125,7 @@ def evaluate(args): def parse_args(): - # parse args and config and redirect to train_sp + # parse args and config parser = argparse.ArgumentParser( description="Synthesize with acoustic model & vocoder") # acoustic model @@ -140,10 +140,7 @@ def parse_args(): ], help='Choose acoustic model type of tts task.') parser.add_argument( - '--am_config', - type=str, - default=None, - help='Config of acoustic model. Use deault config when it is None.') + '--am_config', type=str, default=None, help='Config of acoustic model.') parser.add_argument( '--am_ckpt', type=str, @@ -179,10 +176,7 @@ def parse_args(): ], help='Choose vocoder type of tts task.') parser.add_argument( - '--voc_config', - type=str, - default=None, - help='Config of voc. Use deault config when it is None.') + '--voc_config', type=str, default=None, help='Config of voc.') parser.add_argument( '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') parser.add_argument( diff --git a/paddlespeech/t2s/exps/synthesize_e2e.py b/paddlespeech/t2s/exps/synthesize_e2e.py index 2f14ef564..28657eb27 100644 --- a/paddlespeech/t2s/exps/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/synthesize_e2e.py @@ -159,7 +159,7 @@ def evaluate(args): def parse_args(): - # parse args and config and redirect to train_sp + # parse args and config parser = argparse.ArgumentParser( description="Synthesize with acoustic model & vocoder") # acoustic model @@ -174,10 +174,7 @@ def parse_args(): ], help='Choose acoustic model type of tts task.') parser.add_argument( - '--am_config', - type=str, - default=None, - help='Config of acoustic model. Use deault config when it is None.') + '--am_config', type=str, default=None, help='Config of acoustic model.') parser.add_argument( '--am_ckpt', type=str, @@ -220,10 +217,7 @@ def parse_args(): ], help='Choose vocoder type of tts task.') parser.add_argument( - '--voc_config', - type=str, - default=None, - help='Config of voc. Use deault config when it is None.') + '--voc_config', type=str, default=None, help='Config of voc.') parser.add_argument( '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') parser.add_argument( diff --git a/paddlespeech/t2s/exps/synthesize_streaming.py b/paddlespeech/t2s/exps/synthesize_streaming.py index 3659cb490..d8b23f1ad 100644 --- a/paddlespeech/t2s/exps/synthesize_streaming.py +++ b/paddlespeech/t2s/exps/synthesize_streaming.py @@ -24,7 +24,6 @@ from paddle.static import InputSpec from timer import timer from yacs.config import CfgNode -from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.t2s.exps.syn_utils import denorm from paddlespeech.t2s.exps.syn_utils import get_chunks from paddlespeech.t2s.exps.syn_utils import get_frontend @@ -33,6 +32,7 @@ from paddlespeech.t2s.exps.syn_utils import get_voc_inference from paddlespeech.t2s.exps.syn_utils import model_alias from paddlespeech.t2s.exps.syn_utils import voc_to_static from paddlespeech.t2s.utils import str2bool +from paddlespeech.utils.dynamic_import import dynamic_import def evaluate(args): @@ -133,7 +133,7 @@ def evaluate(args): N = 0 T = 0 - chunk_size = args.chunk_size + block_size = args.block_size pad_size = args.pad_size for utt_id, sentence in sentences: @@ -153,7 +153,7 @@ def evaluate(args): # acoustic model orig_hs = am_encoder_infer(phone_ids) if args.am_streaming: - hss = get_chunks(orig_hs, chunk_size, pad_size) + hss = get_chunks(orig_hs, block_size, pad_size) chunk_num = len(hss) mel_list = [] for i, hs in enumerate(hss): @@ -171,7 +171,7 @@ def evaluate(args): sub_mel = sub_mel[pad_size:] else: # 倒数几块的右侧也可能没有 pad 够 - sub_mel = sub_mel[pad_size:(chunk_size + pad_size) - + sub_mel = sub_mel[pad_size:(block_size + pad_size) - sub_mel.shape[0]] mel_list.append(sub_mel) mel = paddle.concat(mel_list, axis=0) @@ -201,7 +201,7 @@ def evaluate(args): def parse_args(): - # parse args and config and redirect to train_sp + # parse args and config parser = argparse.ArgumentParser( description="Synthesize with acoustic model & vocoder") # acoustic model @@ -212,10 +212,7 @@ def parse_args(): choices=['fastspeech2_csmsc'], help='Choose acoustic model type of tts task.') parser.add_argument( - '--am_config', - type=str, - default=None, - help='Config of acoustic model. Use deault config when it is None.') + '--am_config', type=str, default=None, help='Config of acoustic model.') parser.add_argument( '--am_ckpt', type=str, @@ -245,10 +242,7 @@ def parse_args(): ], help='Choose vocoder type of tts task.') parser.add_argument( - '--voc_config', - type=str, - default=None, - help='Config of voc. Use deault config when it is None.') + '--voc_config', type=str, default=None, help='Config of voc.') parser.add_argument( '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') parser.add_argument( @@ -283,7 +277,7 @@ def parse_args(): default=False, help="whether use streaming acoustic model") parser.add_argument( - "--chunk_size", type=int, default=42, help="chunk size of am streaming") + "--block_size", type=int, default=42, help="block size of am streaming") parser.add_argument( "--pad_size", type=int, default=12, help="pad size of am streaming") diff --git a/paddlespeech/t2s/exps/tacotron2/preprocess.py b/paddlespeech/t2s/exps/tacotron2/preprocess.py index 95349d595..6137da7f1 100644 --- a/paddlespeech/t2s/exps/tacotron2/preprocess.py +++ b/paddlespeech/t2s/exps/tacotron2/preprocess.py @@ -125,9 +125,15 @@ def process_sentences(config, spk_emb_dir: Path=None): if nprocs == 1: results = [] - for fp in fps: - record = process_sentence(config, fp, sentences, output_dir, - mel_extractor, cut_sil, spk_emb_dir) + for fp in tqdm.tqdm(fps, total=len(fps)): + record = process_sentence( + config=config, + fp=fp, + sentences=sentences, + output_dir=output_dir, + mel_extractor=mel_extractor, + cut_sil=cut_sil, + spk_emb_dir=spk_emb_dir) if record: results.append(record) else: @@ -299,30 +305,30 @@ def main(): # process for the 3 sections if train_wav_files: process_sentences( - config, - train_wav_files, - sentences, - train_dump_dir, - mel_extractor, + config=config, + fps=train_wav_files, + sentences=sentences, + output_dir=train_dump_dir, + mel_extractor=mel_extractor, nprocs=args.num_cpu, cut_sil=args.cut_sil, spk_emb_dir=spk_emb_dir) if dev_wav_files: process_sentences( - config, - dev_wav_files, - sentences, - dev_dump_dir, - mel_extractor, + config=config, + fps=dev_wav_files, + sentences=sentences, + output_dir=dev_dump_dir, + mel_extractor=mel_extractor, cut_sil=args.cut_sil, spk_emb_dir=spk_emb_dir) if test_wav_files: process_sentences( - config, - test_wav_files, - sentences, - test_dump_dir, - mel_extractor, + config=config, + fps=test_wav_files, + sentences=sentences, + output_dir=test_dump_dir, + mel_extractor=mel_extractor, nprocs=args.num_cpu, cut_sil=args.cut_sil, spk_emb_dir=spk_emb_dir) diff --git a/paddlespeech/t2s/exps/transformer_tts/preprocess.py b/paddlespeech/t2s/exps/transformer_tts/preprocess.py index 9aa87e91a..28ca3de6e 100644 --- a/paddlespeech/t2s/exps/transformer_tts/preprocess.py +++ b/paddlespeech/t2s/exps/transformer_tts/preprocess.py @@ -125,11 +125,16 @@ def process_sentences(config, output_dir: Path, mel_extractor=None, nprocs: int=1): + if nprocs == 1: results = [] for fp in tqdm.tqdm(fps, total=len(fps)): - record = process_sentence(config, fp, sentences, output_dir, - mel_extractor) + record = process_sentence( + config=config, + fp=fp, + sentences=sentences, + output_dir=output_dir, + mel_extractor=mel_extractor) if record: results.append(record) else: @@ -247,27 +252,27 @@ def main(): # process for the 3 sections if train_wav_files: process_sentences( - config, - train_wav_files, - sentences, - train_dump_dir, - mel_extractor, + config=config, + fps=train_wav_files, + sentences=sentences, + output_dir=train_dump_dir, + mel_extractor=mel_extractor, nprocs=args.num_cpu) if dev_wav_files: process_sentences( - config, - dev_wav_files, - sentences, - dev_dump_dir, - mel_extractor, + config=config, + fps=dev_wav_files, + sentences=sentences, + output_dir=dev_dump_dir, + mel_extractor=mel_extractor, nprocs=args.num_cpu) if test_wav_files: process_sentences( - config, - test_wav_files, - sentences, - test_dump_dir, - mel_extractor, + config=config, + fps=test_wav_files, + sentences=sentences, + output_dir=test_dump_dir, + mel_extractor=mel_extractor, nprocs=args.num_cpu) diff --git a/paddlespeech/t2s/exps/transformer_tts/train.py b/paddlespeech/t2s/exps/transformer_tts/train.py index 45ecb269b..da48b6b99 100644 --- a/paddlespeech/t2s/exps/transformer_tts/train.py +++ b/paddlespeech/t2s/exps/transformer_tts/train.py @@ -160,7 +160,7 @@ def main(): parser = argparse.ArgumentParser(description="Train a TransformerTTS " "model with LJSpeech TTS dataset.") parser.add_argument( - "--config", type=str, help="config file to overwrite default config.") + "--config", type=str, help="TransformerTTS config file.") parser.add_argument("--train-metadata", type=str, help="training data.") parser.add_argument("--dev-metadata", type=str, help="dev data.") parser.add_argument("--output-dir", type=str, help="output dir.") diff --git a/paddlespeech/t2s/exps/vits/normalize.py b/paddlespeech/t2s/exps/vits/normalize.py new file mode 100644 index 000000000..6fc8adb06 --- /dev/null +++ b/paddlespeech/t2s/exps/vits/normalize.py @@ -0,0 +1,165 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Normalize feature files and dump them.""" +import argparse +import logging +from operator import itemgetter +from pathlib import Path + +import jsonlines +import numpy as np +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm + +from paddlespeech.t2s.datasets.data_table import DataTable + + +def main(): + """Run preprocessing process.""" + parser = argparse.ArgumentParser( + description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)." + ) + parser.add_argument( + "--metadata", + type=str, + required=True, + help="directory including feature files to be normalized. " + "you need to specify either *-scp or rootdir.") + + parser.add_argument( + "--dumpdir", + type=str, + required=True, + help="directory to dump normalized feature files.") + parser.add_argument( + "--feats-stats", + type=str, + required=True, + help="speech statistics file.") + parser.add_argument( + "--skip-wav-copy", + default=False, + action="store_true", + help="whether to skip the copy of wav files.") + + parser.add_argument( + "--phones-dict", type=str, default=None, help="phone vocabulary file.") + parser.add_argument( + "--speaker-dict", type=str, default=None, help="speaker id map file.") + parser.add_argument( + "--verbose", + type=int, + default=1, + help="logging level. higher is more logging. (default=1)") + args = parser.parse_args() + + # set logger + if args.verbose > 1: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + elif args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + logging.warning('Skip DEBUG/INFO messages') + + dumpdir = Path(args.dumpdir).expanduser() + # use absolute path + dumpdir = dumpdir.resolve() + dumpdir.mkdir(parents=True, exist_ok=True) + + # get dataset + with jsonlines.open(args.metadata, 'r') as reader: + metadata = list(reader) + dataset = DataTable( + metadata, + converters={ + "feats": np.load, + "wave": None if args.skip_wav_copy else np.load, + }) + logging.info(f"The number of files = {len(dataset)}.") + + # restore scaler + feats_scaler = StandardScaler() + feats_scaler.mean_ = np.load(args.feats_stats)[0] + feats_scaler.scale_ = np.load(args.feats_stats)[1] + feats_scaler.n_features_in_ = feats_scaler.mean_.shape[0] + + vocab_phones = {} + with open(args.phones_dict, 'rt') as f: + phn_id = [line.strip().split() for line in f.readlines()] + for phn, id in phn_id: + vocab_phones[phn] = int(id) + + vocab_speaker = {} + with open(args.speaker_dict, 'rt') as f: + spk_id = [line.strip().split() for line in f.readlines()] + for spk, id in spk_id: + vocab_speaker[spk] = int(id) + + # process each file + output_metadata = [] + + for item in tqdm(dataset): + utt_id = item['utt_id'] + feats = item['feats'] + wave = item['wave'] + + # normalize + feats = feats_scaler.transform(feats) + feats_path = dumpdir / f"{utt_id}_feats.npy" + np.save(feats_path, feats.astype(np.float32), allow_pickle=False) + + if not args.skip_wav_copy: + wav_path = dumpdir / f"{utt_id}_wave.npy" + np.save(wav_path, wave.astype(np.float32), allow_pickle=False) + else: + wav_path = wave + + phone_ids = [vocab_phones[p] for p in item['phones']] + spk_id = vocab_speaker[item["speaker"]] + + record = { + "utt_id": item['utt_id'], + "text": phone_ids, + "text_lengths": item['text_lengths'], + 'feats': str(feats_path), + "feats_lengths": item['feats_lengths'], + "wave": str(wav_path), + "spk_id": spk_id, + } + + # add spk_emb for voice cloning + if "spk_emb" in item: + record["spk_emb"] = str(item["spk_emb"]) + + output_metadata.append(record) + output_metadata.sort(key=itemgetter('utt_id')) + output_metadata_path = Path(args.dumpdir) / "metadata.jsonl" + with jsonlines.open(output_metadata_path, 'w') as writer: + for item in output_metadata: + writer.write(item) + logging.info(f"metadata dumped into {output_metadata_path}") + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/vits/preprocess.py b/paddlespeech/t2s/exps/vits/preprocess.py new file mode 100644 index 000000000..6aa139fb5 --- /dev/null +++ b/paddlespeech/t2s/exps/vits/preprocess.py @@ -0,0 +1,348 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from concurrent.futures import ThreadPoolExecutor +from operator import itemgetter +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List + +import jsonlines +import librosa +import numpy as np +import tqdm +import yaml +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.get_feats import LinearSpectrogram +from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length +from paddlespeech.t2s.datasets.preprocess_utils import get_input_token +from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur +from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map +from paddlespeech.t2s.datasets.preprocess_utils import merge_silence +from paddlespeech.t2s.utils import str2bool + + +def process_sentence(config: Dict[str, Any], + fp: Path, + sentences: Dict, + output_dir: Path, + spec_extractor=None, + cut_sil: bool=True, + spk_emb_dir: Path=None): + utt_id = fp.stem + # for vctk + if utt_id.endswith("_mic2"): + utt_id = utt_id[:-5] + record = None + if utt_id in sentences: + # reading, resampling may occur + wav, _ = librosa.load(str(fp), sr=config.fs) + if len(wav.shape) != 1: + return record + max_value = np.abs(wav).max() + if max_value > 1.0: + wav = wav / max_value + assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio." + assert np.abs(wav).max( + ) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM." + phones = sentences[utt_id][0] + durations = sentences[utt_id][1] + speaker = sentences[utt_id][2] + d_cumsum = np.pad(np.array(durations).cumsum(0), (1, 0), 'constant') + # little imprecise than use *.TextGrid directly + times = librosa.frames_to_time( + d_cumsum, sr=config.fs, hop_length=config.n_shift) + if cut_sil: + start = 0 + end = d_cumsum[-1] + if phones[0] == "sil" and len(durations) > 1: + start = times[1] + durations = durations[1:] + phones = phones[1:] + if phones[-1] == 'sil' and len(durations) > 1: + end = times[-2] + durations = durations[:-1] + phones = phones[:-1] + sentences[utt_id][0] = phones + sentences[utt_id][1] = durations + start, end = librosa.time_to_samples([start, end], sr=config.fs) + wav = wav[start:end] + # extract mel feats + spec = spec_extractor.get_linear_spectrogram(wav) + # change duration according to mel_length + compare_duration_and_mel_length(sentences, utt_id, spec) + # utt_id may be popped in compare_duration_and_mel_length + if utt_id not in sentences: + return None + phones = sentences[utt_id][0] + durations = sentences[utt_id][1] + num_frames = spec.shape[0] + assert sum(durations) == num_frames + + if wav.size < num_frames * config.n_shift: + wav = np.pad( + wav, (0, num_frames * config.n_shift - wav.size), + mode="reflect") + else: + wav = wav[:num_frames * config.n_shift] + num_samples = wav.shape[0] + + spec_path = output_dir / (utt_id + "_feats.npy") + wav_path = output_dir / (utt_id + "_wave.npy") + # (num_samples, ) + np.save(wav_path, wav) + # (num_frames, aux_channels) + np.save(spec_path, spec) + + record = { + "utt_id": utt_id, + "phones": phones, + "text_lengths": len(phones), + "feats": str(spec_path), + "feats_lengths": num_frames, + "wave": str(wav_path), + "speaker": speaker + } + if spk_emb_dir: + if speaker in os.listdir(spk_emb_dir): + embed_name = utt_id + ".npy" + embed_path = spk_emb_dir / speaker / embed_name + if embed_path.is_file(): + record["spk_emb"] = str(embed_path) + else: + return None + return record + + +def process_sentences(config, + fps: List[Path], + sentences: Dict, + output_dir: Path, + spec_extractor=None, + nprocs: int=1, + cut_sil: bool=True, + spk_emb_dir: Path=None): + if nprocs == 1: + results = [] + for fp in tqdm.tqdm(fps, total=len(fps)): + record = process_sentence( + config=config, + fp=fp, + sentences=sentences, + output_dir=output_dir, + spec_extractor=spec_extractor, + cut_sil=cut_sil, + spk_emb_dir=spk_emb_dir) + if record: + results.append(record) + else: + with ThreadPoolExecutor(nprocs) as pool: + futures = [] + with tqdm.tqdm(total=len(fps)) as progress: + for fp in fps: + future = pool.submit(process_sentence, config, fp, + sentences, output_dir, spec_extractor, + cut_sil, spk_emb_dir) + future.add_done_callback(lambda p: progress.update()) + futures.append(future) + + results = [] + for ft in futures: + record = ft.result() + if record: + results.append(record) + + results.sort(key=itemgetter("utt_id")) + with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer: + for item in results: + writer.write(item) + print("Done") + + +def main(): + # parse config and args + parser = argparse.ArgumentParser( + description="Preprocess audio and then extract features.") + + parser.add_argument( + "--dataset", + default="baker", + type=str, + help="name of dataset, should in {baker, aishell3, ljspeech, vctk} now") + + parser.add_argument( + "--rootdir", default=None, type=str, help="directory to dataset.") + + parser.add_argument( + "--dumpdir", + type=str, + required=True, + help="directory to dump feature files.") + parser.add_argument( + "--dur-file", default=None, type=str, help="path to durations.txt.") + + parser.add_argument("--config", type=str, help="fastspeech2 config file.") + + parser.add_argument( + "--verbose", + type=int, + default=1, + help="logging level. higher is more logging. (default=1)") + parser.add_argument( + "--num-cpu", type=int, default=1, help="number of process.") + + parser.add_argument( + "--cut-sil", + type=str2bool, + default=True, + help="whether cut sil in the edge of audio") + + parser.add_argument( + "--spk_emb_dir", + default=None, + type=str, + help="directory to speaker embedding files.") + args = parser.parse_args() + + rootdir = Path(args.rootdir).expanduser() + dumpdir = Path(args.dumpdir).expanduser() + # use absolute path + dumpdir = dumpdir.resolve() + dumpdir.mkdir(parents=True, exist_ok=True) + dur_file = Path(args.dur_file).expanduser() + + if args.spk_emb_dir: + spk_emb_dir = Path(args.spk_emb_dir).expanduser().resolve() + else: + spk_emb_dir = None + + assert rootdir.is_dir() + assert dur_file.is_file() + + with open(args.config, 'rt') as f: + config = CfgNode(yaml.safe_load(f)) + + if args.verbose > 1: + print(vars(args)) + print(config) + + sentences, speaker_set = get_phn_dur(dur_file) + + merge_silence(sentences) + phone_id_map_path = dumpdir / "phone_id_map.txt" + speaker_id_map_path = dumpdir / "speaker_id_map.txt" + get_input_token(sentences, phone_id_map_path, args.dataset) + get_spk_id_map(speaker_set, speaker_id_map_path) + + if args.dataset == "baker": + wav_files = sorted(list((rootdir / "Wave").rglob("*.wav"))) + # split data into 3 sections + num_train = 9800 + num_dev = 100 + train_wav_files = wav_files[:num_train] + dev_wav_files = wav_files[num_train:num_train + num_dev] + test_wav_files = wav_files[num_train + num_dev:] + elif args.dataset == "aishell3": + sub_num_dev = 5 + wav_dir = rootdir / "train" / "wav" + train_wav_files = [] + dev_wav_files = [] + test_wav_files = [] + for speaker in os.listdir(wav_dir): + wav_files = sorted(list((wav_dir / speaker).rglob("*.wav"))) + if len(wav_files) > 100: + train_wav_files += wav_files[:-sub_num_dev * 2] + dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev] + test_wav_files += wav_files[-sub_num_dev:] + else: + train_wav_files += wav_files + + elif args.dataset == "ljspeech": + wav_files = sorted(list((rootdir / "wavs").rglob("*.wav"))) + # split data into 3 sections + num_train = 12900 + num_dev = 100 + train_wav_files = wav_files[:num_train] + dev_wav_files = wav_files[num_train:num_train + num_dev] + test_wav_files = wav_files[num_train + num_dev:] + elif args.dataset == "vctk": + sub_num_dev = 5 + wav_dir = rootdir / "wav48_silence_trimmed" + train_wav_files = [] + dev_wav_files = [] + test_wav_files = [] + for speaker in os.listdir(wav_dir): + wav_files = sorted(list((wav_dir / speaker).rglob("*_mic2.flac"))) + if len(wav_files) > 100: + train_wav_files += wav_files[:-sub_num_dev * 2] + dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev] + test_wav_files += wav_files[-sub_num_dev:] + else: + train_wav_files += wav_files + + else: + print("dataset should in {baker, aishell3, ljspeech, vctk} now!") + + train_dump_dir = dumpdir / "train" / "raw" + train_dump_dir.mkdir(parents=True, exist_ok=True) + dev_dump_dir = dumpdir / "dev" / "raw" + dev_dump_dir.mkdir(parents=True, exist_ok=True) + test_dump_dir = dumpdir / "test" / "raw" + test_dump_dir.mkdir(parents=True, exist_ok=True) + + # Extractor + + spec_extractor = LinearSpectrogram( + n_fft=config.n_fft, + hop_length=config.n_shift, + win_length=config.win_length, + window=config.window) + + # process for the 3 sections + if train_wav_files: + process_sentences( + config=config, + fps=train_wav_files, + sentences=sentences, + output_dir=train_dump_dir, + spec_extractor=spec_extractor, + nprocs=args.num_cpu, + cut_sil=args.cut_sil, + spk_emb_dir=spk_emb_dir) + if dev_wav_files: + process_sentences( + config=config, + fps=dev_wav_files, + sentences=sentences, + output_dir=dev_dump_dir, + spec_extractor=spec_extractor, + cut_sil=args.cut_sil, + spk_emb_dir=spk_emb_dir) + if test_wav_files: + process_sentences( + config=config, + fps=test_wav_files, + sentences=sentences, + output_dir=test_dump_dir, + spec_extractor=spec_extractor, + nprocs=args.num_cpu, + cut_sil=args.cut_sil, + spk_emb_dir=spk_emb_dir) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/vits/synthesize.py b/paddlespeech/t2s/exps/vits/synthesize.py new file mode 100644 index 000000000..074b890f9 --- /dev/null +++ b/paddlespeech/t2s/exps/vits/synthesize.py @@ -0,0 +1,117 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +from pathlib import Path + +import jsonlines +import paddle +import soundfile as sf +import yaml +from timer import timer +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.data_table import DataTable +from paddlespeech.t2s.models.vits import VITS + + +def evaluate(args): + + # construct dataset for evaluation + with jsonlines.open(args.test_metadata, 'r') as reader: + test_metadata = list(reader) + # Init body. + with open(args.config) as f: + config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + + fields = ["utt_id", "text"] + + test_dataset = DataTable(data=test_metadata, fields=fields) + + with open(args.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + + odim = config.n_fft // 2 + 1 + + vits = VITS(idim=vocab_size, odim=odim, **config["model"]) + vits.set_state_dict(paddle.load(args.ckpt)["main_params"]) + vits.eval() + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + N = 0 + T = 0 + + for datum in test_dataset: + utt_id = datum["utt_id"] + phone_ids = paddle.to_tensor(datum["text"]) + with timer() as t: + with paddle.no_grad(): + out = vits.inference(text=phone_ids) + wav = out["wav"] + wav = wav.numpy() + N += wav.size + T += t.elapse + speed = wav.size / t.elapse + rtf = config.fs / speed + print( + f"{utt_id}, wave: {wav.size}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + ) + sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs) + print(f"{utt_id} done!") + print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }") + + +def parse_args(): + # parse args and config + parser = argparse.ArgumentParser(description="Synthesize with VITS") + # model + parser.add_argument( + '--config', type=str, default=None, help='Config of VITS.') + parser.add_argument( + '--ckpt', type=str, default=None, help='Checkpoint file of VITS.') + parser.add_argument( + "--phones_dict", type=str, default=None, help="phone vocabulary file.") + # other + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + parser.add_argument("--test_metadata", type=str, help="test metadata.") + parser.add_argument("--output_dir", type=str, help="output dir.") + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if args.ngpu == 0: + paddle.set_device("cpu") + elif args.ngpu > 0: + paddle.set_device("gpu") + else: + print("ngpu should >= 0 !") + + evaluate(args) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/vits/synthesize_e2e.py b/paddlespeech/t2s/exps/vits/synthesize_e2e.py new file mode 100644 index 000000000..c82e5c039 --- /dev/null +++ b/paddlespeech/t2s/exps/vits/synthesize_e2e.py @@ -0,0 +1,146 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +from pathlib import Path + +import paddle +import soundfile as sf +import yaml +from timer import timer +from yacs.config import CfgNode + +from paddlespeech.t2s.exps.syn_utils import get_frontend +from paddlespeech.t2s.exps.syn_utils import get_sentences +from paddlespeech.t2s.models.vits import VITS + + +def evaluate(args): + + # Init body. + with open(args.config) as f: + config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + + sentences = get_sentences(text_file=args.text, lang=args.lang) + + # frontend + frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict) + + with open(args.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + + odim = config.n_fft // 2 + 1 + + vits = VITS(idim=vocab_size, odim=odim, **config["model"]) + vits.set_state_dict(paddle.load(args.ckpt)["main_params"]) + vits.eval() + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + merge_sentences = False + + N = 0 + T = 0 + for utt_id, sentence in sentences: + with timer() as t: + if args.lang == 'zh': + input_ids = frontend.get_input_ids( + sentence, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + elif args.lang == 'en': + input_ids = frontend.get_input_ids( + sentence, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + else: + print("lang should in {'zh', 'en'}!") + with paddle.no_grad(): + flags = 0 + for i in range(len(phone_ids)): + part_phone_ids = phone_ids[i] + out = vits.inference(text=part_phone_ids) + wav = out["wav"] + if flags == 0: + wav_all = wav + flags = 1 + else: + wav_all = paddle.concat([wav_all, wav]) + wav = wav_all.numpy() + N += wav.size + T += t.elapse + speed = wav.size / t.elapse + rtf = config.fs / speed + print( + f"{utt_id}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + ) + sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs) + print(f"{utt_id} done!") + print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }") + + +def parse_args(): + # parse args and config + parser = argparse.ArgumentParser(description="Synthesize with VITS") + + # model + parser.add_argument( + '--config', type=str, default=None, help='Config of VITS.') + parser.add_argument( + '--ckpt', type=str, default=None, help='Checkpoint file of VITS.') + parser.add_argument( + "--phones_dict", type=str, default=None, help="phone vocabulary file.") + # other + parser.add_argument( + '--lang', + type=str, + default='zh', + help='Choose model language. zh or en') + + parser.add_argument( + "--inference_dir", + type=str, + default=None, + help="dir to save inference models") + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + parser.add_argument( + "--text", + type=str, + help="text to synthesize, a 'utt_id sentence' pair per line.") + parser.add_argument("--output_dir", type=str, help="output dir.") + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if args.ngpu == 0: + paddle.set_device("cpu") + elif args.ngpu > 0: + paddle.set_device("gpu") + else: + print("ngpu should >= 0 !") + + evaluate(args) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/vits/train.py b/paddlespeech/t2s/exps/vits/train.py new file mode 100644 index 000000000..dbda8b717 --- /dev/null +++ b/paddlespeech/t2s/exps/vits/train.py @@ -0,0 +1,260 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging +import os +import shutil +from pathlib import Path + +import jsonlines +import numpy as np +import paddle +import yaml +from paddle import DataParallel +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler +from paddle.optimizer import Adam +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.am_batch_fn import vits_single_spk_batch_fn +from paddlespeech.t2s.datasets.data_table import DataTable +from paddlespeech.t2s.models.vits import VITS +from paddlespeech.t2s.models.vits import VITSEvaluator +from paddlespeech.t2s.models.vits import VITSUpdater +from paddlespeech.t2s.modules.losses import DiscriminatorAdversarialLoss +from paddlespeech.t2s.modules.losses import FeatureMatchLoss +from paddlespeech.t2s.modules.losses import GeneratorAdversarialLoss +from paddlespeech.t2s.modules.losses import KLDivergenceLoss +from paddlespeech.t2s.modules.losses import MelSpectrogramLoss +from paddlespeech.t2s.training.extensions.snapshot import Snapshot +from paddlespeech.t2s.training.extensions.visualizer import VisualDL +from paddlespeech.t2s.training.optimizer import scheduler_classes +from paddlespeech.t2s.training.seeding import seed_everything +from paddlespeech.t2s.training.trainer import Trainer + + +def train_sp(args, config): + # decides device type and whether to run in parallel + # setup running environment correctly + world_size = paddle.distributed.get_world_size() + if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0: + paddle.set_device("cpu") + else: + paddle.set_device("gpu") + if world_size > 1: + paddle.distributed.init_parallel_env() + + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + print( + f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", + ) + + # dataloader has been too verbose + logging.getLogger("DataLoader").disabled = True + + fields = ["text", "text_lengths", "feats", "feats_lengths", "wave"] + + converters = { + "wave": np.load, + "feats": np.load, + } + + # construct dataset for training and validation + with jsonlines.open(args.train_metadata, 'r') as reader: + train_metadata = list(reader) + train_dataset = DataTable( + data=train_metadata, + fields=fields, + converters=converters, ) + with jsonlines.open(args.dev_metadata, 'r') as reader: + dev_metadata = list(reader) + dev_dataset = DataTable( + data=dev_metadata, + fields=fields, + converters=converters, ) + + # collate function and dataloader + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=config.batch_size, + shuffle=True, + drop_last=True) + dev_sampler = DistributedBatchSampler( + dev_dataset, + batch_size=config.batch_size, + shuffle=False, + drop_last=False) + print("samplers done!") + + train_batch_fn = vits_single_spk_batch_fn + + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=train_batch_fn, + num_workers=config.num_workers) + + dev_dataloader = DataLoader( + dev_dataset, + batch_sampler=dev_sampler, + collate_fn=train_batch_fn, + num_workers=config.num_workers) + print("dataloaders done!") + + with open(args.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + + odim = config.n_fft // 2 + 1 + model = VITS(idim=vocab_size, odim=odim, **config["model"]) + gen_parameters = model.generator.parameters() + dis_parameters = model.discriminator.parameters() + if world_size > 1: + model = DataParallel(model) + gen_parameters = model._layers.generator.parameters() + dis_parameters = model._layers.discriminator.parameters() + + print("model done!") + + # loss + criterion_mel = MelSpectrogramLoss( + **config["mel_loss_params"], ) + criterion_feat_match = FeatureMatchLoss( + **config["feat_match_loss_params"], ) + criterion_gen_adv = GeneratorAdversarialLoss( + **config["generator_adv_loss_params"], ) + criterion_dis_adv = DiscriminatorAdversarialLoss( + **config["discriminator_adv_loss_params"], ) + criterion_kl = KLDivergenceLoss() + + print("criterions done!") + + lr_schedule_g = scheduler_classes[config["generator_scheduler"]]( + **config["generator_scheduler_params"]) + optimizer_g = Adam( + learning_rate=lr_schedule_g, + parameters=gen_parameters, + **config["generator_optimizer_params"]) + + lr_schedule_d = scheduler_classes[config["discriminator_scheduler"]]( + **config["discriminator_scheduler_params"]) + optimizer_d = Adam( + learning_rate=lr_schedule_d, + parameters=dis_parameters, + **config["discriminator_optimizer_params"]) + + print("optimizers done!") + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + if dist.get_rank() == 0: + config_name = args.config.split("/")[-1] + # copy conf to output_dir + shutil.copyfile(args.config, output_dir / config_name) + + updater = VITSUpdater( + model=model, + optimizers={ + "generator": optimizer_g, + "discriminator": optimizer_d, + }, + criterions={ + "mel": criterion_mel, + "feat_match": criterion_feat_match, + "gen_adv": criterion_gen_adv, + "dis_adv": criterion_dis_adv, + "kl": criterion_kl, + }, + schedulers={ + "generator": lr_schedule_g, + "discriminator": lr_schedule_d, + }, + dataloader=train_dataloader, + lambda_adv=config.lambda_adv, + lambda_mel=config.lambda_mel, + lambda_kl=config.lambda_kl, + lambda_feat_match=config.lambda_feat_match, + lambda_dur=config.lambda_dur, + generator_first=config.generator_first, + output_dir=output_dir) + + evaluator = VITSEvaluator( + model=model, + criterions={ + "mel": criterion_mel, + "feat_match": criterion_feat_match, + "gen_adv": criterion_gen_adv, + "dis_adv": criterion_dis_adv, + "kl": criterion_kl, + }, + dataloader=dev_dataloader, + lambda_adv=config.lambda_adv, + lambda_mel=config.lambda_mel, + lambda_kl=config.lambda_kl, + lambda_feat_match=config.lambda_feat_match, + lambda_dur=config.lambda_dur, + generator_first=config.generator_first, + output_dir=output_dir) + + trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir) + + if dist.get_rank() == 0: + trainer.extend(evaluator, trigger=(1, "epoch")) + trainer.extend(VisualDL(output_dir), trigger=(1, "iteration")) + trainer.extend( + Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch')) + + print("Trainer Done!") + trainer.run() + + +def main(): + # parse args and config and redirect to train_sp + + parser = argparse.ArgumentParser(description="Train a VITS model.") + parser.add_argument("--config", type=str, help="VITS config file") + parser.add_argument("--train-metadata", type=str, help="training data.") + parser.add_argument("--dev-metadata", type=str, help="dev data.") + parser.add_argument("--output-dir", type=str, help="output dir.") + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + parser.add_argument( + "--phones-dict", type=str, default=None, help="phone vocabulary file.") + + args = parser.parse_args() + + with open(args.config, 'rt') as f: + config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + print( + f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" + ) + + # dispatch + if args.ngpu > 1: + dist.spawn(train_sp, (args, config), nprocs=args.ngpu) + else: + train_sp(args, config) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/voice_cloning.py b/paddlespeech/t2s/exps/voice_cloning.py index 2742cd068..b51a4d7bc 100644 --- a/paddlespeech/t2s/exps/voice_cloning.py +++ b/paddlespeech/t2s/exps/voice_cloning.py @@ -122,7 +122,7 @@ def voice_cloning(args): def parse_args(): - # parse args and config and redirect to train_sp + # parse args and config parser = argparse.ArgumentParser(description="") parser.add_argument( '--am', @@ -131,10 +131,7 @@ def parse_args(): choices=['fastspeech2_aishell3', 'tacotron2_aishell3'], help='Choose acoustic model type of tts task.') parser.add_argument( - '--am_config', - type=str, - default=None, - help='Config of acoustic model. Use deault config when it is None.') + '--am_config', type=str, default=None, help='Config of acoustic model.') parser.add_argument( '--am_ckpt', type=str, @@ -160,10 +157,7 @@ def parse_args(): help='Choose vocoder type of tts task.') parser.add_argument( - '--voc_config', - type=str, - default=None, - help='Config of voc. Use deault config when it is None.') + '--voc_config', type=str, default=None, help='Config of voc.') parser.add_argument( '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') parser.add_argument( diff --git a/paddlespeech/t2s/exps/wavernn/train.py b/paddlespeech/t2s/exps/wavernn/train.py index 8661d311d..cf24ea268 100644 --- a/paddlespeech/t2s/exps/wavernn/train.py +++ b/paddlespeech/t2s/exps/wavernn/train.py @@ -180,8 +180,7 @@ def main(): # parse args and config and redirect to train_sp parser = argparse.ArgumentParser(description="Train a WaveRNN model.") - parser.add_argument( - "--config", type=str, help="config file to overwrite default config.") + parser.add_argument("--config", type=str, help="WaveRNN config file.") parser.add_argument("--train-metadata", type=str, help="training data.") parser.add_argument("--dev-metadata", type=str, help="dev data.") parser.add_argument("--output-dir", type=str, help="output dir.") diff --git a/paddlespeech/t2s/frontend/zh_frontend.py b/paddlespeech/t2s/frontend/zh_frontend.py index bb8ed5b49..129aa944e 100644 --- a/paddlespeech/t2s/frontend/zh_frontend.py +++ b/paddlespeech/t2s/frontend/zh_frontend.py @@ -195,7 +195,7 @@ class Frontend(): new_initials.append(initials[i]) return new_initials, new_finals - def _p2id(self, phonemes: List[str]) -> np.array: + def _p2id(self, phonemes: List[str]) -> np.ndarray: # replace unk phone with sp phonemes = [ phn if phn in self.vocab_phones else "sp" for phn in phonemes @@ -203,7 +203,7 @@ class Frontend(): phone_ids = [self.vocab_phones[item] for item in phonemes] return np.array(phone_ids, np.int64) - def _t2id(self, tones: List[str]) -> np.array: + def _t2id(self, tones: List[str]) -> np.ndarray: # replace unk phone with sp tones = [tone if tone in self.vocab_tones else "0" for tone in tones] tone_ids = [self.vocab_tones[item] for item in tones] diff --git a/paddlespeech/t2s/models/__init__.py b/paddlespeech/t2s/models/__init__.py index 41be7c1db..0b6f29119 100644 --- a/paddlespeech/t2s/models/__init__.py +++ b/paddlespeech/t2s/models/__init__.py @@ -18,5 +18,6 @@ from .parallel_wavegan import * from .speedyspeech import * from .tacotron2 import * from .transformer_tts import * +from .vits import * from .waveflow import * from .wavernn import * diff --git a/paddlespeech/t2s/models/hifigan/hifigan.py b/paddlespeech/t2s/models/hifigan/hifigan.py index ac5ff204f..bea9dd9a3 100644 --- a/paddlespeech/t2s/models/hifigan/hifigan.py +++ b/paddlespeech/t2s/models/hifigan/hifigan.py @@ -16,6 +16,7 @@ import copy from typing import Any from typing import Dict from typing import List +from typing import Optional import paddle import paddle.nn.functional as F @@ -34,6 +35,7 @@ class HiFiGANGenerator(nn.Layer): in_channels: int=80, out_channels: int=1, channels: int=512, + global_channels: int=-1, kernel_size: int=7, upsample_scales: List[int]=(8, 8, 2, 2), upsample_kernel_sizes: List[int]=(16, 16, 4, 4), @@ -51,6 +53,7 @@ class HiFiGANGenerator(nn.Layer): in_channels (int): Number of input channels. out_channels (int): Number of output channels. channels (int): Number of hidden representation channels. + global_channels (int): Number of global conditioning channels. kernel_size (int): Kernel size of initial and final conv layer. upsample_scales (list): List of upsampling scales. upsample_kernel_sizes (list): List of kernel sizes for upsampling layers. @@ -119,6 +122,9 @@ class HiFiGANGenerator(nn.Layer): padding=(kernel_size - 1) // 2, ), nn.Tanh(), ) + if global_channels > 0: + self.global_conv = nn.Conv1D(global_channels, channels, 1) + nn.initializer.set_global_initializer(None) # apply weight norm @@ -128,15 +134,18 @@ class HiFiGANGenerator(nn.Layer): # reset parameters self.reset_parameters() - def forward(self, c): + def forward(self, c, g: Optional[paddle.Tensor]=None): """Calculate forward propagation. Args: c (Tensor): Input tensor (B, in_channels, T). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). Returns: Tensor: Output tensor (B, out_channels, T). """ c = self.input_conv(c) + if g is not None: + c = c + self.global_conv(g) for i in range(self.num_upsamples): c = self.upsamples[i](c) # initialize @@ -187,16 +196,19 @@ class HiFiGANGenerator(nn.Layer): self.apply(_remove_weight_norm) - def inference(self, c): + def inference(self, c, g: Optional[paddle.Tensor]=None): """Perform inference. Args: c (Tensor): Input tensor (T, in_channels). normalize_before (bool): Whether to perform normalization. + g (Optional[Tensor]): Global conditioning tensor (global_channels, 1). Returns: Tensor: Output tensor (T ** prod(upsample_scales), out_channels). """ - c = self.forward(c.transpose([1, 0]).unsqueeze(0)) + if g is not None: + g = g.unsqueeze(0) + c = self.forward(c.transpose([1, 0]).unsqueeze(0), g=g) return c.squeeze(0).transpose([1, 0]) diff --git a/paddlespeech/t2s/models/parallel_wavegan/parallel_wavegan_updater.py b/paddlespeech/t2s/models/parallel_wavegan/parallel_wavegan_updater.py index 40cfff5a5..c1cd73308 100644 --- a/paddlespeech/t2s/models/parallel_wavegan/parallel_wavegan_updater.py +++ b/paddlespeech/t2s/models/parallel_wavegan/parallel_wavegan_updater.py @@ -68,8 +68,8 @@ class PWGUpdater(StandardUpdater): self.discriminator_train_start_steps = discriminator_train_start_steps self.lambda_adv = lambda_adv self.lambda_aux = lambda_aux - self.state = UpdaterState(iteration=0, epoch=0) + self.state = UpdaterState(iteration=0, epoch=0) self.train_iterator = iter(self.dataloader) log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) diff --git a/paddlespeech/t2s/models/speedyspeech/speedyspeech_updater.py b/paddlespeech/t2s/models/speedyspeech/speedyspeech_updater.py index e30a3fe1a..b20fda1f7 100644 --- a/paddlespeech/t2s/models/speedyspeech/speedyspeech_updater.py +++ b/paddlespeech/t2s/models/speedyspeech/speedyspeech_updater.py @@ -16,7 +16,6 @@ from pathlib import Path import paddle from paddle import distributed as dist -from paddle.fluid.layers import huber_loss from paddle.io import DataLoader from paddle.nn import functional as F from paddle.nn import Layer @@ -78,8 +77,11 @@ class SpeedySpeechUpdater(StandardUpdater): target_durations.astype(predicted_durations.dtype), paddle.to_tensor([1.0])) duration_loss = weighted_mean( - huber_loss( - predicted_durations, paddle.log(target_durations), delta=1.0), + F.smooth_l1_loss( + predicted_durations, + paddle.log(target_durations), + delta=1.0, + reduction='none', ), text_mask, ) # ssim loss @@ -146,8 +148,11 @@ class SpeedySpeechEvaluator(StandardEvaluator): target_durations.astype(predicted_durations.dtype), paddle.to_tensor([1.0])) duration_loss = weighted_mean( - huber_loss( - predicted_durations, paddle.log(target_durations), delta=1.0), + F.smooth_l1_loss( + predicted_durations, + paddle.log(target_durations), + delta=1.0, + reduction='none', ), text_mask, ) # ssim loss diff --git a/paddlespeech/t2s/models/vits/__init__.py b/paddlespeech/t2s/models/vits/__init__.py new file mode 100644 index 000000000..ea43028ae --- /dev/null +++ b/paddlespeech/t2s/models/vits/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .vits import * +from .vits_updater import * diff --git a/paddlespeech/t2s/models/vits/duration_predictor.py b/paddlespeech/t2s/models/vits/duration_predictor.py new file mode 100644 index 000000000..6197d5696 --- /dev/null +++ b/paddlespeech/t2s/models/vits/duration_predictor.py @@ -0,0 +1,172 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Stochastic duration predictor modules in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" +import math +from typing import Optional + +import paddle +import paddle.nn.functional as F +from paddle import nn + +from paddlespeech.t2s.models.vits.flow import ConvFlow +from paddlespeech.t2s.models.vits.flow import DilatedDepthSeparableConv +from paddlespeech.t2s.models.vits.flow import ElementwiseAffineFlow +from paddlespeech.t2s.models.vits.flow import FlipFlow +from paddlespeech.t2s.models.vits.flow import LogFlow + + +class StochasticDurationPredictor(nn.Layer): + """Stochastic duration predictor module. + This is a module of stochastic duration predictor described in `Conditional + Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_. + .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`: https://arxiv.org/abs/2106.06103 + """ + + def __init__( + self, + channels: int=192, + kernel_size: int=3, + dropout_rate: float=0.5, + flows: int=4, + dds_conv_layers: int=3, + global_channels: int=-1, ): + """Initialize StochasticDurationPredictor module. + Args: + channels (int): Number of channels. + kernel_size (int): Kernel size. + dropout_rate (float): Dropout rate. + flows (int): Number of flows. + dds_conv_layers (int): Number of conv layers in DDS conv. + global_channels (int): Number of global conditioning channels. + """ + super().__init__() + + self.pre = nn.Conv1D(channels, channels, 1) + self.dds = DilatedDepthSeparableConv( + channels, + kernel_size, + layers=dds_conv_layers, + dropout_rate=dropout_rate, ) + self.proj = nn.Conv1D(channels, channels, 1) + + self.log_flow = LogFlow() + self.flows = nn.LayerList() + self.flows.append(ElementwiseAffineFlow(2)) + for i in range(flows): + self.flows.append( + ConvFlow( + 2, + channels, + kernel_size, + layers=dds_conv_layers, )) + self.flows.append(FlipFlow()) + + self.post_pre = nn.Conv1D(1, channels, 1) + self.post_dds = DilatedDepthSeparableConv( + channels, + kernel_size, + layers=dds_conv_layers, + dropout_rate=dropout_rate, ) + self.post_proj = nn.Conv1D(channels, channels, 1) + self.post_flows = nn.LayerList() + self.post_flows.append(ElementwiseAffineFlow(2)) + for i in range(flows): + self.post_flows.append( + ConvFlow( + 2, + channels, + kernel_size, + layers=dds_conv_layers, )) + self.post_flows.append(FlipFlow()) + + if global_channels > 0: + self.global_conv = nn.Conv1D(global_channels, channels, 1) + + def forward( + self, + x: paddle.Tensor, + x_mask: paddle.Tensor, + w: Optional[paddle.Tensor]=None, + g: Optional[paddle.Tensor]=None, + inverse: bool=False, + noise_scale: float=1.0, ) -> paddle.Tensor: + """Calculate forward propagation. + Args: + x (Tensor): Input tensor (B, channels, T_text). + x_mask (Tensor): Mask tensor (B, 1, T_text). + w (Optional[Tensor]): Duration tensor (B, 1, T_text). + g (Optional[Tensor]): Global conditioning tensor (B, channels, 1) + inverse (bool): Whether to inverse the flow. + noise_scale (float): Noise scale value. + Returns: + Tensor: If not inverse, negative log-likelihood (NLL) tensor (B,). + If inverse, log-duration tensor (B, 1, T_text). + """ + # stop gradient + # x = x.detach() + x = self.pre(x) + if g is not None: + # stop gradient + x = x + self.global_conv(g.detach()) + x = self.dds(x, x_mask) + x = self.proj(x) * x_mask + + if not inverse: + assert w is not None, "w must be provided." + h_w = self.post_pre(w) + h_w = self.post_dds(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = (paddle.randn([paddle.shape(w)[0], 2, paddle.shape(w)[2]]) * + x_mask) + z_q = e_q + logdet_tot_q = 0.0 + for i, flow in enumerate(self.post_flows): + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = paddle.split(z_q, [1, 1], 1) + u = F.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += paddle.sum( + (F.log_sigmoid(z_u) + F.log_sigmoid(-z_u)) * x_mask, [1, 2]) + logq = (paddle.sum(-0.5 * + (math.log(2 * math.pi) + + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q) + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = paddle.concat([z0, z1], 1) + for flow in self.flows: + z, logdet = flow(z, x_mask, g=x, inverse=inverse) + logdet_tot = logdet_tot + logdet + nll = (paddle.sum(0.5 * (math.log(2 * math.pi) + + (z**2)) * x_mask, [1, 2]) - logdet_tot) + # (B,) + return nll + logq + else: + flows = list(reversed(self.flows)) + # remove a useless vflow + flows = flows[:-2] + [flows[-1]] + z = (paddle.randn([paddle.shape(x)[0], 2, paddle.shape(x)[2]]) * + noise_scale) + for flow in flows: + z = flow(z, x_mask, g=x, inverse=inverse) + z0, z1 = paddle.split(z, 2, axis=1) + logw = z0 + return logw diff --git a/paddlespeech/t2s/models/vits/flow.py b/paddlespeech/t2s/models/vits/flow.py new file mode 100644 index 000000000..3c8f89356 --- /dev/null +++ b/paddlespeech/t2s/models/vits/flow.py @@ -0,0 +1,313 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Basic Flow modules used in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" +import math +from typing import Optional +from typing import Tuple +from typing import Union + +import paddle +from paddle import nn + +from paddlespeech.t2s.models.vits.transform import piecewise_rational_quadratic_transform + + +class FlipFlow(nn.Layer): + """Flip flow module.""" + + def forward(self, x: paddle.Tensor, *args, inverse: bool=False, **kwargs + ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]: + """Calculate forward propagation. + Args: + x (Tensor): Input tensor (B, channels, T). + inverse (bool): Whether to inverse the flow. + Returns: + Tensor: Flipped tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + """ + x = paddle.flip(x, [1]) + if not inverse: + logdet = paddle.zeros(paddle.shape(x)[0], dtype=x.dtype) + return x, logdet + else: + return x + + +class LogFlow(nn.Layer): + """Log flow module.""" + + def forward(self, + x: paddle.Tensor, + x_mask: paddle.Tensor, + inverse: bool=False, + eps: float=1e-5, + **kwargs + ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]: + """Calculate forward propagation. + Args: + x (Tensor): Input tensor (B, channels, T). + x_mask (Tensor): Mask tensor (B, 1, T). + inverse (bool): Whether to inverse the flow. + eps (float): Epsilon for log. + Returns: + Tensor: Output tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + """ + if not inverse: + y = paddle.log(paddle.clip(x, min=eps)) * x_mask + logdet = paddle.sum(-y, [1, 2]) + return y, logdet + else: + x = paddle.exp(x) * x_mask + return x + + +class ElementwiseAffineFlow(nn.Layer): + """Elementwise affine flow module.""" + + def __init__(self, channels: int): + """Initialize ElementwiseAffineFlow module. + Args: + channels (int): Number of channels. + """ + super().__init__() + self.channels = channels + + m = paddle.zeros([channels, 1]) + self.m = paddle.create_parameter( + shape=m.shape, + dtype=str(m.numpy().dtype), + default_initializer=paddle.nn.initializer.Assign(m)) + logs = paddle.zeros([channels, 1]) + self.logs = paddle.create_parameter( + shape=logs.shape, + dtype=str(logs.numpy().dtype), + default_initializer=paddle.nn.initializer.Assign(logs)) + + def forward(self, + x: paddle.Tensor, + x_mask: paddle.Tensor, + inverse: bool=False, + **kwargs + ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]: + """Calculate forward propagation. + Args: + x (Tensor): Input tensor (B, channels, T). + x_mask (Tensor): Mask tensor (B, 1, T). + inverse (bool): Whether to inverse the flow. + Returns: + Tensor: Output tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + """ + if not inverse: + y = self.m + paddle.exp(self.logs) * x + y = y * x_mask + logdet = paddle.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * paddle.exp(-self.logs) * x_mask + return x + + +class Transpose(nn.Layer): + """Transpose module for paddle.nn.Sequential().""" + + def __init__(self, dim1: int, dim2: int): + """Initialize Transpose module.""" + super().__init__() + self.dim1 = dim1 + self.dim2 = dim2 + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + """Transpose.""" + len_dim = len(x.shape) + orig_perm = list(range(len_dim)) + new_perm = orig_perm[:] + temp = new_perm[self.dim1] + new_perm[self.dim1] = new_perm[self.dim2] + new_perm[self.dim2] = temp + + return paddle.transpose(x, new_perm) + + +class DilatedDepthSeparableConv(nn.Layer): + """Dilated depth-separable conv module.""" + + def __init__( + self, + channels: int, + kernel_size: int, + layers: int, + dropout_rate: float=0.0, + eps: float=1e-5, ): + """Initialize DilatedDepthSeparableConv module. + Args: + channels (int): Number of channels. + kernel_size (int): Kernel size. + layers (int): Number of layers. + dropout_rate (float): Dropout rate. + eps (float): Epsilon for layer norm. + """ + super().__init__() + + self.convs = nn.LayerList() + for i in range(layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs.append( + nn.Sequential( + nn.Conv1D( + channels, + channels, + kernel_size, + groups=channels, + dilation=dilation, + padding=padding, ), + Transpose(1, 2), + nn.LayerNorm(channels, epsilon=eps), + Transpose(1, 2), + nn.GELU(), + nn.Conv1D( + channels, + channels, + 1, ), + Transpose(1, 2), + nn.LayerNorm(channels, epsilon=eps), + Transpose(1, 2), + nn.GELU(), + nn.Dropout(dropout_rate), )) + + def forward(self, + x: paddle.Tensor, + x_mask: paddle.Tensor, + g: Optional[paddle.Tensor]=None) -> paddle.Tensor: + """Calculate forward propagation. + Args: + x (Tensor): Input tensor (B, in_channels, T). + x_mask (Tensor): Mask tensor (B, 1, T). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + Returns: + Tensor: Output tensor (B, channels, T). + """ + if g is not None: + x = x + g + for f in self.convs: + y = f(x * x_mask) + x = x + y + return x * x_mask + + +class ConvFlow(nn.Layer): + """Convolutional flow module.""" + + def __init__( + self, + in_channels: int, + hidden_channels: int, + kernel_size: int, + layers: int, + bins: int=10, + tail_bound: float=5.0, ): + """Initialize ConvFlow module. + Args: + in_channels (int): Number of input channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size. + layers (int): Number of layers. + bins (int): Number of bins. + tail_bound (float): Tail bound value. + """ + super().__init__() + self.half_channels = in_channels // 2 + self.hidden_channels = hidden_channels + self.bins = bins + self.tail_bound = tail_bound + + self.input_conv = nn.Conv1D( + self.half_channels, + hidden_channels, + 1, ) + self.dds_conv = DilatedDepthSeparableConv( + hidden_channels, + kernel_size, + layers, + dropout_rate=0.0, ) + self.proj = nn.Conv1D( + hidden_channels, + self.half_channels * (bins * 3 - 1), + 1, ) + + weight = paddle.zeros(paddle.shape(self.proj.weight)) + + self.proj.weight = paddle.create_parameter( + shape=weight.shape, + dtype=str(weight.numpy().dtype), + default_initializer=paddle.nn.initializer.Assign(weight)) + + bias = paddle.zeros(paddle.shape(self.proj.bias)) + + self.proj.bias = paddle.create_parameter( + shape=bias.shape, + dtype=str(bias.numpy().dtype), + default_initializer=paddle.nn.initializer.Assign(bias)) + + def forward( + self, + x: paddle.Tensor, + x_mask: paddle.Tensor, + g: Optional[paddle.Tensor]=None, + inverse: bool=False, + ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]: + """Calculate forward propagation. + Args: + x (Tensor): Input tensor (B, channels, T). + x_mask (Tensor): Mask tensor (B, 1, T). + g (Optional[Tensor]): Global conditioning tensor (B, channels, 1). + inverse (bool): Whether to inverse the flow. + Returns: + Tensor: Output tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + """ + xa, xb = x.split(2, 1) + h = self.input_conv(xa) + h = self.dds_conv(h, x_mask, g=g) + # (B, half_channels * (bins * 3 - 1), T) + h = self.proj(h) * x_mask + + b, c, t = xa.shape + # (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1) + h = h.reshape([b, c, -1, t]).transpose([0, 1, 3, 2]) + + denom = math.sqrt(self.hidden_channels) + unnorm_widths = h[..., :self.bins] / denom + unnorm_heights = h[..., self.bins:2 * self.bins] / denom + unnorm_derivatives = h[..., 2 * self.bins:] + xb, logdet_abs = piecewise_rational_quadratic_transform( + xb, + unnorm_widths, + unnorm_heights, + unnorm_derivatives, + inverse=inverse, + tails="linear", + tail_bound=self.tail_bound, ) + x = paddle.concat([xa, xb], 1) * x_mask + logdet = paddle.sum(logdet_abs * x_mask, [1, 2]) + if not inverse: + return x, logdet + else: + return x diff --git a/paddlespeech/t2s/models/vits/generator.py b/paddlespeech/t2s/models/vits/generator.py new file mode 100644 index 000000000..f87de91a2 --- /dev/null +++ b/paddlespeech/t2s/models/vits/generator.py @@ -0,0 +1,550 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Generator module in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" +import math +from typing import List +from typing import Optional +from typing import Tuple + +import numpy as np +import paddle +import paddle.nn.functional as F +from paddle import nn + +from paddlespeech.t2s.models.hifigan import HiFiGANGenerator +from paddlespeech.t2s.models.vits.duration_predictor import StochasticDurationPredictor +from paddlespeech.t2s.models.vits.posterior_encoder import PosteriorEncoder +from paddlespeech.t2s.models.vits.residual_coupling import ResidualAffineCouplingBlock +from paddlespeech.t2s.models.vits.text_encoder import TextEncoder +from paddlespeech.t2s.modules.nets_utils import get_random_segments +from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask + + +class VITSGenerator(nn.Layer): + """Generator module in VITS. + This is a module of VITS described in `Conditional Variational Autoencoder + with Adversarial Learning for End-to-End Text-to-Speech`_. + As text encoder, we use conformer architecture instead of the relative positional + Transformer, which contains additional convolution layers. + .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`: https://arxiv.org/abs/2006.04558 + """ + + def __init__( + self, + vocabs: int, + aux_channels: int=513, + hidden_channels: int=192, + spks: Optional[int]=None, + langs: Optional[int]=None, + spk_embed_dim: Optional[int]=None, + global_channels: int=-1, + segment_size: int=32, + text_encoder_attention_heads: int=2, + text_encoder_ffn_expand: int=4, + text_encoder_blocks: int=6, + text_encoder_positionwise_layer_type: str="conv1d", + text_encoder_positionwise_conv_kernel_size: int=1, + text_encoder_positional_encoding_layer_type: str="rel_pos", + text_encoder_self_attention_layer_type: str="rel_selfattn", + text_encoder_activation_type: str="swish", + text_encoder_normalize_before: bool=True, + text_encoder_dropout_rate: float=0.1, + text_encoder_positional_dropout_rate: float=0.0, + text_encoder_attention_dropout_rate: float=0.0, + text_encoder_conformer_kernel_size: int=7, + use_macaron_style_in_text_encoder: bool=True, + use_conformer_conv_in_text_encoder: bool=True, + decoder_kernel_size: int=7, + decoder_channels: int=512, + decoder_upsample_scales: List[int]=[8, 8, 2, 2], + decoder_upsample_kernel_sizes: List[int]=[16, 16, 4, 4], + decoder_resblock_kernel_sizes: List[int]=[3, 7, 11], + decoder_resblock_dilations: List[List[int]]=[[1, 3, 5], [1, 3, 5], + [1, 3, 5]], + use_weight_norm_in_decoder: bool=True, + posterior_encoder_kernel_size: int=5, + posterior_encoder_layers: int=16, + posterior_encoder_stacks: int=1, + posterior_encoder_base_dilation: int=1, + posterior_encoder_dropout_rate: float=0.0, + use_weight_norm_in_posterior_encoder: bool=True, + flow_flows: int=4, + flow_kernel_size: int=5, + flow_base_dilation: int=1, + flow_layers: int=4, + flow_dropout_rate: float=0.0, + use_weight_norm_in_flow: bool=True, + use_only_mean_in_flow: bool=True, + stochastic_duration_predictor_kernel_size: int=3, + stochastic_duration_predictor_dropout_rate: float=0.5, + stochastic_duration_predictor_flows: int=4, + stochastic_duration_predictor_dds_conv_layers: int=3, ): + """Initialize VITS generator module. + Args: + vocabs (int): Input vocabulary size. + aux_channels (int): Number of acoustic feature channels. + hidden_channels (int): Number of hidden channels. + spks (Optional[int]): Number of speakers. If set to > 1, assume that the + sids will be provided as the input and use sid embedding layer. + langs (Optional[int]): Number of languages. If set to > 1, assume that the + lids will be provided as the input and use sid embedding layer. + spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0, + assume that spembs will be provided as the input. + global_channels (int): Number of global conditioning channels. + segment_size (int): Segment size for decoder. + text_encoder_attention_heads (int): Number of heads in conformer block + of text encoder. + text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block + of text encoder. + text_encoder_blocks (int): Number of conformer blocks in text encoder. + text_encoder_positionwise_layer_type (str): Position-wise layer type in + conformer block of text encoder. + text_encoder_positionwise_conv_kernel_size (int): Position-wise convolution + kernel size in conformer block of text encoder. Only used when the + above layer type is conv1d or conv1d-linear. + text_encoder_positional_encoding_layer_type (str): Positional encoding layer + type in conformer block of text encoder. + text_encoder_self_attention_layer_type (str): Self-attention layer type in + conformer block of text encoder. + text_encoder_activation_type (str): Activation function type in conformer + block of text encoder. + text_encoder_normalize_before (bool): Whether to apply layer norm before + self-attention in conformer block of text encoder. + text_encoder_dropout_rate (float): Dropout rate in conformer block of + text encoder. + text_encoder_positional_dropout_rate (float): Dropout rate for positional + encoding in conformer block of text encoder. + text_encoder_attention_dropout_rate (float): Dropout rate for attention in + conformer block of text encoder. + text_encoder_conformer_kernel_size (int): Conformer conv kernel size. It + will be used when only use_conformer_conv_in_text_encoder = True. + use_macaron_style_in_text_encoder (bool): Whether to use macaron style FFN + in conformer block of text encoder. + use_conformer_conv_in_text_encoder (bool): Whether to use covolution in + conformer block of text encoder. + decoder_kernel_size (int): Decoder kernel size. + decoder_channels (int): Number of decoder initial channels. + decoder_upsample_scales (List[int]): List of upsampling scales in decoder. + decoder_upsample_kernel_sizes (List[int]): List of kernel size for + upsampling layers in decoder. + decoder_resblock_kernel_sizes (List[int]): List of kernel size for resblocks + in decoder. + decoder_resblock_dilations (List[List[int]]): List of list of dilations for + resblocks in decoder. + use_weight_norm_in_decoder (bool): Whether to apply weight normalization in + decoder. + posterior_encoder_kernel_size (int): Posterior encoder kernel size. + posterior_encoder_layers (int): Number of layers of posterior encoder. + posterior_encoder_stacks (int): Number of stacks of posterior encoder. + posterior_encoder_base_dilation (int): Base dilation of posterior encoder. + posterior_encoder_dropout_rate (float): Dropout rate for posterior encoder. + use_weight_norm_in_posterior_encoder (bool): Whether to apply weight + normalization in posterior encoder. + flow_flows (int): Number of flows in flow. + flow_kernel_size (int): Kernel size in flow. + flow_base_dilation (int): Base dilation in flow. + flow_layers (int): Number of layers in flow. + flow_dropout_rate (float): Dropout rate in flow + use_weight_norm_in_flow (bool): Whether to apply weight normalization in + flow. + use_only_mean_in_flow (bool): Whether to use only mean in flow. + stochastic_duration_predictor_kernel_size (int): Kernel size in stochastic + duration predictor. + stochastic_duration_predictor_dropout_rate (float): Dropout rate in + stochastic duration predictor. + stochastic_duration_predictor_flows (int): Number of flows in stochastic + duration predictor. + stochastic_duration_predictor_dds_conv_layers (int): Number of DDS conv + layers in stochastic duration predictor. + """ + super().__init__() + self.segment_size = segment_size + self.text_encoder = TextEncoder( + vocabs=vocabs, + attention_dim=hidden_channels, + attention_heads=text_encoder_attention_heads, + linear_units=hidden_channels * text_encoder_ffn_expand, + blocks=text_encoder_blocks, + positionwise_layer_type=text_encoder_positionwise_layer_type, + positionwise_conv_kernel_size=text_encoder_positionwise_conv_kernel_size, + positional_encoding_layer_type=text_encoder_positional_encoding_layer_type, + self_attention_layer_type=text_encoder_self_attention_layer_type, + activation_type=text_encoder_activation_type, + normalize_before=text_encoder_normalize_before, + dropout_rate=text_encoder_dropout_rate, + positional_dropout_rate=text_encoder_positional_dropout_rate, + attention_dropout_rate=text_encoder_attention_dropout_rate, + conformer_kernel_size=text_encoder_conformer_kernel_size, + use_macaron_style=use_macaron_style_in_text_encoder, + use_conformer_conv=use_conformer_conv_in_text_encoder, ) + self.decoder = HiFiGANGenerator( + in_channels=hidden_channels, + out_channels=1, + channels=decoder_channels, + global_channels=global_channels, + kernel_size=decoder_kernel_size, + upsample_scales=decoder_upsample_scales, + upsample_kernel_sizes=decoder_upsample_kernel_sizes, + resblock_kernel_sizes=decoder_resblock_kernel_sizes, + resblock_dilations=decoder_resblock_dilations, + use_weight_norm=use_weight_norm_in_decoder, ) + self.posterior_encoder = PosteriorEncoder( + in_channels=aux_channels, + out_channels=hidden_channels, + hidden_channels=hidden_channels, + kernel_size=posterior_encoder_kernel_size, + layers=posterior_encoder_layers, + stacks=posterior_encoder_stacks, + base_dilation=posterior_encoder_base_dilation, + global_channels=global_channels, + dropout_rate=posterior_encoder_dropout_rate, + use_weight_norm=use_weight_norm_in_posterior_encoder, ) + self.flow = ResidualAffineCouplingBlock( + in_channels=hidden_channels, + hidden_channels=hidden_channels, + flows=flow_flows, + kernel_size=flow_kernel_size, + base_dilation=flow_base_dilation, + layers=flow_layers, + global_channels=global_channels, + dropout_rate=flow_dropout_rate, + use_weight_norm=use_weight_norm_in_flow, + use_only_mean=use_only_mean_in_flow, ) + # TODO: Add deterministic version as an option + self.duration_predictor = StochasticDurationPredictor( + channels=hidden_channels, + kernel_size=stochastic_duration_predictor_kernel_size, + dropout_rate=stochastic_duration_predictor_dropout_rate, + flows=stochastic_duration_predictor_flows, + dds_conv_layers=stochastic_duration_predictor_dds_conv_layers, + global_channels=global_channels, ) + + self.upsample_factor = int(np.prod(decoder_upsample_scales)) + self.spks = None + if spks is not None and spks > 1: + assert global_channels > 0 + self.spks = spks + self.global_emb = nn.Embedding(spks, global_channels) + self.spk_embed_dim = None + if spk_embed_dim is not None and spk_embed_dim > 0: + assert global_channels > 0 + self.spk_embed_dim = spk_embed_dim + self.spemb_proj = nn.Linear(spk_embed_dim, global_channels) + self.langs = None + if langs is not None and langs > 1: + assert global_channels > 0 + self.langs = langs + self.lang_emb = nn.Embedding(langs, global_channels) + + # delayed import + from paddlespeech.t2s.models.vits.monotonic_align import maximum_path + + self.maximum_path = maximum_path + + def forward( + self, + text: paddle.Tensor, + text_lengths: paddle.Tensor, + feats: paddle.Tensor, + feats_lengths: paddle.Tensor, + sids: Optional[paddle.Tensor]=None, + spembs: Optional[paddle.Tensor]=None, + lids: Optional[paddle.Tensor]=None, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, + paddle.Tensor, paddle.Tensor, + Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, + paddle.Tensor, paddle.Tensor, ], ]: + """Calculate forward propagation. + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, aux_channels, T_feats). + feats_lengths (Tensor): Feature length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + Returns: + Tensor: Waveform tensor (B, 1, segment_size * upsample_factor). + Tensor: Duration negative log-likelihood (NLL) tensor (B,). + Tensor: Monotonic attention weight tensor (B, 1, T_feats, T_text). + Tensor: Segments start index tensor (B,). + Tensor: Text mask tensor (B, 1, T_text). + Tensor: Feature mask tensor (B, 1, T_feats). + tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + - Tensor: Posterior encoder hidden representation (B, H, T_feats). + - Tensor: Flow hidden representation (B, H, T_feats). + - Tensor: Expanded text encoder projected mean (B, H, T_feats). + - Tensor: Expanded text encoder projected scale (B, H, T_feats). + - Tensor: Posterior encoder projected mean (B, H, T_feats). + - Tensor: Posterior encoder projected scale (B, H, T_feats). + """ + # forward text encoder + x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths) + + # calculate global conditioning + g = None + if self.spks is not None: + # speaker one-hot vector embedding: (B, global_channels, 1) + g = self.global_emb(paddle.reshape(sids, [-1])).unsqueeze(-1) + if self.spk_embed_dim is not None: + # pretreined speaker embedding, e.g., X-vector (B, global_channels, 1) + g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + if self.langs is not None: + # language one-hot vector embedding: (B, global_channels, 1) + g_ = self.lang_emb(paddle.reshape(lids, [-1])).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + + # forward posterior encoder + z, m_q, logs_q, y_mask = self.posterior_encoder( + feats, feats_lengths, g=g) + + # forward flow + # (B, H, T_feats) + z_p = self.flow(z, y_mask, g=g) + + # monotonic alignment search + with paddle.no_grad(): + # negative cross-entropy + # (B, H, T_text) + s_p_sq_r = paddle.exp(-2 * logs_p) + # (B, 1, T_text) + neg_x_ent_1 = paddle.sum( + -0.5 * math.log(2 * math.pi) - logs_p, + [1], + keepdim=True, ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_2 = paddle.matmul( + -0.5 * (z_p**2).transpose([0, 2, 1]), + s_p_sq_r, ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_3 = paddle.matmul( + z_p.transpose([0, 2, 1]), + (m_p * s_p_sq_r), ) + # (B, 1, T_text) + neg_x_ent_4 = paddle.sum( + -0.5 * (m_p**2) * s_p_sq_r, + [1], + keepdim=True, ) + # (B, T_feats, T_text) + neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4 + # (B, 1, T_feats, T_text) + attn_mask = paddle.unsqueeze(x_mask, 2) * paddle.unsqueeze(y_mask, + -1) + # monotonic attention weight: (B, 1, T_feats, T_text) + attn = (self.maximum_path( + neg_x_ent, + attn_mask.squeeze(1), ).unsqueeze(1).detach()) + + # forward duration predictor + # (B, 1, T_text) + w = attn.sum(2) + dur_nll = self.duration_predictor(x, x_mask, w=w, g=g) + dur_nll = dur_nll / paddle.sum(x_mask) + + # expand the length to match with the feature sequence + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + m_p = paddle.matmul(attn.squeeze(1), + m_p.transpose([0, 2, 1])).transpose([0, 2, 1]) + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + logs_p = paddle.matmul(attn.squeeze(1), + logs_p.transpose([0, 2, 1])).transpose([0, 2, 1]) + + # get random segments + z_segments, z_start_idxs = get_random_segments( + z, + feats_lengths, + self.segment_size, ) + + # forward decoder with random segments + wav = self.decoder(z_segments, g=g) + + return (wav, dur_nll, attn, z_start_idxs, x_mask, y_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), ) + + def inference( + self, + text: paddle.Tensor, + text_lengths: paddle.Tensor, + feats: Optional[paddle.Tensor]=None, + feats_lengths: Optional[paddle.Tensor]=None, + sids: Optional[paddle.Tensor]=None, + spembs: Optional[paddle.Tensor]=None, + lids: Optional[paddle.Tensor]=None, + dur: Optional[paddle.Tensor]=None, + noise_scale: float=0.667, + noise_scale_dur: float=0.8, + alpha: float=1.0, + max_len: Optional[int]=None, + use_teacher_forcing: bool=False, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Run inference. + Args: + text (Tensor): Input text index tensor (B, T_text,). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, aux_channels, T_feats,). + feats_lengths (Tensor): Feature length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + dur (Optional[Tensor]): Ground-truth duration (B, T_text,). If provided, + skip the prediction of durations (i.e., teacher forcing). + noise_scale (float): Noise scale parameter for flow. + noise_scale_dur (float): Noise scale parameter for duration predictor. + alpha (float): Alpha parameter to control the speed of generated speech. + max_len (Optional[int]): Maximum length of acoustic feature sequence. + use_teacher_forcing (bool): Whether to use teacher forcing. + Returns: + Tensor: Generated waveform tensor (B, T_wav). + Tensor: Monotonic attention weight tensor (B, T_feats, T_text). + Tensor: Duration tensor (B, T_text). + """ + # encoder + x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths) + g = None + if self.spks is not None: + # (B, global_channels, 1) + g = self.global_emb(paddle.reshape(sids, [-1])).unsqueeze(-1) + if self.spk_embed_dim is not None: + # (B, global_channels, 1) + g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + if self.langs is not None: + # (B, global_channels, 1) + g_ = self.lang_emb(paddle.reshape(lids, [-1])).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + + if use_teacher_forcing: + # forward posterior encoder + z, m_q, logs_q, y_mask = self.posterior_encoder( + feats, feats_lengths, g=g) + + # forward flow + # (B, H, T_feats) + z_p = self.flow(z, y_mask, g=g) + + # monotonic alignment search + # (B, H, T_text) + s_p_sq_r = paddle.exp(-2 * logs_p) + # (B, 1, T_text) + neg_x_ent_1 = paddle.sum( + -0.5 * math.log(2 * math.pi) - logs_p, + [1], + keepdim=True, ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_2 = paddle.matmul( + -0.5 * (z_p**2).transpose([0, 2, 1]), + s_p_sq_r, ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_3 = paddle.matmul( + z_p.transpose([0, 2, 1]), + (m_p * s_p_sq_r), ) + # (B, 1, T_text) + neg_x_ent_4 = paddle.sum( + -0.5 * (m_p**2) * s_p_sq_r, + [1], + keepdim=True, ) + # (B, T_feats, T_text) + neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4 + # (B, 1, T_feats, T_text) + attn_mask = paddle.unsqueeze(x_mask, 2) * paddle.unsqueeze(y_mask, + -1) + # monotonic attention weight: (B, 1, T_feats, T_text) + attn = self.maximum_path( + neg_x_ent, + attn_mask.squeeze(1), ).unsqueeze(1) + # (B, 1, T_text) + dur = attn.sum(2) + + # forward decoder with random segments + wav = self.decoder(z * y_mask, g=g) + else: + # duration + if dur is None: + logw = self.duration_predictor( + x, + x_mask, + g=g, + inverse=True, + noise_scale=noise_scale_dur, ) + w = paddle.exp(logw) * x_mask * alpha + dur = paddle.ceil(w) + y_lengths = paddle.cast( + paddle.clip(paddle.sum(dur, [1, 2]), min=1), dtype='int64') + y_mask = make_non_pad_mask(y_lengths).unsqueeze(1) + attn_mask = paddle.unsqueeze(x_mask, 2) * paddle.unsqueeze(y_mask, + -1) + attn = self._generate_path(dur, attn_mask) + + # expand the length to match with the feature sequence + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + m_p = paddle.matmul( + attn.squeeze(1), + m_p.transpose([0, 2, 1]), ).transpose([0, 2, 1]) + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + logs_p = paddle.matmul( + attn.squeeze(1), + logs_p.transpose([0, 2, 1]), ).transpose([0, 2, 1]) + + # decoder + z_p = m_p + paddle.randn( + paddle.shape(m_p)) * paddle.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=g, inverse=True) + wav = self.decoder((z * y_mask)[:, :, :max_len], g=g) + + return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1) + + def _generate_path(self, dur: paddle.Tensor, + mask: paddle.Tensor) -> paddle.Tensor: + """Generate path a.k.a. monotonic attention. + Args: + dur (Tensor): Duration tensor (B, 1, T_text). + mask (Tensor): Attention mask tensor (B, 1, T_feats, T_text). + Returns: + Tensor: Path tensor (B, 1, T_feats, T_text). + """ + b, _, t_y, t_x = paddle.shape(mask) + cum_dur = paddle.cumsum(dur, -1) + cum_dur_flat = paddle.reshape(cum_dur, [b * t_x]) + + path = paddle.arange(t_y, dtype=dur.dtype) + path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1) + path = paddle.reshape(path, [b, t_x, t_y]) + ''' + path will be like (t_x = 3, t_y = 5): + [[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.], + [1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.], + [1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]] + ''' + + path = paddle.cast(path, dtype='float32') + path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1] + return path.unsqueeze(1).transpose([0, 1, 3, 2]) * mask diff --git a/paddlespeech/t2s/models/vits/monotonic_align/__init__.py b/paddlespeech/t2s/models/vits/monotonic_align/__init__.py new file mode 100644 index 000000000..3aa47ed72 --- /dev/null +++ b/paddlespeech/t2s/models/vits/monotonic_align/__init__.py @@ -0,0 +1,94 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Maximum path calculation module. + +This code is based on https://github.com/jaywalnut310/vits. + +""" +import warnings + +import numpy as np +import paddle +from numba import njit +from numba import prange + +try: + from .core import maximum_path_c + + is_cython_avalable = True +except ImportError: + is_cython_avalable = False + warnings.warn( + "Cython version is not available. Fallback to 'EXPERIMETAL' numba version. " + "If you want to use the cython version, please build it as follows: " + "`cd paddlespeech/t2s/models/vits/monotonic_align; python setup.py build_ext --inplace`" + ) + + +def maximum_path(neg_x_ent: paddle.Tensor, + attn_mask: paddle.Tensor) -> paddle.Tensor: + """Calculate maximum path. + + Args: + neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text). + attn_mask (Tensor): Attention mask (B, T_feats, T_text). + + Returns: + Tensor: Maximum path tensor (B, T_feats, T_text). + + """ + dtype = neg_x_ent.dtype + neg_x_ent = neg_x_ent.numpy().astype(np.float32) + path = np.zeros(neg_x_ent.shape, dtype=np.int32) + t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32) + t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32) + if is_cython_avalable: + maximum_path_c(path, neg_x_ent, t_t_max, t_s_max) + else: + maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max) + + return paddle.cast(paddle.to_tensor(path), dtype=dtype) + + +@njit +def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf): + """Calculate a single maximum path with numba.""" + index = t_x - 1 + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[y - 1, x] + if x == 0: + if y == 0: + v_prev = 0.0 + else: + v_prev = max_neg_val + else: + v_prev = value[y - 1, x - 1] + value[y, x] += max(v_prev, v_cur) + + for y in range(t_y - 1, -1, -1): + path[y, index] = 1 + if index != 0 and (index == y or + value[y - 1, index] < value[y - 1, index - 1]): + index = index - 1 + + +@njit(parallel=True) +def maximum_path_numba(paths, values, t_ys, t_xs): + """Calculate batch maximum path with numba.""" + for i in prange(paths.shape[0]): + maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i]) diff --git a/paddlespeech/t2s/models/vits/monotonic_align/core.pyx b/paddlespeech/t2s/models/vits/monotonic_align/core.pyx new file mode 100644 index 000000000..5a573dc74 --- /dev/null +++ b/paddlespeech/t2s/models/vits/monotonic_align/core.pyx @@ -0,0 +1,62 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Maximum path calculation module with cython optimization. + +This code is copied from https://github.com/jaywalnut310/vits and modifed code format. + +""" + +cimport cython + +from cython.parallel import prange + + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil: + cdef int x + cdef int y + cdef float v_prev + cdef float v_cur + cdef float tmp + cdef int index = t_x - 1 + + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[y - 1, x] + if x == 0: + if y == 0: + v_prev = 0.0 + else: + v_prev = max_neg_val + else: + v_prev = value[y - 1, x - 1] + value[y, x] += max(v_prev, v_cur) + + for y in range(t_y - 1, -1, -1): + path[y, index] = 1 + if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): + index = index - 1 + + +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil: + cdef int b = paths.shape[0] + cdef int i + for i in prange(b, nogil=True): + maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i]) diff --git a/paddlespeech/t2s/models/vits/monotonic_align/setup.py b/paddlespeech/t2s/models/vits/monotonic_align/setup.py new file mode 100644 index 000000000..8df03ab12 --- /dev/null +++ b/paddlespeech/t2s/models/vits/monotonic_align/setup.py @@ -0,0 +1,39 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Setup cython code.""" +from Cython.Build import cythonize +from setuptools import Extension +from setuptools import setup +from setuptools.command.build_ext import build_ext as _build_ext + + +class build_ext(_build_ext): + """Overwrite build_ext.""" + + def finalize_options(self): + """Prevent numpy from thinking it is still in its setup process.""" + _build_ext.finalize_options(self) + __builtins__.__NUMPY_SETUP__ = False + import numpy + + self.include_dirs.append(numpy.get_include()) + + +exts = [Extension( + name="core", + sources=["core.pyx"], )] +setup( + name="monotonic_align", + ext_modules=cythonize(exts, language_level=3), + cmdclass={"build_ext": build_ext}, ) diff --git a/paddlespeech/t2s/models/vits/posterior_encoder.py b/paddlespeech/t2s/models/vits/posterior_encoder.py new file mode 100644 index 000000000..853237557 --- /dev/null +++ b/paddlespeech/t2s/models/vits/posterior_encoder.py @@ -0,0 +1,120 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Text encoder module in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" +from typing import Optional +from typing import Tuple + +import paddle +from paddle import nn + +from paddlespeech.t2s.models.vits.wavenet.wavenet import WaveNet +from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask + + +class PosteriorEncoder(nn.Layer): + """Posterior encoder module in VITS. + + This is a module of posterior encoder described in `Conditional Variational + Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_. + + .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`: https://arxiv.org/abs/2006.04558 + """ + + def __init__( + self, + in_channels: int=513, + out_channels: int=192, + hidden_channels: int=192, + kernel_size: int=5, + layers: int=16, + stacks: int=1, + base_dilation: int=1, + global_channels: int=-1, + dropout_rate: float=0.0, + bias: bool=True, + use_weight_norm: bool=True, ): + """Initilialize PosteriorEncoder module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size in WaveNet. + layers (int): Number of layers of WaveNet. + stacks (int): Number of repeat stacking of WaveNet. + base_dilation (int): Base dilation factor. + global_channels (int): Number of global conditioning channels. + dropout_rate (float): Dropout rate. + bias (bool): Whether to use bias parameters in conv. + use_weight_norm (bool): Whether to apply weight norm. + + """ + super().__init__() + + # define modules + self.input_conv = nn.Conv1D(in_channels, hidden_channels, 1) + self.encoder = WaveNet( + in_channels=-1, + out_channels=-1, + kernel_size=kernel_size, + layers=layers, + stacks=stacks, + base_dilation=base_dilation, + residual_channels=hidden_channels, + aux_channels=-1, + gate_channels=hidden_channels * 2, + skip_channels=hidden_channels, + global_channels=global_channels, + dropout_rate=dropout_rate, + bias=bias, + use_weight_norm=use_weight_norm, + use_first_conv=False, + use_last_conv=False, + scale_residual=False, + scale_skip_connect=True, ) + self.proj = nn.Conv1D(hidden_channels, out_channels * 2, 1) + + def forward( + self, + x: paddle.Tensor, + x_lengths: paddle.Tensor, + g: Optional[paddle.Tensor]=None + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T_feats). + x_lengths (Tensor): Length tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + + Returns: + Tensor: Encoded hidden representation tensor (B, out_channels, T_feats). + Tensor: Projected mean tensor (B, out_channels, T_feats). + Tensor: Projected scale tensor (B, out_channels, T_feats). + Tensor: Mask tensor for input tensor (B, 1, T_feats). + + """ + x_mask = make_non_pad_mask(x_lengths).unsqueeze(1) + x = self.input_conv(x) * x_mask + x = self.encoder(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = paddle.split(stats, 2, axis=1) + z = (m + paddle.randn(paddle.shape(m)) * paddle.exp(logs)) * x_mask + + return z, m, logs, x_mask diff --git a/paddlespeech/t2s/models/vits/residual_coupling.py b/paddlespeech/t2s/models/vits/residual_coupling.py new file mode 100644 index 000000000..c18beedd0 --- /dev/null +++ b/paddlespeech/t2s/models/vits/residual_coupling.py @@ -0,0 +1,242 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Residual affine coupling modules in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" +from typing import Optional +from typing import Tuple +from typing import Union + +import paddle +from paddle import nn + +from paddlespeech.t2s.models.vits.flow import FlipFlow +from paddlespeech.t2s.models.vits.wavenet.wavenet import WaveNet + + +class ResidualAffineCouplingBlock(nn.Layer): + """Residual affine coupling block module. + + This is a module of residual affine coupling block, which used as "Flow" in + `Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`_. + + .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`: https://arxiv.org/abs/2006.04558 + + """ + + def __init__( + self, + in_channels: int=192, + hidden_channels: int=192, + flows: int=4, + kernel_size: int=5, + base_dilation: int=1, + layers: int=4, + global_channels: int=-1, + dropout_rate: float=0.0, + use_weight_norm: bool=True, + bias: bool=True, + use_only_mean: bool=True, ): + """Initilize ResidualAffineCouplingBlock module. + + Args: + in_channels (int): Number of input channels. + hidden_channels (int): Number of hidden channels. + flows (int): Number of flows. + kernel_size (int): Kernel size for WaveNet. + base_dilation (int): Base dilation factor for WaveNet. + layers (int): Number of layers of WaveNet. + stacks (int): Number of stacks of WaveNet. + global_channels (int): Number of global channels. + dropout_rate (float): Dropout rate. + use_weight_norm (bool): Whether to use weight normalization in WaveNet. + bias (bool): Whether to use bias paramters in WaveNet. + use_only_mean (bool): Whether to estimate only mean. + + """ + super().__init__() + + self.flows = nn.LayerList() + for i in range(flows): + self.flows.append( + ResidualAffineCouplingLayer( + in_channels=in_channels, + hidden_channels=hidden_channels, + kernel_size=kernel_size, + base_dilation=base_dilation, + layers=layers, + stacks=1, + global_channels=global_channels, + dropout_rate=dropout_rate, + use_weight_norm=use_weight_norm, + bias=bias, + use_only_mean=use_only_mean, )) + self.flows.append(FlipFlow()) + + def forward( + self, + x: paddle.Tensor, + x_mask: paddle.Tensor, + g: Optional[paddle.Tensor]=None, + inverse: bool=False, ) -> paddle.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T). + x_mask (Tensor): Length tensor (B, 1, T). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Output tensor (B, in_channels, T). + + """ + if not inverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, inverse=inverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, inverse=inverse) + return x + + +class ResidualAffineCouplingLayer(nn.Layer): + """Residual affine coupling layer.""" + + def __init__( + self, + in_channels: int=192, + hidden_channels: int=192, + kernel_size: int=5, + base_dilation: int=1, + layers: int=5, + stacks: int=1, + global_channels: int=-1, + dropout_rate: float=0.0, + use_weight_norm: bool=True, + bias: bool=True, + use_only_mean: bool=True, ): + """Initialzie ResidualAffineCouplingLayer module. + + Args: + in_channels (int): Number of input channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size for WaveNet. + base_dilation (int): Base dilation factor for WaveNet. + layers (int): Number of layers of WaveNet. + stacks (int): Number of stacks of WaveNet. + global_channels (int): Number of global channels. + dropout_rate (float): Dropout rate. + use_weight_norm (bool): Whether to use weight normalization in WaveNet. + bias (bool): Whether to use bias paramters in WaveNet. + use_only_mean (bool): Whether to estimate only mean. + + """ + assert in_channels % 2 == 0, "in_channels should be divisible by 2" + super().__init__() + self.half_channels = in_channels // 2 + self.use_only_mean = use_only_mean + + # define modules + self.input_conv = nn.Conv1D( + self.half_channels, + hidden_channels, + 1, ) + self.encoder = WaveNet( + in_channels=-1, + out_channels=-1, + kernel_size=kernel_size, + layers=layers, + stacks=stacks, + base_dilation=base_dilation, + residual_channels=hidden_channels, + aux_channels=-1, + gate_channels=hidden_channels * 2, + skip_channels=hidden_channels, + global_channels=global_channels, + dropout_rate=dropout_rate, + bias=bias, + use_weight_norm=use_weight_norm, + use_first_conv=False, + use_last_conv=False, + scale_residual=False, + scale_skip_connect=True, ) + if use_only_mean: + self.proj = nn.Conv1D( + hidden_channels, + self.half_channels, + 1, ) + else: + self.proj = nn.Conv1D( + hidden_channels, + self.half_channels * 2, + 1, ) + + weight = paddle.zeros(paddle.shape(self.proj.weight)) + + self.proj.weight = paddle.create_parameter( + shape=weight.shape, + dtype=str(weight.numpy().dtype), + default_initializer=paddle.nn.initializer.Assign(weight)) + + bias = paddle.zeros(paddle.shape(self.proj.bias)) + + self.proj.bias = paddle.create_parameter( + shape=bias.shape, + dtype=str(bias.numpy().dtype), + default_initializer=paddle.nn.initializer.Assign(bias)) + + def forward( + self, + x: paddle.Tensor, + x_mask: paddle.Tensor, + g: Optional[paddle.Tensor]=None, + inverse: bool=False, + ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T). + x_lengths (Tensor): Length tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Output tensor (B, in_channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + xa, xb = paddle.split(x, 2, axis=1) + h = self.input_conv(xa) * x_mask + h = self.encoder(h, x_mask, g=g) + stats = self.proj(h) * x_mask + if not self.use_only_mean: + m, logs = paddle.split(stats, 2, axis=1) + else: + m = stats + logs = paddle.zeros(paddle.shape(m)) + + if not inverse: + xb = m + xb * paddle.exp(logs) * x_mask + x = paddle.concat([xa, xb], 1) + logdet = paddle.sum(logs, [1, 2]) + return x, logdet + else: + xb = (xb - m) * paddle.exp(-logs) * x_mask + x = paddle.concat([xa, xb], 1) + return x diff --git a/paddlespeech/t2s/models/vits/text_encoder.py b/paddlespeech/t2s/models/vits/text_encoder.py new file mode 100644 index 000000000..3afc7831a --- /dev/null +++ b/paddlespeech/t2s/models/vits/text_encoder.py @@ -0,0 +1,145 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Text encoder module in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" +import math +from typing import Tuple + +import paddle +from paddle import nn + +from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask +from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder as Encoder + + +class TextEncoder(nn.Layer): + """Text encoder module in VITS. + + This is a module of text encoder described in `Conditional Variational Autoencoder + with Adversarial Learning for End-to-End Text-to-Speech`_. + + Instead of the relative positional Transformer, we use conformer architecture as + the encoder module, which contains additional convolution layers. + + .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`: https://arxiv.org/abs/2006.04558 + + """ + + def __init__( + self, + vocabs: int, + attention_dim: int=192, + attention_heads: int=2, + linear_units: int=768, + blocks: int=6, + positionwise_layer_type: str="conv1d", + positionwise_conv_kernel_size: int=3, + positional_encoding_layer_type: str="rel_pos", + self_attention_layer_type: str="rel_selfattn", + activation_type: str="swish", + normalize_before: bool=True, + use_macaron_style: bool=False, + use_conformer_conv: bool=False, + conformer_kernel_size: int=7, + dropout_rate: float=0.1, + positional_dropout_rate: float=0.0, + attention_dropout_rate: float=0.0, ): + """Initialize TextEncoder module. + + Args: + vocabs (int): Vocabulary size. + attention_dim (int): Attention dimension. + attention_heads (int): Number of attention heads. + linear_units (int): Number of linear units of positionwise layers. + blocks (int): Number of encoder blocks. + positionwise_layer_type (str): Positionwise layer type. + positionwise_conv_kernel_size (int): Positionwise layer's kernel size. + positional_encoding_layer_type (str): Positional encoding layer type. + self_attention_layer_type (str): Self-attention layer type. + activation_type (str): Activation function type. + normalize_before (bool): Whether to apply LayerNorm before attention. + use_macaron_style (bool): Whether to use macaron style components. + use_conformer_conv (bool): Whether to use conformer conv layers. + conformer_kernel_size (int): Conformer's conv kernel size. + dropout_rate (float): Dropout rate. + positional_dropout_rate (float): Dropout rate for positional encoding. + attention_dropout_rate (float): Dropout rate for attention. + + """ + super().__init__() + # store for forward + self.attention_dim = attention_dim + + # define modules + self.emb = nn.Embedding(vocabs, attention_dim) + + dist = paddle.distribution.Normal(loc=0.0, scale=attention_dim**-0.5) + w = dist.sample(self.emb.weight.shape) + self.emb.weight.set_value(w) + + self.encoder = Encoder( + idim=-1, + input_layer=None, + attention_dim=attention_dim, + attention_heads=attention_heads, + linear_units=linear_units, + num_blocks=blocks, + dropout_rate=dropout_rate, + positional_dropout_rate=positional_dropout_rate, + attention_dropout_rate=attention_dropout_rate, + normalize_before=normalize_before, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + macaron_style=use_macaron_style, + pos_enc_layer_type=positional_encoding_layer_type, + selfattention_layer_type=self_attention_layer_type, + activation_type=activation_type, + use_cnn_module=use_conformer_conv, + cnn_module_kernel=conformer_kernel_size, ) + self.proj = nn.Conv1D(attention_dim, attention_dim * 2, 1) + + def forward( + self, + x: paddle.Tensor, + x_lengths: paddle.Tensor, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input index tensor (B, T_text). + x_lengths (Tensor): Length tensor (B,). + + Returns: + Tensor: Encoded hidden representation (B, attention_dim, T_text). + Tensor: Projected mean tensor (B, attention_dim, T_text). + Tensor: Projected scale tensor (B, attention_dim, T_text). + Tensor: Mask tensor for input tensor (B, 1, T_text). + + """ + x = self.emb(x) * math.sqrt(self.attention_dim) + x_mask = make_non_pad_mask(x_lengths).unsqueeze(1) + # encoder assume the channel last (B, T_text, attention_dim) + # but mask shape shoud be (B, 1, T_text) + x, _ = self.encoder(x, x_mask) + + # convert the channel first (B, attention_dim, T_text) + x = paddle.transpose(x, [0, 2, 1]) + stats = self.proj(x) * x_mask + m, logs = paddle.split(stats, 2, axis=1) + + return x, m, logs, x_mask diff --git a/paddlespeech/t2s/models/vits/transform.py b/paddlespeech/t2s/models/vits/transform.py new file mode 100644 index 000000000..fec80377c --- /dev/null +++ b/paddlespeech/t2s/models/vits/transform.py @@ -0,0 +1,238 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flow-related transformation. + +This code is based on https://github.com/bayesiains/nflows. + +""" +import numpy as np +import paddle +from paddle.nn import functional as F + +from paddlespeech.t2s.modules.nets_utils import paddle_gather + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +def piecewise_rational_quadratic_transform( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, ): + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = {"tails": tails, "tail_bound": tail_bound} + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs) + return outputs, logabsdet + + +def mask_preprocess(x, mask): + B, C, T, bins = paddle.shape(x) + new_x = paddle.zeros([mask.sum(), bins]) + for i in range(bins): + new_x[:, i] = x[:, :, :, i][mask] + return new_x + + +def unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails="linear", + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, ): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = paddle.zeros(paddle.shape(inputs)) + logabsdet = paddle.zeros(paddle.shape(inputs)) + if tails == "linear": + unnormalized_derivatives = F.pad( + unnormalized_derivatives, + pad=[0] * (len(unnormalized_derivatives.shape) - 1) * 2 + [1, 1]) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError("{} tails are not implemented.".format(tails)) + + unnormalized_widths = mask_preprocess(unnormalized_widths, + inside_interval_mask) + unnormalized_heights = mask_preprocess(unnormalized_heights, + inside_interval_mask) + unnormalized_derivatives = mask_preprocess(unnormalized_derivatives, + inside_interval_mask) + + (outputs[inside_interval_mask], + logabsdet[inside_interval_mask], ) = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, ) + + return outputs, logabsdet + + +def rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0.0, + right=1.0, + bottom=0.0, + top=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, ): + if paddle.min(inputs) < left or paddle.max(inputs) > right: + raise ValueError("Input to a transform is not within its domain") + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError("Minimal bin width too large for the number of bins") + if min_bin_height * num_bins > 1.0: + raise ValueError("Minimal bin height too large for the number of bins") + + widths = F.softmax(unnormalized_widths, axis=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = paddle.cumsum(widths, axis=-1) + cumwidths = F.pad( + cumwidths, + pad=[0] * (len(cumwidths.shape) - 1) * 2 + [1, 0], + mode="constant", + value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, axis=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = paddle.cumsum(heights, axis=-1) + cumheights = F.pad( + cumheights, + pad=[0] * (len(cumheights.shape) - 1) * 2 + [1, 0], + mode="constant", + value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = _searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = _searchsorted(cumwidths, inputs)[..., None] + input_cumwidths = paddle_gather(cumwidths, -1, bin_idx)[..., 0] + input_bin_widths = paddle_gather(widths, -1, bin_idx)[..., 0] + + input_cumheights = paddle_gather(cumheights, -1, bin_idx)[..., 0] + delta = heights / widths + input_delta = paddle_gather(delta, -1, bin_idx)[..., 0] + + input_derivatives = paddle_gather(derivatives, -1, bin_idx)[..., 0] + input_derivatives_plus_one = paddle_gather(derivatives[..., 1:], -1, + bin_idx)[..., 0] + + input_heights = paddle_gather(heights, -1, bin_idx)[..., 0] + + if inverse: + a = (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta) + c = -input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - paddle.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) * theta_one_minus_theta) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + 2 * input_delta * + theta_one_minus_theta + input_derivatives * (1 - root).pow(2)) + logabsdet = paddle.log(derivative_numerator) - 2 * paddle.log( + denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * (input_delta * theta.pow(2) + + input_derivatives * theta_one_minus_theta) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) * theta_one_minus_theta) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + 2 * input_delta * + theta_one_minus_theta + input_derivatives * (1 - theta).pow(2)) + logabsdet = paddle.log(derivative_numerator) - 2 * paddle.log( + denominator) + + return outputs, logabsdet + + +def _searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return paddle.sum(inputs[..., None] >= bin_locations, axis=-1) - 1 diff --git a/paddlespeech/t2s/models/vits/vits.py b/paddlespeech/t2s/models/vits/vits.py new file mode 100644 index 000000000..ab8eda26d --- /dev/null +++ b/paddlespeech/t2s/models/vits/vits.py @@ -0,0 +1,412 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +"""VITS module""" +from typing import Any +from typing import Dict +from typing import Optional + +import paddle +from paddle import nn +from typeguard import check_argument_types + +from paddlespeech.t2s.models.hifigan import HiFiGANMultiPeriodDiscriminator +from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleDiscriminator +from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleMultiPeriodDiscriminator +from paddlespeech.t2s.models.hifigan import HiFiGANPeriodDiscriminator +from paddlespeech.t2s.models.hifigan import HiFiGANScaleDiscriminator +from paddlespeech.t2s.models.vits.generator import VITSGenerator +from paddlespeech.t2s.modules.nets_utils import initialize + +AVAILABLE_GENERATERS = { + "vits_generator": VITSGenerator, +} +AVAILABLE_DISCRIMINATORS = { + "hifigan_period_discriminator": + HiFiGANPeriodDiscriminator, + "hifigan_scale_discriminator": + HiFiGANScaleDiscriminator, + "hifigan_multi_period_discriminator": + HiFiGANMultiPeriodDiscriminator, + "hifigan_multi_scale_discriminator": + HiFiGANMultiScaleDiscriminator, + "hifigan_multi_scale_multi_period_discriminator": + HiFiGANMultiScaleMultiPeriodDiscriminator, +} + + +class VITS(nn.Layer): + """VITS module (generator + discriminator). + This is a module of VITS described in `Conditional Variational Autoencoder + with Adversarial Learning for End-to-End Text-to-Speech`_. + .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`: https://arxiv.org/abs/2006.04558 + """ + + def __init__( + self, + # generator related + idim: int, + odim: int, + sampling_rate: int=22050, + generator_type: str="vits_generator", + generator_params: Dict[str, Any]={ + "hidden_channels": 192, + "spks": None, + "langs": None, + "spk_embed_dim": None, + "global_channels": -1, + "segment_size": 32, + "text_encoder_attention_heads": 2, + "text_encoder_ffn_expand": 4, + "text_encoder_blocks": 6, + "text_encoder_positionwise_layer_type": "conv1d", + "text_encoder_positionwise_conv_kernel_size": 1, + "text_encoder_positional_encoding_layer_type": "rel_pos", + "text_encoder_self_attention_layer_type": "rel_selfattn", + "text_encoder_activation_type": "swish", + "text_encoder_normalize_before": True, + "text_encoder_dropout_rate": 0.1, + "text_encoder_positional_dropout_rate": 0.0, + "text_encoder_attention_dropout_rate": 0.0, + "text_encoder_conformer_kernel_size": 7, + "use_macaron_style_in_text_encoder": True, + "use_conformer_conv_in_text_encoder": True, + "decoder_kernel_size": 7, + "decoder_channels": 512, + "decoder_upsample_scales": [8, 8, 2, 2], + "decoder_upsample_kernel_sizes": [16, 16, 4, 4], + "decoder_resblock_kernel_sizes": [3, 7, 11], + "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "use_weight_norm_in_decoder": True, + "posterior_encoder_kernel_size": 5, + "posterior_encoder_layers": 16, + "posterior_encoder_stacks": 1, + "posterior_encoder_base_dilation": 1, + "posterior_encoder_dropout_rate": 0.0, + "use_weight_norm_in_posterior_encoder": True, + "flow_flows": 4, + "flow_kernel_size": 5, + "flow_base_dilation": 1, + "flow_layers": 4, + "flow_dropout_rate": 0.0, + "use_weight_norm_in_flow": True, + "use_only_mean_in_flow": True, + "stochastic_duration_predictor_kernel_size": 3, + "stochastic_duration_predictor_dropout_rate": 0.5, + "stochastic_duration_predictor_flows": 4, + "stochastic_duration_predictor_dds_conv_layers": 3, + }, + # discriminator related + discriminator_type: str="hifigan_multi_scale_multi_period_discriminator", + discriminator_params: Dict[str, Any]={ + "scales": 1, + "scale_downsample_pooling": "AvgPool1D", + "scale_downsample_pooling_params": { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + "scale_discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 128, + "max_downsample_channels": 1024, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 4, 4, 1], + "nonlinear_activation": "leakyrelu", + "nonlinear_activation_params": { + "negative_slope": 0.1 + }, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + "follow_official_norm": False, + "periods": [2, 3, 5, 7, 11], + "period_discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 32, + "downsample_scales": [3, 3, 3, 3, 1], + "max_downsample_channels": 1024, + "bias": True, + "nonlinear_activation": "leakyrelu", + "nonlinear_activation_params": { + "negative_slope": 0.1 + }, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + }, + cache_generator_outputs: bool=True, + init_type: str="xavier_uniform", ): + """Initialize VITS module. + Args: + idim (int): Input vocabrary size. + odim (int): Acoustic feature dimension. The actual output channels will + be 1 since VITS is the end-to-end text-to-wave model but for the + compatibility odim is used to indicate the acoustic feature dimension. + sampling_rate (int): Sampling rate, not used for the training but it will + be referred in saving waveform during the inference. + generator_type (str): Generator type. + generator_params (Dict[str, Any]): Parameter dict for generator. + discriminator_type (str): Discriminator type. + discriminator_params (Dict[str, Any]): Parameter dict for discriminator. + cache_generator_outputs (bool): Whether to cache generator outputs. + """ + assert check_argument_types() + super().__init__() + + # initialize parameters + initialize(self, init_type) + + # define modules + generator_class = AVAILABLE_GENERATERS[generator_type] + if generator_type == "vits_generator": + # NOTE: Update parameters for the compatibility. + # The idim and odim is automatically decided from input data, + # where idim represents #vocabularies and odim represents + # the input acoustic feature dimension. + generator_params.update(vocabs=idim, aux_channels=odim) + self.generator = generator_class( + **generator_params, ) + discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type] + self.discriminator = discriminator_class( + **discriminator_params, ) + + nn.initializer.set_global_initializer(None) + + # cache + self.cache_generator_outputs = cache_generator_outputs + self._cache = None + + # store sampling rate for saving wav file + # (not used for the training) + self.fs = sampling_rate + + # store parameters for test compatibility + self.spks = self.generator.spks + self.langs = self.generator.langs + self.spk_embed_dim = self.generator.spk_embed_dim + + self.reuse_cache_gen = True + self.reuse_cache_dis = True + + def forward( + self, + text: paddle.Tensor, + text_lengths: paddle.Tensor, + feats: paddle.Tensor, + feats_lengths: paddle.Tensor, + sids: Optional[paddle.Tensor]=None, + spembs: Optional[paddle.Tensor]=None, + lids: Optional[paddle.Tensor]=None, + forward_generator: bool=True, ) -> Dict[str, Any]: + """Perform generator forward. + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + forward_generator (bool): Whether to forward generator. + Returns: + Dict[str, Any]: + - loss (Tensor): Loss scalar tensor. + - stats (Dict[str, float]): Statistics to be monitored. + - weight (Tensor): Weight tensor to summarize losses. + - optim_idx (int): Optimizer index (0 for G and 1 for D). + """ + if forward_generator: + return self._forward_generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, ) + else: + return self._forward_discrminator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, ) + + def _forward_generator( + self, + text: paddle.Tensor, + text_lengths: paddle.Tensor, + feats: paddle.Tensor, + feats_lengths: paddle.Tensor, + sids: Optional[paddle.Tensor]=None, + spembs: Optional[paddle.Tensor]=None, + lids: Optional[paddle.Tensor]=None, ) -> Dict[str, Any]: + """Perform generator forward. + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + Returns: + + """ + # setup + feats = feats.transpose([0, 2, 1]) + + # calculate generator outputs + self.reuse_cache_gen = True + if not self.cache_generator_outputs or self._cache is None: + self.reuse_cache_gen = False + outs = self.generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, ) + else: + outs = self._cache + + # store cache + if self.training and self.cache_generator_outputs and not self.reuse_cache_gen: + self._cache = outs + + return outs + + def _forward_discrminator( + self, + text: paddle.Tensor, + text_lengths: paddle.Tensor, + feats: paddle.Tensor, + feats_lengths: paddle.Tensor, + sids: Optional[paddle.Tensor]=None, + spembs: Optional[paddle.Tensor]=None, + lids: Optional[paddle.Tensor]=None, ) -> Dict[str, Any]: + """Perform discriminator forward. + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + Returns: + + """ + # setup + feats = feats.transpose([0, 2, 1]) + + # calculate generator outputs + self.reuse_cache_dis = True + if not self.cache_generator_outputs or self._cache is None: + self.reuse_cache_dis = False + outs = self.generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, ) + else: + outs = self._cache + + # store cache + if self.cache_generator_outputs and not self.reuse_cache_dis: + self._cache = outs + + return outs + + def inference( + self, + text: paddle.Tensor, + feats: Optional[paddle.Tensor]=None, + sids: Optional[paddle.Tensor]=None, + spembs: Optional[paddle.Tensor]=None, + lids: Optional[paddle.Tensor]=None, + durations: Optional[paddle.Tensor]=None, + noise_scale: float=0.667, + noise_scale_dur: float=0.8, + alpha: float=1.0, + max_len: Optional[int]=None, + use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]: + """Run inference. + Args: + text (Tensor): Input text index tensor (T_text,). + feats (Tensor): Feature tensor (T_feats, aux_channels). + sids (Tensor): Speaker index tensor (1,). + spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,). + lids (Tensor): Language index tensor (1,). + durations (Tensor): Ground-truth duration tensor (T_text,). + noise_scale (float): Noise scale value for flow. + noise_scale_dur (float): Noise scale value for duration predictor. + alpha (float): Alpha parameter to control the speed of generated speech. + max_len (Optional[int]): Maximum length. + use_teacher_forcing (bool): Whether to use teacher forcing. + Returns: + Dict[str, Tensor]: + * wav (Tensor): Generated waveform tensor (T_wav,). + * att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text). + * duration (Tensor): Predicted duration tensor (T_text,). + """ + # setup + text = text[None] + text_lengths = paddle.to_tensor(paddle.shape(text)[1]) + + if durations is not None: + durations = paddle.reshape(durations, [1, 1, -1]) + + # inference + if use_teacher_forcing: + assert feats is not None + feats = feats[None].transpose([0, 2, 1]) + feats_lengths = paddle.to_tensor([paddle.shape(feats)[2]]) + wav, att_w, dur = self.generator.inference( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, + max_len=max_len, + use_teacher_forcing=use_teacher_forcing, ) + else: + wav, att_w, dur = self.generator.inference( + text=text, + text_lengths=text_lengths, + sids=sids, + spembs=spembs, + lids=lids, + dur=durations, + noise_scale=noise_scale, + noise_scale_dur=noise_scale_dur, + alpha=alpha, + max_len=max_len, ) + return dict( + wav=paddle.reshape(wav, [-1]), att_w=att_w[0], duration=dur[0]) diff --git a/paddlespeech/t2s/models/vits/vits_updater.py b/paddlespeech/t2s/models/vits/vits_updater.py new file mode 100644 index 000000000..76271fd97 --- /dev/null +++ b/paddlespeech/t2s/models/vits/vits_updater.py @@ -0,0 +1,355 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Dict + +import paddle +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.nn import Layer +from paddle.optimizer import Optimizer +from paddle.optimizer.lr import LRScheduler + +from paddlespeech.t2s.modules.nets_utils import get_segments +from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator +from paddlespeech.t2s.training.reporter import report +from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater +from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState + +logging.basicConfig( + format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s', + datefmt='[%Y-%m-%d %H:%M:%S]') +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class VITSUpdater(StandardUpdater): + def __init__(self, + model: Layer, + optimizers: Dict[str, Optimizer], + criterions: Dict[str, Layer], + schedulers: Dict[str, LRScheduler], + dataloader: DataLoader, + generator_train_start_steps: int=0, + discriminator_train_start_steps: int=100000, + lambda_adv: float=1.0, + lambda_mel: float=45.0, + lambda_feat_match: float=2.0, + lambda_dur: float=1.0, + lambda_kl: float=1.0, + generator_first: bool=False, + output_dir=None): + # it is designed to hold multiple models + # 因为输入的是单模型,但是没有用到父类的 init(), 所以需要重新写这部分 + models = {"main": model} + self.models: Dict[str, Layer] = models + # self.model = model + + self.model = model._layers if isinstance(model, + paddle.DataParallel) else model + + self.optimizers = optimizers + self.optimizer_g: Optimizer = optimizers['generator'] + self.optimizer_d: Optimizer = optimizers['discriminator'] + + self.criterions = criterions + self.criterion_mel = criterions['mel'] + self.criterion_feat_match = criterions['feat_match'] + self.criterion_gen_adv = criterions["gen_adv"] + self.criterion_dis_adv = criterions["dis_adv"] + self.criterion_kl = criterions["kl"] + + self.schedulers = schedulers + self.scheduler_g = schedulers['generator'] + self.scheduler_d = schedulers['discriminator'] + + self.dataloader = dataloader + + self.generator_train_start_steps = generator_train_start_steps + self.discriminator_train_start_steps = discriminator_train_start_steps + + self.lambda_adv = lambda_adv + self.lambda_mel = lambda_mel + self.lambda_feat_match = lambda_feat_match + self.lambda_dur = lambda_dur + self.lambda_kl = lambda_kl + + if generator_first: + self.turns = ["generator", "discriminator"] + else: + self.turns = ["discriminator", "generator"] + + self.state = UpdaterState(iteration=0, epoch=0) + self.train_iterator = iter(self.dataloader) + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def update_core(self, batch): + self.msg = "Rank: {}, ".format(dist.get_rank()) + losses_dict = {} + + for turn in self.turns: + speech = batch["speech"] + speech = speech.unsqueeze(1) + outs = self.model( + text=batch["text"], + text_lengths=batch["text_lengths"], + feats=batch["feats"], + feats_lengths=batch["feats_lengths"], + forward_generator=turn == "generator") + # Generator + if turn == "generator": + # parse outputs + speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs + _, z_p, m_p, logs_p, _, logs_q = outs_ + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * + self.model.generator.upsample_factor, + segment_size=self.model.generator.segment_size * + self.model.generator.upsample_factor, ) + + # calculate discriminator outputs + p_hat = self.model.discriminator(speech_hat_) + with paddle.no_grad(): + # do not store discriminator gradient in generator turn + p = self.model.discriminator(speech_) + + # calculate losses + mel_loss = self.criterion_mel(speech_hat_, speech_) + kl_loss = self.criterion_kl(z_p, logs_q, m_p, logs_p, z_mask) + dur_loss = paddle.sum(dur_nll) + adv_loss = self.criterion_gen_adv(p_hat) + feat_match_loss = self.criterion_feat_match(p_hat, p) + + mel_loss = mel_loss * self.lambda_mel + kl_loss = kl_loss * self.lambda_kl + dur_loss = dur_loss * self.lambda_dur + adv_loss = adv_loss * self.lambda_adv + feat_match_loss = feat_match_loss * self.lambda_feat_match + gen_loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss + + report("train/generator_loss", float(gen_loss)) + report("train/generator_mel_loss", float(mel_loss)) + report("train/generator_kl_loss", float(kl_loss)) + report("train/generator_dur_loss", float(dur_loss)) + report("train/generator_adv_loss", float(adv_loss)) + report("train/generator_feat_match_loss", + float(feat_match_loss)) + + losses_dict["generator_loss"] = float(gen_loss) + losses_dict["generator_mel_loss"] = float(mel_loss) + losses_dict["generator_kl_loss"] = float(kl_loss) + losses_dict["generator_dur_loss"] = float(dur_loss) + losses_dict["generator_adv_loss"] = float(adv_loss) + losses_dict["generator_feat_match_loss"] = float( + feat_match_loss) + + self.optimizer_g.clear_grad() + gen_loss.backward() + + self.optimizer_g.step() + self.scheduler_g.step() + + # reset cache + if self.model.reuse_cache_gen or not self.model.training: + self.model._cache = None + + # Disctiminator + elif turn == "discriminator": + # parse outputs + speech_hat_, _, _, start_idxs, *_ = outs + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * + self.model.generator.upsample_factor, + segment_size=self.model.generator.segment_size * + self.model.generator.upsample_factor, ) + + # calculate discriminator outputs + p_hat = self.model.discriminator(speech_hat_.detach()) + p = self.model.discriminator(speech_) + + # calculate losses + real_loss, fake_loss = self.criterion_dis_adv(p_hat, p) + dis_loss = real_loss + fake_loss + + report("train/real_loss", float(real_loss)) + report("train/fake_loss", float(fake_loss)) + report("train/discriminator_loss", float(dis_loss)) + losses_dict["real_loss"] = float(real_loss) + losses_dict["fake_loss"] = float(fake_loss) + losses_dict["discriminator_loss"] = float(dis_loss) + + self.optimizer_d.clear_grad() + dis_loss.backward() + + self.optimizer_d.step() + self.scheduler_d.step() + + # reset cache + if self.model.reuse_cache_dis or not self.model.training: + self.model._cache = None + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + + +class VITSEvaluator(StandardEvaluator): + def __init__(self, + model, + criterions: Dict[str, Layer], + dataloader: DataLoader, + lambda_adv: float=1.0, + lambda_mel: float=45.0, + lambda_feat_match: float=2.0, + lambda_dur: float=1.0, + lambda_kl: float=1.0, + generator_first: bool=False, + output_dir=None): + # 因为输入的是单模型,但是没有用到父类的 init(), 所以需要重新写这部分 + models = {"main": model} + self.models: Dict[str, Layer] = models + # self.model = model + self.model = model._layers if isinstance(model, + paddle.DataParallel) else model + + self.criterions = criterions + self.criterion_mel = criterions['mel'] + self.criterion_feat_match = criterions['feat_match'] + self.criterion_gen_adv = criterions["gen_adv"] + self.criterion_dis_adv = criterions["dis_adv"] + self.criterion_kl = criterions["kl"] + + self.dataloader = dataloader + + self.lambda_adv = lambda_adv + self.lambda_mel = lambda_mel + self.lambda_feat_match = lambda_feat_match + self.lambda_dur = lambda_dur + self.lambda_kl = lambda_kl + + if generator_first: + self.turns = ["generator", "discriminator"] + else: + self.turns = ["discriminator", "generator"] + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def evaluate_core(self, batch): + # logging.debug("Evaluate: ") + self.msg = "Evaluate: " + losses_dict = {} + + for turn in self.turns: + speech = batch["speech"] + speech = speech.unsqueeze(1) + outs = self.model( + text=batch["text"], + text_lengths=batch["text_lengths"], + feats=batch["feats"], + feats_lengths=batch["feats_lengths"], + forward_generator=turn == "generator") + # Generator + if turn == "generator": + # parse outputs + speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs + _, z_p, m_p, logs_p, _, logs_q = outs_ + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * + self.model.generator.upsample_factor, + segment_size=self.model.generator.segment_size * + self.model.generator.upsample_factor, ) + + # calculate discriminator outputs + p_hat = self.model.discriminator(speech_hat_) + with paddle.no_grad(): + # do not store discriminator gradient in generator turn + p = self.model.discriminator(speech_) + + # calculate losses + mel_loss = self.criterion_mel(speech_hat_, speech_) + kl_loss = self.criterion_kl(z_p, logs_q, m_p, logs_p, z_mask) + dur_loss = paddle.sum(dur_nll) + adv_loss = self.criterion_gen_adv(p_hat) + feat_match_loss = self.criterion_feat_match(p_hat, p) + + mel_loss = mel_loss * self.lambda_mel + kl_loss = kl_loss * self.lambda_kl + dur_loss = dur_loss * self.lambda_dur + adv_loss = adv_loss * self.lambda_adv + feat_match_loss = feat_match_loss * self.lambda_feat_match + gen_loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss + + report("eval/generator_loss", float(gen_loss)) + report("eval/generator_mel_loss", float(mel_loss)) + report("eval/generator_kl_loss", float(kl_loss)) + report("eval/generator_dur_loss", float(dur_loss)) + report("eval/generator_adv_loss", float(adv_loss)) + report("eval/generator_feat_match_loss", float(feat_match_loss)) + + losses_dict["generator_loss"] = float(gen_loss) + losses_dict["generator_mel_loss"] = float(mel_loss) + losses_dict["generator_kl_loss"] = float(kl_loss) + losses_dict["generator_dur_loss"] = float(dur_loss) + losses_dict["generator_adv_loss"] = float(adv_loss) + losses_dict["generator_feat_match_loss"] = float( + feat_match_loss) + + # reset cache + if self.model.reuse_cache_gen or not self.model.training: + self.model._cache = None + + # Disctiminator + elif turn == "discriminator": + # parse outputs + speech_hat_, _, _, start_idxs, *_ = outs + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * + self.model.generator.upsample_factor, + segment_size=self.model.generator.segment_size * + self.model.generator.upsample_factor, ) + + # calculate discriminator outputs + p_hat = self.model.discriminator(speech_hat_.detach()) + p = self.model.discriminator(speech_) + + # calculate losses + real_loss, fake_loss = self.criterion_dis_adv(p_hat, p) + dis_loss = real_loss + fake_loss + + report("eval/real_loss", float(real_loss)) + report("eval/fake_loss", float(fake_loss)) + report("eval/discriminator_loss", float(dis_loss)) + losses_dict["real_loss"] = float(real_loss) + losses_dict["fake_loss"] = float(fake_loss) + losses_dict["discriminator_loss"] = float(dis_loss) + + # reset cache + if self.model.reuse_cache_dis or not self.model.training: + self.model._cache = None + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + self.logger.info(self.msg) diff --git a/audio/tests/backends/__init__.py b/paddlespeech/t2s/models/vits/wavenet/__init__.py similarity index 100% rename from audio/tests/backends/__init__.py rename to paddlespeech/t2s/models/vits/wavenet/__init__.py diff --git a/paddlespeech/t2s/models/vits/wavenet/residual_block.py b/paddlespeech/t2s/models/vits/wavenet/residual_block.py new file mode 100644 index 000000000..197e74975 --- /dev/null +++ b/paddlespeech/t2s/models/vits/wavenet/residual_block.py @@ -0,0 +1,154 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +import math +from typing import Optional +from typing import Tuple + +import paddle +import paddle.nn.functional as F +from paddle import nn + + +class ResidualBlock(nn.Layer): + """Residual block module in WaveNet.""" + + def __init__( + self, + kernel_size: int=3, + residual_channels: int=64, + gate_channels: int=128, + skip_channels: int=64, + aux_channels: int=80, + global_channels: int=-1, + dropout_rate: float=0.0, + dilation: int=1, + bias: bool=True, + scale_residual: bool=False, ): + """Initialize ResidualBlock module. + + Args: + kernel_size (int): Kernel size of dilation convolution layer. + residual_channels (int): Number of channels for residual connection. + skip_channels (int): Number of channels for skip connection. + aux_channels (int): Number of local conditioning channels. + dropout (float): Dropout probability. + dilation (int): Dilation factor. + bias (bool): Whether to add bias parameter in convolution layers. + scale_residual (bool): Whether to scale the residual outputs. + + """ + super().__init__() + self.dropout_rate = dropout_rate + self.residual_channels = residual_channels + self.skip_channels = skip_channels + self.scale_residual = scale_residual + + # check + assert ( + kernel_size - 1) % 2 == 0, "Not support even number kernel size." + assert gate_channels % 2 == 0 + + # dilation conv + padding = (kernel_size - 1) // 2 * dilation + self.conv = nn.Conv1D( + residual_channels, + gate_channels, + kernel_size, + padding=padding, + dilation=dilation, + bias_attr=bias, ) + + # local conditioning + if aux_channels > 0: + self.conv1x1_aux = nn.Conv1D( + aux_channels, gate_channels, kernel_size=1, bias_attr=False) + else: + self.conv1x1_aux = None + + # global conditioning + if global_channels > 0: + self.conv1x1_glo = nn.Conv1D( + global_channels, gate_channels, kernel_size=1, bias_attr=False) + else: + self.conv1x1_glo = None + + # conv output is split into two groups + gate_out_channels = gate_channels // 2 + + # NOTE: concat two convs into a single conv for the efficiency + # (integrate res 1x1 + skip 1x1 convs) + self.conv1x1_out = nn.Conv1D( + gate_out_channels, + residual_channels + skip_channels, + kernel_size=1, + bias_attr=bias) + + def forward( + self, + x: paddle.Tensor, + x_mask: Optional[paddle.Tensor]=None, + c: Optional[paddle.Tensor]=None, + g: Optional[paddle.Tensor]=None, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, residual_channels, T). + x_mask Optional[paddle.Tensor]: Mask tensor (B, 1, T). + c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + + Returns: + Tensor: Output tensor for residual connection (B, residual_channels, T). + Tensor: Output tensor for skip connection (B, skip_channels, T). + + """ + residual = x + x = F.dropout(x, p=self.dropout_rate, training=self.training) + x = self.conv(x) + + # split into two part for gated activation + splitdim = 1 + xa, xb = paddle.split(x, 2, axis=splitdim) + + # local conditioning + if c is not None: + c = self.conv1x1_aux(c) + ca, cb = paddle.split(c, 2, axis=splitdim) + xa, xb = xa + ca, xb + cb + + # global conditioning + if g is not None: + g = self.conv1x1_glo(g) + ga, gb = paddle.split(g, 2, axis=splitdim) + xa, xb = xa + ga, xb + gb + + x = paddle.tanh(xa) * F.sigmoid(xb) + + # residual + skip 1x1 conv + x = self.conv1x1_out(x) + if x_mask is not None: + x = x * x_mask + + # split integrated conv results + x, s = paddle.split( + x, [self.residual_channels, self.skip_channels], axis=1) + + # for residual connection + x = x + residual + if self.scale_residual: + x = x * math.sqrt(0.5) + + return x, s diff --git a/paddlespeech/t2s/models/vits/wavenet/wavenet.py b/paddlespeech/t2s/models/vits/wavenet/wavenet.py new file mode 100644 index 000000000..44693dac6 --- /dev/null +++ b/paddlespeech/t2s/models/vits/wavenet/wavenet.py @@ -0,0 +1,175 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +import math +from typing import Optional + +import paddle +from paddle import nn + +from paddlespeech.t2s.models.vits.wavenet.residual_block import ResidualBlock + + +class WaveNet(nn.Layer): + """WaveNet with global conditioning.""" + + def __init__( + self, + in_channels: int=1, + out_channels: int=1, + kernel_size: int=3, + layers: int=30, + stacks: int=3, + base_dilation: int=2, + residual_channels: int=64, + aux_channels: int=-1, + gate_channels: int=128, + skip_channels: int=64, + global_channels: int=-1, + dropout_rate: float=0.0, + bias: bool=True, + use_weight_norm: bool=True, + use_first_conv: bool=False, + use_last_conv: bool=False, + scale_residual: bool=False, + scale_skip_connect: bool=False, ): + """Initialize WaveNet module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (int): Kernel size of dilated convolution. + layers (int): Number of residual block layers. + stacks (int): Number of stacks i.e., dilation cycles. + base_dilation (int): Base dilation factor. + residual_channels (int): Number of channels in residual conv. + gate_channels (int): Number of channels in gated conv. + skip_channels (int): Number of channels in skip conv. + aux_channels (int): Number of channels for local conditioning feature. + global_channels (int): Number of channels for global conditioning feature. + dropout_rate (float): Dropout rate. 0.0 means no dropout applied. + bias (bool): Whether to use bias parameter in conv layer. + use_weight_norm (bool): Whether to use weight norm. If set to true, it will + be applied to all of the conv layers. + use_first_conv (bool): Whether to use the first conv layers. + use_last_conv (bool): Whether to use the last conv layers. + scale_residual (bool): Whether to scale the residual outputs. + scale_skip_connect (bool): Whether to scale the skip connection outputs. + + """ + super().__init__() + self.layers = layers + self.stacks = stacks + self.kernel_size = kernel_size + self.base_dilation = base_dilation + self.use_first_conv = use_first_conv + self.use_last_conv = use_last_conv + self.scale_skip_connect = scale_skip_connect + + # check the number of layers and stacks + assert layers % stacks == 0 + layers_per_stack = layers // stacks + + # define first convolution + if self.use_first_conv: + self.first_conv = nn.Conv1D( + in_channels, residual_channels, kernel_size=1, bias_attr=True) + + # define residual blocks + self.conv_layers = nn.LayerList() + for layer in range(layers): + dilation = base_dilation**(layer % layers_per_stack) + conv = ResidualBlock( + kernel_size=kernel_size, + residual_channels=residual_channels, + gate_channels=gate_channels, + skip_channels=skip_channels, + aux_channels=aux_channels, + global_channels=global_channels, + dilation=dilation, + dropout_rate=dropout_rate, + bias=bias, + scale_residual=scale_residual, ) + self.conv_layers.append(conv) + + # define output layers + if self.use_last_conv: + self.last_conv = nn.Sequential( + nn.ReLU(), + nn.Conv1D( + skip_channels, skip_channels, kernel_size=1, + bias_attr=True), + nn.ReLU(), + nn.Conv1D( + skip_channels, out_channels, kernel_size=1, bias_attr=True), + ) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + def forward( + self, + x: paddle.Tensor, + x_mask: Optional[paddle.Tensor]=None, + c: Optional[paddle.Tensor]=None, + g: Optional[paddle.Tensor]=None, ) -> paddle.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T) if use_first_conv else + (B, residual_channels, T). + x_mask (Optional[Tensor]): Mask tensor (B, 1, T). + c (Optional[Tensor]): Local conditioning features (B, aux_channels, T). + g (Optional[Tensor]): Global conditioning features (B, global_channels, 1). + + Returns: + Tensor: Output tensor (B, out_channels, T) if use_last_conv else + (B, residual_channels, T). + + """ + # encode to hidden representation + if self.use_first_conv: + x = self.first_conv(x) + + # residual block + skips = 0.0 + for f in self.conv_layers: + x, h = f(x, x_mask=x_mask, c=c, g=g) + skips = skips + h + x = skips + if self.scale_skip_connect: + x = x * math.sqrt(1.0 / len(self.conv_layers)) + + # apply final layers + if self.use_last_conv: + x = self.last_conv(x) + + return x + + def apply_weight_norm(self): + def _apply_weight_norm(layer): + if isinstance(layer, (nn.Conv1D, nn.Conv2D)): + nn.utils.weight_norm(layer) + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + def _remove_weight_norm(layer): + try: + nn.utils.remove_weight_norm(layer) + except ValueError: + pass + + self.apply(_remove_weight_norm) diff --git a/paddlespeech/t2s/modules/losses.py b/paddlespeech/t2s/modules/losses.py index db31bcfbb..e6ab93513 100644 --- a/paddlespeech/t2s/modules/losses.py +++ b/paddlespeech/t2s/modules/losses.py @@ -17,7 +17,6 @@ import librosa import numpy as np import paddle from paddle import nn -from paddle.fluid.layers import sequence_mask from paddle.nn import functional as F from scipy import signal @@ -160,7 +159,7 @@ def sample_from_discretized_mix_logistic(y, log_scale_min=None): return x -# Loss for new Tacotron2 +# Loss for Tacotron2 class GuidedAttentionLoss(nn.Layer): """Guided attention loss function module. @@ -428,41 +427,6 @@ class Tacotron2Loss(nn.Layer): return l1_loss, mse_loss, bce_loss -# Loss for Tacotron2 -def attention_guide(dec_lens, enc_lens, N, T, g, dtype=None): - """Build that W matrix. shape(B, T_dec, T_enc) - W[i, n, t] = 1 - exp(-(n/dec_lens[i] - t/enc_lens[i])**2 / (2g**2)) - - See also: - Tachibana, Hideyuki, Katsuya Uenoyama, and Shunsuke Aihara. 2017. “Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention.” ArXiv:1710.08969 [Cs, Eess], October. http://arxiv.org/abs/1710.08969. - """ - dtype = dtype or paddle.get_default_dtype() - dec_pos = paddle.arange(0, N).astype(dtype) / dec_lens.unsqueeze( - -1) # n/N # shape(B, T_dec) - enc_pos = paddle.arange(0, T).astype(dtype) / enc_lens.unsqueeze( - -1) # t/T # shape(B, T_enc) - W = 1 - paddle.exp(-(dec_pos.unsqueeze(-1) - enc_pos.unsqueeze(1))**2 / - (2 * g**2)) - - dec_mask = sequence_mask(dec_lens, maxlen=N) - enc_mask = sequence_mask(enc_lens, maxlen=T) - mask = dec_mask.unsqueeze(-1) * enc_mask.unsqueeze(1) - mask = paddle.cast(mask, W.dtype) - - W *= mask - return W - - -def guided_attention_loss(attention_weight, dec_lens, enc_lens, g): - """Guided attention loss, masked to excluded padding parts.""" - _, N, T = attention_weight.shape - W = attention_guide(dec_lens, enc_lens, N, T, g, attention_weight.dtype) - - total_tokens = (dec_lens * enc_lens).astype(W.dtype) - loss = paddle.mean(paddle.sum(W * attention_weight, [1, 2]) / total_tokens) - return loss - - # Losses for GAN Vocoder def stft(x, fft_size, @@ -1006,3 +970,40 @@ class FeatureMatchLoss(nn.Layer): feat_match_loss /= i + 1 return feat_match_loss + + +# loss for VITS +class KLDivergenceLoss(nn.Layer): + """KL divergence loss.""" + + def forward( + self, + z_p: paddle.Tensor, + logs_q: paddle.Tensor, + m_p: paddle.Tensor, + logs_p: paddle.Tensor, + z_mask: paddle.Tensor, ) -> paddle.Tensor: + """Calculate KL divergence loss. + + Args: + z_p (Tensor): Flow hidden representation (B, H, T_feats). + logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats). + m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats). + logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats). + z_mask (Tensor): Mask tensor (B, 1, T_feats). + + Returns: + Tensor: KL divergence loss. + + """ + z_p = paddle.cast(z_p, 'float32') + logs_q = paddle.cast(logs_q, 'float32') + m_p = paddle.cast(m_p, 'float32') + logs_p = paddle.cast(logs_p, 'float32') + z_mask = paddle.cast(z_mask, 'float32') + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p)**2) * paddle.exp(-2.0 * logs_p) + kl = paddle.sum(kl * z_mask) + loss = kl / paddle.sum(z_mask) + + return loss diff --git a/paddlespeech/t2s/modules/nets_utils.py b/paddlespeech/t2s/modules/nets_utils.py index 4207d316c..598b63164 100644 --- a/paddlespeech/t2s/modules/nets_utils.py +++ b/paddlespeech/t2s/modules/nets_utils.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from espnet(https://github.com/espnet/espnet) +from typing import Tuple + import paddle from paddle import nn from typeguard import check_argument_types @@ -129,3 +131,66 @@ def initialize(model: nn.Layer, init: str): nn.initializer.Constant()) else: raise ValueError("Unknown initialization: " + init) + + +# for VITS +def get_random_segments( + x: paddle.paddle, + x_lengths: paddle.Tensor, + segment_size: int, ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Get random segments. + Args: + x (Tensor): Input tensor (B, C, T). + x_lengths (Tensor): Length tensor (B,). + segment_size (int): Segment size. + Returns: + Tensor: Segmented tensor (B, C, segment_size). + Tensor: Start index tensor (B,). + """ + b, c, t = paddle.shape(x) + max_start_idx = x_lengths - segment_size + start_idxs = paddle.cast(paddle.rand([b]) * max_start_idx, 'int64') + segments = get_segments(x, start_idxs, segment_size) + + return segments, start_idxs + + +def get_segments( + x: paddle.Tensor, + start_idxs: paddle.Tensor, + segment_size: int, ) -> paddle.Tensor: + """Get segments. + Args: + x (Tensor): Input tensor (B, C, T). + start_idxs (Tensor): Start index tensor (B,). + segment_size (int): Segment size. + Returns: + Tensor: Segmented tensor (B, C, segment_size). + """ + b, c, t = paddle.shape(x) + segments = paddle.zeros([b, c, segment_size], dtype=x.dtype) + for i, start_idx in enumerate(start_idxs): + segments[i] = x[i, :, start_idx:start_idx + segment_size] + return segments + + +# see https://github.com/PaddlePaddle/X2Paddle/blob/develop/docs/pytorch_project_convertor/API_docs/ops/torch.gather.md +def paddle_gather(x, dim, index): + index_shape = index.shape + index_flatten = index.flatten() + if dim < 0: + dim = len(x.shape) + dim + nd_index = [] + for k in range(len(x.shape)): + if k == dim: + nd_index.append(index_flatten) + else: + reshape_shape = [1] * len(x.shape) + reshape_shape[k] = x.shape[k] + x_arange = paddle.arange(x.shape[k], dtype=index.dtype) + x_arange = x_arange.reshape(reshape_shape) + dim_index = paddle.expand(x_arange, index_shape).flatten() + nd_index.append(dim_index) + ind2 = paddle.transpose(paddle.stack(nd_index), [1, 0]).astype("int64") + paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape) + return paddle_out diff --git a/paddlespeech/t2s/modules/transformer/repeat.py b/paddlespeech/t2s/modules/transformer/repeat.py index 2073a78b9..1e946adf7 100644 --- a/paddlespeech/t2s/modules/transformer/repeat.py +++ b/paddlespeech/t2s/modules/transformer/repeat.py @@ -36,4 +36,4 @@ def repeat(N, fn): Returns: MultiSequential: Repeated model instance. """ - return MultiSequential(*[fn(n) for n in range(N)]) + return MultiSequential(* [fn(n) for n in range(N)]) diff --git a/paddlespeech/t2s/training/optimizer.py b/paddlespeech/t2s/training/optimizer.py index 64274d538..3342cae53 100644 --- a/paddlespeech/t2s/training/optimizer.py +++ b/paddlespeech/t2s/training/optimizer.py @@ -14,6 +14,14 @@ import paddle from paddle import nn +scheduler_classes = dict( + ReduceOnPlateau=paddle.optimizer.lr.ReduceOnPlateau, + lambda_decay=paddle.optimizer.lr.LambdaDecay, + step_decay=paddle.optimizer.lr.StepDecay, + multistep_decay=paddle.optimizer.lr.MultiStepDecay, + exponential_decay=paddle.optimizer.lr.ExponentialDecay, + CosineAnnealingDecay=paddle.optimizer.lr.CosineAnnealingDecay, ) + optim_classes = dict( adadelta=paddle.optimizer.Adadelta, adagrad=paddle.optimizer.Adagrad, diff --git a/paddlespeech/t2s/utils/profile.py b/paddlespeech/t2s/utils/profile.py deleted file mode 100644 index 5f9b49526..000000000 --- a/paddlespeech/t2s/utils/profile.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from contextlib import contextmanager - -import paddle -from paddle.framework import core -from paddle.framework import CUDAPlace - - -def synchronize(): - """Trigger cuda synchronization for better timing.""" - place = paddle.fluid.framework._current_expected_place() - if isinstance(place, CUDAPlace): - paddle.fluid.core._cuda_synchronize(place) - - -@contextmanager -def nvtx_span(name): - try: - core.nvprof_nvtx_push(name) - yield - finally: - core.nvprof_nvtx_pop() diff --git a/paddlespeech/t2s/utils/timeline.py b/paddlespeech/t2s/utils/timeline.py deleted file mode 100644 index 0a5509dbe..000000000 --- a/paddlespeech/t2s/utils/timeline.py +++ /dev/null @@ -1,315 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import json - -import paddle.fluid.proto.profiler.profiler_pb2 as profiler_pb2 -import six - -parser = argparse.ArgumentParser(description=__doc__) -parser.add_argument( - '--profile_path', - type=str, - default='', - help='Input profile file name. If there are multiple file, the format ' - 'should be trainer1=file1,trainer2=file2,ps=file3') -parser.add_argument( - '--timeline_path', type=str, default='', help='Output timeline file name.') -args = parser.parse_args() - - -class _ChromeTraceFormatter(object): - def __init__(self): - self._events = [] - self._metadata = [] - - def _create_event(self, ph, category, name, pid, tid, timestamp): - """Creates a new Chrome Trace event. - - For details of the file format, see: - https://github.com/catapult-project/catapult/blob/master/tracing/README.md - - Args: - ph: The type of event - usually a single character. - category: The event category as a string. - name: The event name as a string. - pid: Identifier of the process generating this event as an integer. - tid: Identifier of the thread generating this event as an integer. - timestamp: The timestamp of this event as a long integer. - - Returns: - A JSON compatible event object. - """ - event = {} - event['ph'] = ph - event['cat'] = category - event['name'] = name.replace("ParallelExecutor::Run/", "") - event['pid'] = pid - event['tid'] = tid - event['ts'] = timestamp - return event - - def emit_pid(self, name, pid): - """Adds a process metadata event to the trace. - - Args: - name: The process name as a string. - pid: Identifier of the process as an integer. - """ - event = {} - event['name'] = 'process_name' - event['ph'] = 'M' - event['pid'] = pid - event['args'] = {'name': name} - self._metadata.append(event) - - def emit_region(self, timestamp, duration, pid, tid, category, name, args): - """Adds a region event to the trace. - - Args: - timestamp: The start timestamp of this region as a long integer. - duration: The duration of this region as a long integer. - pid: Identifier of the process generating this event as an integer. - tid: Identifier of the thread generating this event as an integer. - category: The event category as a string. - name: The event name as a string. - args: A JSON-compatible dictionary of event arguments. - """ - event = self._create_event('X', category, name, pid, tid, timestamp) - event['dur'] = duration - event['args'] = args - self._events.append(event) - - def emit_counter(self, category, name, pid, timestamp, counter, value): - """Emits a record for a single counter. - - Args: - category: The event category as string - name: The event name as string - pid: Identifier of the process generating this event as integer - timestamp: The timestamps of this event as long integer - counter: Name of the counter as string - value: Value of the counter as integer - tid: Thread id of the allocation as integer - """ - event = self._create_event('C', category, name, pid, 0, timestamp) - event['args'] = {counter: value} - self._events.append(event) - - def format_to_string(self, pretty=False): - """Formats the chrome trace to a string. - - Args: - pretty: (Optional.) If True, produce human-readable JSON output. - - Returns: - A JSON-formatted string in Chrome Trace format. - """ - trace = {} - trace['traceEvents'] = self._metadata + self._events - if pretty: - return json.dumps(trace, indent=4, separators=(',', ': ')) - else: - return json.dumps(trace, separators=(',', ':')) - - -class Timeline(object): - def __init__(self, profile_dict): - self._profile_dict = profile_dict - self._pid = 0 - self._devices = dict() - self._mem_devices = dict() - self._chrome_trace = _ChromeTraceFormatter() - - def _allocate_pid(self): - cur_pid = self._pid - self._pid += 1 - return cur_pid - - def _allocate_pids(self): - for k, profile_pb in six.iteritems(self._profile_dict): - for event in profile_pb.events: - if event.type == profiler_pb2.Event.CPU: - if (k, event.device_id, "CPU") not in self._devices: - pid = self._allocate_pid() - self._devices[(k, event.device_id, "CPU")] = pid - # -1 device id represents CUDA API(RunTime) call.(e.g. cudaLaunch, cudaMemcpy) - if event.device_id == -1: - self._chrome_trace.emit_pid("%s:cuda_api" % k, pid) - else: - self._chrome_trace.emit_pid( - "%s:cpu:block:%d" % (k, event.device_id), pid) - elif event.type == profiler_pb2.Event.GPUKernel: - if (k, event.device_id, "GPUKernel") not in self._devices: - pid = self._allocate_pid() - self._devices[(k, event.device_id, "GPUKernel")] = pid - self._chrome_trace.emit_pid("%s:gpu:%d" % - (k, event.device_id), pid) - if not hasattr(profile_pb, "mem_events"): - continue - for mevent in profile_pb.mem_events: - if mevent.place == profiler_pb2.MemEvent.CUDAPlace: - if (k, mevent.device_id, "GPU") not in self._mem_devices: - pid = self._allocate_pid() - self._mem_devices[(k, mevent.device_id, "GPU")] = pid - self._chrome_trace.emit_pid( - "memory usage on %s:gpu:%d" % (k, mevent.device_id), - pid) - elif mevent.place == profiler_pb2.MemEvent.CPUPlace: - if (k, mevent.device_id, "CPU") not in self._mem_devices: - pid = self._allocate_pid() - self._mem_devices[(k, mevent.device_id, "CPU")] = pid - self._chrome_trace.emit_pid( - "memory usage on %s:cpu:%d" % (k, mevent.device_id), - pid) - elif mevent.place == profiler_pb2.MemEvent.CUDAPinnedPlace: - if (k, mevent.device_id, - "CUDAPinnedPlace") not in self._mem_devices: - pid = self._allocate_pid() - self._mem_devices[(k, mevent.device_id, - "CUDAPinnedPlace")] = pid - self._chrome_trace.emit_pid( - "memory usage on %s:cudapinnedplace:%d" % - (k, mevent.device_id), pid) - elif mevent.place == profiler_pb2.MemEvent.NPUPlace: - if (k, mevent.device_id, "NPU") not in self._mem_devices: - pid = self._allocate_pid() - self._mem_devices[(k, mevent.device_id, "NPU")] = pid - self._chrome_trace.emit_pid( - "memory usage on %s:npu:%d" % (k, mevent.device_id), - pid) - if (k, 0, "CPU") not in self._mem_devices: - pid = self._allocate_pid() - self._mem_devices[(k, 0, "CPU")] = pid - self._chrome_trace.emit_pid("memory usage on %s:cpu:%d" % - (k, 0), pid) - if (k, 0, "GPU") not in self._mem_devices: - pid = self._allocate_pid() - self._mem_devices[(k, 0, "GPU")] = pid - self._chrome_trace.emit_pid("memory usage on %s:gpu:%d" % - (k, 0), pid) - if (k, 0, "CUDAPinnedPlace") not in self._mem_devices: - pid = self._allocate_pid() - self._mem_devices[(k, 0, "CUDAPinnedPlace")] = pid - self._chrome_trace.emit_pid( - "memory usage on %s:cudapinnedplace:%d" % (k, 0), pid) - if (k, 0, "NPU") not in self._mem_devices: - pid = self._allocate_pid() - self._mem_devices[(k, 0, "NPU")] = pid - self._chrome_trace.emit_pid("memory usage on %s:npu:%d" % - (k, 0), pid) - - def _allocate_events(self): - for k, profile_pb in six.iteritems(self._profile_dict): - for event in profile_pb.events: - if event.type == profiler_pb2.Event.CPU: - type = "CPU" - elif event.type == profiler_pb2.Event.GPUKernel: - type = "GPUKernel" - pid = self._devices[(k, event.device_id, type)] - args = {'name': event.name} - if event.memcopy.bytes > 0: - args['mem_bytes'] = event.memcopy.bytes - if hasattr(event, "detail_info") and event.detail_info: - args['detail_info'] = event.detail_info - # TODO(panyx0718): Chrome tracing only handles ms. However, some - # ops takes micro-seconds. Hence, we keep the ns here. - self._chrome_trace.emit_region( - event.start_ns, (event.end_ns - event.start_ns) / 1.0, pid, - event.sub_device_id, 'Op', event.name, args) - - def _allocate_memory_event(self): - if not hasattr(profiler_pb2, "MemEvent"): - return - place_to_str = { - profiler_pb2.MemEvent.CPUPlace: "CPU", - profiler_pb2.MemEvent.CUDAPlace: "GPU", - profiler_pb2.MemEvent.CUDAPinnedPlace: "CUDAPinnedPlace", - profiler_pb2.MemEvent.NPUPlace: "NPU" - } - for k, profile_pb in six.iteritems(self._profile_dict): - mem_list = [] - end_profiler = 0 - for mevent in profile_pb.mem_events: - crt_info = dict() - crt_info['time'] = mevent.start_ns - crt_info['size'] = mevent.bytes - if mevent.place in place_to_str: - place = place_to_str[mevent.place] - else: - place = "UnDefine" - crt_info['place'] = place - pid = self._mem_devices[(k, mevent.device_id, place)] - crt_info['pid'] = pid - crt_info['thread_id'] = mevent.thread_id - crt_info['device_id'] = mevent.device_id - mem_list.append(crt_info) - crt_info = dict() - crt_info['place'] = place - crt_info['pid'] = pid - crt_info['thread_id'] = mevent.thread_id - crt_info['device_id'] = mevent.device_id - crt_info['time'] = mevent.end_ns - crt_info['size'] = -mevent.bytes - mem_list.append(crt_info) - end_profiler = max(end_profiler, crt_info['time']) - mem_list.sort(key=lambda tmp: (tmp.get('time', 0))) - i = 0 - total_size = 0 - while i < len(mem_list): - total_size += mem_list[i]['size'] - while i < len(mem_list) - 1 and mem_list[i]['time'] == mem_list[ - i + 1]['time']: - total_size += mem_list[i + 1]['size'] - i += 1 - - self._chrome_trace.emit_counter( - "Memory", "Memory", mem_list[i]['pid'], mem_list[i]['time'], - 0, total_size) - i += 1 - - def generate_chrome_trace(self): - self._allocate_pids() - self._allocate_events() - self._allocate_memory_event() - return self._chrome_trace.format_to_string() - - -profile_path = '/tmp/profile' -if args.profile_path: - profile_path = args.profile_path -timeline_path = '/tmp/timeline' -if args.timeline_path: - timeline_path = args.timeline_path - -profile_paths = profile_path.split(',') -profile_dict = dict() -if len(profile_paths) == 1: - with open(profile_path, 'rb') as f: - profile_s = f.read() - profile_pb = profiler_pb2.Profile() - profile_pb.ParseFromString(profile_s) - profile_dict['trainer'] = profile_pb -else: - for profile_path in profile_paths: - k, v = profile_path.split('=') - with open(v, 'rb') as f: - profile_s = f.read() - profile_pb = profiler_pb2.Profile() - profile_pb.ParseFromString(profile_s) - profile_dict[k] = profile_pb - -tl = Timeline(profile_dict) -with open(timeline_path, 'w') as f: - f.write(tl.generate_chrome_trace()) diff --git a/examples/other/1xt2x/src_deepspeech2x/models/__init__.py b/paddlespeech/utils/__init__.py similarity index 100% rename from examples/other/1xt2x/src_deepspeech2x/models/__init__.py rename to paddlespeech/utils/__init__.py diff --git a/paddlespeech/utils/dynamic_import.py b/paddlespeech/utils/dynamic_import.py new file mode 100644 index 000000000..99f93356f --- /dev/null +++ b/paddlespeech/utils/dynamic_import.py @@ -0,0 +1,38 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +import importlib + +__all__ = ["dynamic_import"] + + +def dynamic_import(import_path, alias=dict()): + """dynamic import module and class + + :param str import_path: syntax 'module_name:class_name' + e.g., 'paddlespeech.s2t.models.u2:U2Model' + :param dict alias: shortcut for registered class + :return: imported class + """ + if import_path not in alias and ":" not in import_path: + raise ValueError( + "import_path should be one of {} or " + 'include ":", e.g. "paddlespeech.s2t.models.u2:U2Model" : ' + "{}".format(set(alias), import_path)) + if ":" not in import_path: + import_path = alias[import_path] + + module_name, objname = import_path.split(":") + m = importlib.import_module(module_name) + return getattr(m, objname) diff --git a/paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py b/paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py index e8d91bf3a..cd4538bb5 100644 --- a/paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py +++ b/paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py @@ -16,10 +16,10 @@ import os import time import paddle -from paddleaudio.backends import load as load_audio -from paddleaudio.compliance.librosa import melspectrogram from yacs.config import CfgNode +from paddlespeech.audio.backends import load as load_audio +from paddlespeech.audio.compliance.librosa import melspectrogram from paddlespeech.s2t.utils.log import Log from paddlespeech.vector.io.batch import feature_normalize from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn diff --git a/paddlespeech/vector/exps/ecapa_tdnn/test.py b/paddlespeech/vector/exps/ecapa_tdnn/test.py index f15dbf9b7..6c87dbe7b 100644 --- a/paddlespeech/vector/exps/ecapa_tdnn/test.py +++ b/paddlespeech/vector/exps/ecapa_tdnn/test.py @@ -18,10 +18,10 @@ import numpy as np import paddle from paddle.io import BatchSampler from paddle.io import DataLoader -from paddleaudio.metric import compute_eer from tqdm import tqdm from yacs.config import CfgNode +from paddlespeech.audio.metric import compute_eer from paddlespeech.s2t.utils.log import Log from paddlespeech.vector.io.batch import batch_feature_normalize from paddlespeech.vector.io.dataset import CSVDataset diff --git a/paddlespeech/vector/exps/ecapa_tdnn/train.py b/paddlespeech/vector/exps/ecapa_tdnn/train.py index bf014045d..961b75e29 100644 --- a/paddlespeech/vector/exps/ecapa_tdnn/train.py +++ b/paddlespeech/vector/exps/ecapa_tdnn/train.py @@ -20,9 +20,9 @@ import paddle from paddle.io import BatchSampler from paddle.io import DataLoader from paddle.io import DistributedBatchSampler -from paddleaudio.compliance.librosa import melspectrogram from yacs.config import CfgNode +from paddlespeech.audio.compliance.librosa import melspectrogram from paddlespeech.s2t.utils.log import Log from paddlespeech.vector.io.augment import build_augment_pipeline from paddlespeech.vector.io.augment import waveform_augment diff --git a/paddlespeech/vector/io/dataset.py b/paddlespeech/vector/io/dataset.py index 1b514f3d6..245b29592 100644 --- a/paddlespeech/vector/io/dataset.py +++ b/paddlespeech/vector/io/dataset.py @@ -15,9 +15,9 @@ from dataclasses import dataclass from dataclasses import fields from paddle.io import Dataset -from paddleaudio import load as load_audio -from paddleaudio.compliance.librosa import melspectrogram +from paddlespeech.audio import load as load_audio +from paddlespeech.audio.compliance.librosa import melspectrogram from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() diff --git a/paddlespeech/vector/io/dataset_from_json.py b/paddlespeech/vector/io/dataset_from_json.py index bf04e1132..12e845771 100644 --- a/paddlespeech/vector/io/dataset_from_json.py +++ b/paddlespeech/vector/io/dataset_from_json.py @@ -16,9 +16,10 @@ from dataclasses import dataclass from dataclasses import fields from paddle.io import Dataset -from paddleaudio import load as load_audio -from paddleaudio.compliance.librosa import melspectrogram -from paddleaudio.compliance.librosa import mfcc + +from paddlespeech.audio import load as load_audio +from paddlespeech.audio.compliance.librosa import melspectrogram +from paddlespeech.audio.compliance.librosa import mfcc @dataclass diff --git a/setup.py b/setup.py index ad353d42b..679549b4d 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,7 @@ from setuptools import find_packages from setuptools import setup from setuptools.command.develop import develop from setuptools.command.install import install +from setuptools.command.test import test HERE = Path(os.path.abspath(os.path.dirname(__file__))) @@ -31,42 +32,13 @@ VERSION = '0.0.0' COMMITID = 'none' base = [ - "editdistance", - "g2p_en", - "g2pM", - "h5py", - "inflect", - "jieba", - "jsonlines", - "kaldiio", - "librosa==0.8.1", - "loguru", - "matplotlib", - "nara_wpe", - "onnxruntime", - "pandas", - "paddleaudio", - "paddlenlp", - "paddlespeech_feat", - "praatio==5.0.0", - "pypinyin", - "pypinyin-dict", - "python-dateutil", - "pyworld", - "resampy==0.2.2", - "sacrebleu", - "scipy", - "sentencepiece~=0.1.96", - "soundfile~=0.10", - "textgrid", - "timer", - "tqdm", - "typeguard", - "visualdl", - "webrtcvad", - "yacs~=0.1.8", - "prettytable", - "zhon", + "editdistance", "g2p_en", "g2pM", "h5py", "inflect", "jieba", "jsonlines", + "kaldiio", "librosa==0.8.1", "loguru", "matplotlib", "nara_wpe", + "onnxruntime", "pandas", "paddlenlp", "paddlespeech_feat", "praatio==5.0.0", + "pypinyin", "pypinyin-dict", "python-dateutil", "pyworld", "resampy==0.2.2", + "sacrebleu", "scipy", "sentencepiece~=0.1.96", "soundfile~=0.10", + "textgrid", "timer", "tqdm", "typeguard", "visualdl", "webrtcvad", + "yacs~=0.1.8", "prettytable", "zhon", 'colorlog', 'pathos == 0.2.8' ] server = [ @@ -98,7 +70,6 @@ requirements = { } - def check_call(cmd: str, shell=False, executable=None): try: sp.check_call( @@ -112,12 +83,13 @@ def check_call(cmd: str, shell=False, executable=None): file=sys.stderr) raise e + def check_output(cmd: str, shell=False): try: out_bytes = sp.check_output(cmd.split()) except sp.CalledProcessError as e: - out_bytes = e.output # Output generated before error - code = e.returncode # Return code + out_bytes = e.output # Output generated before error + code = e.returncode # Return code print( f"{__file__}:{inspect.currentframe().f_lineno}: CMD: {cmd}, Error:", out_bytes, @@ -146,6 +118,7 @@ def _remove(files: str): for f in files: f.unlink() + ################################# Install ################################## @@ -176,7 +149,19 @@ class InstallCommand(install): install.run(self) - # cmd: python setup.py upload +class TestCommand(test): + def finalize_options(self): + test.finalize_options(self) + self.test_args = [] + self.test_suite = True + + def run_tests(self): + # Run nose ensuring that argv simulates running nosetests directly + import nose + nose.run_exit(argv=['nosetests', '-w', 'tests']) + + +# cmd: python setup.py upload class UploadCommand(Command): description = "Build and publish the package." user_options = [] @@ -278,11 +263,13 @@ setup_info = dict( "sphinx", "sphinx-rtd-theme", "numpydoc", "myst_parser", "recommonmark>=0.5.0", "sphinx-markdown-tables", "sphinx-autobuild" ], + 'test': ['nose', 'torchaudio==0.10.2'], }, cmdclass={ 'develop': DevelopCommand, 'install': InstallCommand, 'upload': UploadCommand, + 'test': TestCommand, }, # Package info @@ -308,6 +295,5 @@ setup_info = dict( ] }) - with version_info(): setup(**setup_info) diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt index 98d9e6374..4b5838e5c 100644 --- a/speechx/CMakeLists.txt +++ b/speechx/CMakeLists.txt @@ -142,4 +142,3 @@ set(DEPS ${DEPS} set(SPEECHX_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/speechx) add_subdirectory(speechx) -add_subdirectory(examples) diff --git a/speechx/README.md b/speechx/README.md index f75d8ac4e..cd1cd62c1 100644 --- a/speechx/README.md +++ b/speechx/README.md @@ -44,13 +44,13 @@ More details please see `README.md` under `examples`. > If using docker please check `--privileged` is set when `docker run`. * Fatal error at startup: `a function redirection which is mandatory for this platform-tool combination cannot be set up` -``` +```bash apt-get install libc6-dbg ``` * Install -``` +```bash pushd tools ./setup_valgrind.sh popd @@ -59,4 +59,4 @@ popd ## TODO ### Deepspeech2 with linear feature -* DecibelNormalizer: there is a little bit difference between offline and online db norm. The computation of online db norm read feature chunk by chunk, which causes the feature size is different with offline db norm. In normalizer.cc:73, the samples.size() is different, which causes the difference of result. +* DecibelNormalizer: there is a small difference between the offline and online db norm. The computation of online db norm reads features chunk by chunk, which causes the feature size to be different different with offline db norm. In `normalizer.cc:73`, the `samples.size()` is different, which causes the different result. diff --git a/speechx/examples/CMakeLists.txt b/speechx/examples/CMakeLists.txt deleted file mode 100644 index 3c274a20a..000000000 --- a/speechx/examples/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -cmake_minimum_required(VERSION 3.14 FATAL_ERROR) - -add_subdirectory(ds2_ol) -add_subdirectory(dev) \ No newline at end of file diff --git a/speechx/examples/README.md b/speechx/examples/README.md index b18c88e04..f7f6f9ac0 100644 --- a/speechx/examples/README.md +++ b/speechx/examples/README.md @@ -22,14 +22,6 @@ netron exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel --port 8022 --host ## For Developer -> Warning: Only for developer, make sure you know what's it. +> Reminder: Only for developer, make sure you know what's it. -* dev - for speechx developer, using for test. - -## Build WFST - -> Warning: Using below example when you know what's it. - -* text_lm - process text for build lm -* ngram - using to build NGram ARPA lm. -* wfst - build wfst for TLG. +* codelab - for speechx developer, using for test. diff --git a/speechx/examples/codelab/README.md b/speechx/examples/codelab/README.md new file mode 100644 index 000000000..f89184de9 --- /dev/null +++ b/speechx/examples/codelab/README.md @@ -0,0 +1,8 @@ +# Codelab + +## introduction + +> The below is for developing and offline testing. Do not run it only if you know what it is. +* nnet +* feat +* decoder diff --git a/speechx/examples/ds2_ol/decoder/.gitignore b/speechx/examples/codelab/decoder/.gitignore similarity index 100% rename from speechx/examples/ds2_ol/decoder/.gitignore rename to speechx/examples/codelab/decoder/.gitignore diff --git a/speechx/examples/ds2_ol/decoder/README.md b/speechx/examples/codelab/decoder/README.md similarity index 100% rename from speechx/examples/ds2_ol/decoder/README.md rename to speechx/examples/codelab/decoder/README.md diff --git a/speechx/examples/codelab/decoder/path.sh b/speechx/examples/codelab/decoder/path.sh new file mode 100644 index 000000000..9d2291743 --- /dev/null +++ b/speechx/examples/codelab/decoder/path.sh @@ -0,0 +1,14 @@ +# This contains the locations of binarys build required for running the examples. + +SPEECHX_ROOT=$PWD/../../../ +SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx + +SPEECHX_TOOLS=$SPEECHX_ROOT/tools +TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin + +[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; } + +export LC_AL=C + +SPEECHX_BIN=$SPEECHX_ROOT/build/speechx/decoder:$SPEECHX_ROOT/build/speechx/frontend/audio +export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN diff --git a/speechx/examples/ds2_ol/decoder/run.sh b/speechx/examples/codelab/decoder/run.sh similarity index 94% rename from speechx/examples/ds2_ol/decoder/run.sh rename to speechx/examples/codelab/decoder/run.sh index 40501eb41..a911eb033 100755 --- a/speechx/examples/ds2_ol/decoder/run.sh +++ b/speechx/examples/codelab/decoder/run.sh @@ -54,7 +54,7 @@ cmvn=$exp_dir/cmvn.ark export GLOG_logtostderr=1 # dump json cmvn to kaldi -cmvn-json2kaldi \ +cmvn_json2kaldi_main \ --json_file $ckpt_dir/data/mean_std.json \ --cmvn_write_path $cmvn \ --binary=false @@ -62,17 +62,17 @@ echo "convert json cmvn to kaldi ark." # generate linear feature as streaming -linear-spectrogram-wo-db-norm-ol \ +compute_linear_spectrogram_main \ --wav_rspecifier=scp:$data/wav.scp \ --feature_wspecifier=ark,t:$feat_wspecifier \ --cmvn_file=$cmvn echo "compute linear spectrogram feature." # run ctc beam search decoder as streaming -ctc-prefix-beam-search-decoder-ol \ +ctc_prefix_beam_search_decoder_main \ --result_wspecifier=ark,t:$exp_dir/result.txt \ --feature_rspecifier=ark:$feat_wspecifier \ --model_path=$model_dir/avg_1.jit.pdmodel \ --param_path=$model_dir/avg_1.jit.pdiparams \ --dict_file=$vocb_dir/vocab.txt \ - --lm_path=$lm \ No newline at end of file + --lm_path=$lm diff --git a/speechx/examples/ds2_ol/decoder/valgrind.sh b/speechx/examples/codelab/decoder/valgrind.sh similarity index 100% rename from speechx/examples/ds2_ol/decoder/valgrind.sh rename to speechx/examples/codelab/decoder/valgrind.sh diff --git a/speechx/examples/ds2_ol/feat/README.md b/speechx/examples/codelab/feat/README.md similarity index 58% rename from speechx/examples/ds2_ol/feat/README.md rename to speechx/examples/codelab/feat/README.md index 89cb79eca..e59e02bf9 100644 --- a/speechx/examples/ds2_ol/feat/README.md +++ b/speechx/examples/codelab/feat/README.md @@ -2,6 +2,6 @@ ASR audio feature test bins. We using theses bins to test linaer/fbank/mfcc asr feature as streaming manner. -* linear_spectrogram_without_db_norm_main.cc +* compute_linear_spectrogram_main.cc -compute linear spectrogram w/o db norm in streaming manner. +compute linear spectrogram without db norm in streaming manner. diff --git a/speechx/examples/dev/glog/path.sh b/speechx/examples/codelab/feat/path.sh similarity index 82% rename from speechx/examples/dev/glog/path.sh rename to speechx/examples/codelab/feat/path.sh index 1a96a861a..3b89d01e9 100644 --- a/speechx/examples/dev/glog/path.sh +++ b/speechx/examples/codelab/feat/path.sh @@ -1,15 +1,14 @@ # This contains the locations of binarys build required for running the examples. SPEECHX_ROOT=$PWD/../../../ +SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples SPEECHX_TOOLS=$SPEECHX_ROOT/tools TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin - -SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples [ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; } -SPEECHX_BIN=$SPEECHX_EXAMPLES/dev/glog -export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN - export LC_AL=C + +SPEECHX_BIN=$SPEECHX_ROOT/build/speechx/decoder:$SPEECHX_ROOT/build/speechx/frontend/audio +export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN diff --git a/speechx/examples/ds2_ol/feat/run.sh b/speechx/examples/codelab/feat/run.sh similarity index 95% rename from speechx/examples/ds2_ol/feat/run.sh rename to speechx/examples/codelab/feat/run.sh index 757779275..1fa37f981 100755 --- a/speechx/examples/ds2_ol/feat/run.sh +++ b/speechx/examples/codelab/feat/run.sh @@ -41,14 +41,14 @@ mkdir -p $exp_dir # 3. run feat export GLOG_logtostderr=1 -cmvn-json2kaldi \ +cmvn_json2kaldi_main \ --json_file $model_dir/data/mean_std.json \ --cmvn_write_path $exp_dir/cmvn.ark \ --binary=false echo "convert json cmvn to kaldi ark." -linear-spectrogram-wo-db-norm-ol \ +compute_linear_spectrogram_main \ --wav_rspecifier=scp:$data_dir/wav.scp \ --feature_wspecifier=ark,t:$exp_dir/feats.ark \ --cmvn_file=$exp_dir/cmvn.ark diff --git a/speechx/examples/ds2_ol/feat/valgrind.sh b/speechx/examples/codelab/feat/valgrind.sh similarity index 93% rename from speechx/examples/ds2_ol/feat/valgrind.sh rename to speechx/examples/codelab/feat/valgrind.sh index f8aab63f8..ea50fdc23 100755 --- a/speechx/examples/ds2_ol/feat/valgrind.sh +++ b/speechx/examples/codelab/feat/valgrind.sh @@ -17,7 +17,7 @@ feat_wspecifier=./feats.ark cmvn=./cmvn.ark valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \ - linear_spectrogram_main \ + compute_linear_spectrogram_main \ --wav_rspecifier=scp:$model_dir/wav.scp \ --feature_wspecifier=ark,t:$feat_wspecifier \ --cmvn_write_path=$cmvn diff --git a/speechx/examples/ds2_ol/nnet/.gitignore b/speechx/examples/codelab/nnet/.gitignore similarity index 100% rename from speechx/examples/ds2_ol/nnet/.gitignore rename to speechx/examples/codelab/nnet/.gitignore diff --git a/speechx/examples/ds2_ol/nnet/README.md b/speechx/examples/codelab/nnet/README.md similarity index 100% rename from speechx/examples/ds2_ol/nnet/README.md rename to speechx/examples/codelab/nnet/README.md diff --git a/speechx/examples/ds2_ol/feat/path.sh b/speechx/examples/codelab/nnet/path.sh similarity index 81% rename from speechx/examples/ds2_ol/feat/path.sh rename to speechx/examples/codelab/nnet/path.sh index ad2b6a4e9..7d395d648 100644 --- a/speechx/examples/ds2_ol/feat/path.sh +++ b/speechx/examples/codelab/nnet/path.sh @@ -1,7 +1,7 @@ # This contains the locations of binarys build required for running the examples. SPEECHX_ROOT=$PWD/../../../ -SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples +SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx SPEECHX_TOOLS=$SPEECHX_ROOT/tools TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin @@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin export LC_AL=C -SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/feat +SPEECHX_BIN=$SPEECHX_BUILD/codelab/nnet export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN diff --git a/speechx/examples/ds2_ol/nnet/run.sh b/speechx/examples/codelab/nnet/run.sh similarity index 75% rename from speechx/examples/ds2_ol/nnet/run.sh rename to speechx/examples/codelab/nnet/run.sh index 10029f7e8..842499ba2 100755 --- a/speechx/examples/ds2_ol/nnet/run.sh +++ b/speechx/examples/codelab/nnet/run.sh @@ -20,19 +20,10 @@ if [ ! -f data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]; popd fi -# produce wav scp -if [ ! -f data/wav.scp ]; then - mkdir -p data - pushd data - wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav - echo "utt1 " $PWD/zh.wav > wav.scp - popd -fi - ckpt_dir=./data/model model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/ -ds2-model-ol-test \ +ds2_model_test_main \ --model_path=$model_dir/avg_1.jit.pdmodel \ --param_path=$model_dir/avg_1.jit.pdiparams diff --git a/speechx/examples/ds2_ol/nnet/valgrind.sh b/speechx/examples/codelab/nnet/valgrind.sh similarity index 71% rename from speechx/examples/ds2_ol/nnet/valgrind.sh rename to speechx/examples/codelab/nnet/valgrind.sh index 2a08c6082..a5aab6637 100755 --- a/speechx/examples/ds2_ol/nnet/valgrind.sh +++ b/speechx/examples/codelab/nnet/valgrind.sh @@ -12,9 +12,10 @@ if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then exit 1 fi -model_dir=../paddle_asr_model +ckpt_dir=./data/model +model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/ valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \ - pp-model-test \ + ds2_model_test_main \ --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdparams \ No newline at end of file + --param_path=$model_dir/avg_1.jit.pdparams diff --git a/speechx/examples/custom_asr/README.md b/speechx/examples/custom_asr/README.md new file mode 100644 index 000000000..5ffa21b50 --- /dev/null +++ b/speechx/examples/custom_asr/README.md @@ -0,0 +1,32 @@ +# customized Auto Speech Recognition + +## introduction +These scripts are tutorials to show you how build your own decoding graph. + +eg: +* G with slot: 打车到 "address_slot"。 +![](https://ai-studio-static-online.cdn.bcebos.com/28d9ef132a7f47a895a65ae9e5c4f55b8f472c9f3dd24be8a2e66e0b88b173a4) + +* this is address slot wfst, you can add the address which want to recognize. +![](https://ai-studio-static-online.cdn.bcebos.com/47c89100ef8c465bac733605ffc53d76abefba33d62f4d818d351f8cea3c8fe2) + +* after replace operation, G = fstreplace(G_with_slot, address_slot), we will get the customized graph. +![](https://ai-studio-static-online.cdn.bcebos.com/60a3095293044f10b73039ab10c7950d139a6717580a44a3ba878c6e74de402b) + +These operations are in the scripts, please check out. we will lanuch more detail scripts. + +## How to run + +``` +bash run.sh +``` + +## Results + +### CTC WFST + +``` +Overall -> 1.23 % N=1134 C=1126 S=6 D=2 I=6 +Mandarin -> 1.24 % N=1132 C=1124 S=6 D=2 I=6 +English -> 0.00 % N=2 C=2 S=0 D=0 I=0 +``` diff --git a/speechx/examples/custom_asr/local/compile_lexicon_token_fst.sh b/speechx/examples/custom_asr/local/compile_lexicon_token_fst.sh new file mode 100755 index 000000000..8411f7ed6 --- /dev/null +++ b/speechx/examples/custom_asr/local/compile_lexicon_token_fst.sh @@ -0,0 +1,89 @@ +#!/bin/bash +# Copyright 2015 Yajie Miao (Carnegie Mellon University) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# This script compiles the lexicon and CTC tokens into FSTs. FST compiling slightly differs between the +# phoneme and character-based lexicons. +set -eo pipefail +. utils/parse_options.sh + +if [ $# -ne 3 ]; then + echo "usage: utils/fst/compile_lexicon_token_fst.sh " + echo "e.g.: utils/fst/compile_lexicon_token_fst.sh data/local/dict data/local/lang_tmp data/lang" + echo " should contain the following files:" + echo "lexicon.txt lexicon_numbers.txt units.txt" + echo "options: " + exit 1; +fi + +srcdir=$1 +tmpdir=$2 +dir=$3 +mkdir -p $dir $tmpdir + +[ -f path.sh ] && . ./path.sh + +cp $srcdir/units.txt $dir + +# Add probabilities to lexicon entries. There is in fact no point of doing this here since all the entries have 1.0. +# But utils/make_lexicon_fst.pl requires a probabilistic version, so we just leave it as it is. +perl -ape 's/(\S+\s+)(.+)/${1}1.0\t$2/;' < $srcdir/lexicon.txt > $tmpdir/lexiconp.txt || exit 1; + +# Add disambiguation symbols to the lexicon. This is necessary for determinizing the composition of L.fst and G.fst. +# Without these symbols, determinization will fail. +# default first disambiguation is #1 +ndisambig=`utils/fst/add_lex_disambig.pl $tmpdir/lexiconp.txt $tmpdir/lexiconp_disambig.txt` +# add #0 (#0 reserved for symbol in grammar). +ndisambig=$[$ndisambig+1]; + +( for n in `seq 0 $ndisambig`; do echo '#'$n; done ) > $tmpdir/disambig.list + +# Get the full list of CTC tokens used in FST. These tokens include , the blank , +# the actual model unit, and the disambiguation symbols. +cat $srcdir/units.txt | awk '{print $1}' > $tmpdir/units.list +(echo '';) | cat - $tmpdir/units.list $tmpdir/disambig.list | awk '{print $1 " " (NR-1)}' > $dir/tokens.txt + +# ctc_token_fst_corrected is too big and too slow for character based chinese modeling, +# so here just use simple ctc_token_fst +utils/fst/ctc_token_fst.py --token_file $dir/tokens.txt | \ + fstcompile --isymbols=$dir/tokens.txt --osymbols=$dir/tokens.txt --keep_isymbols=false --keep_osymbols=false | \ + fstarcsort --sort_type=olabel > $dir/T.fst || exit 1; + +# Encode the words with indices. Will be used in lexicon and language model FST compiling. +cat $tmpdir/lexiconp.txt | awk '{print $1}' | sort | awk ' + BEGIN { + print " 0"; + } + { + printf("%s %d\n", $1, NR); + } + END { + printf("#0 %d\n", NR+1); + printf(" %d\n", NR+2); + printf(" %d\n", NR+3); + printf("ROOT %d\n", NR+4); + }' > $dir/words.txt || exit 1; + +# Now compile the lexicon FST. Depending on the size of your lexicon, it may take some time. +token_disambig_symbol=`grep \#0 $dir/tokens.txt | awk '{print $2}'` +word_disambig_symbol=`grep \#0 $dir/words.txt | awk '{print $2}'` + +utils/fst/make_lexicon_fst.pl --pron-probs $tmpdir/lexiconp_disambig.txt 0 "sil" '#'$ndisambig | \ + fstcompile --isymbols=$dir/tokens.txt --osymbols=$dir/words.txt \ + --keep_isymbols=false --keep_osymbols=false | \ + fstaddselfloops "echo $token_disambig_symbol |" "echo $word_disambig_symbol |" | \ + fstarcsort --sort_type=olabel > $dir/L.fst || exit 1; + +echo "Lexicon and Token FSTs compiling succeeded" diff --git a/speechx/examples/custom_asr/local/mk_slot_graph.sh b/speechx/examples/custom_asr/local/mk_slot_graph.sh new file mode 100755 index 000000000..8298a5d09 --- /dev/null +++ b/speechx/examples/custom_asr/local/mk_slot_graph.sh @@ -0,0 +1,74 @@ +#!/bin/bash + +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +graph_slot=$1 +dir=$2 + +[ -f path.sh ] && . ./path.sh + +sym=$dir/../lang/words.txt +cat > $dir/address_slot.txt < +0 5 上海 上海 +0 5 北京 北京 +0 5 合肥 合肥 +5 1 南站 南站 +0 6 立水 立水 +6 1 桥 桥 +0 7 青岛 青岛 +7 1 站 站 +1 +EOF + +fstcompile --isymbols=$sym --osymbols=$sym $dir/address_slot.txt $dir/address_slot.fst +fstcompile --isymbols=$sym --osymbols=$sym $graph_slot/time_slot.txt $dir/time_slot.fst +fstcompile --isymbols=$sym --osymbols=$sym $graph_slot/date_slot.txt $dir/date_slot.fst +fstcompile --isymbols=$sym --osymbols=$sym $graph_slot/money_slot.txt $dir/money_slot.fst +fstcompile --isymbols=$sym --osymbols=$sym $graph_slot/year_slot.txt $dir/year_slot.fst diff --git a/speechx/examples/custom_asr/local/mk_tlg_with_slot.sh b/speechx/examples/custom_asr/local/mk_tlg_with_slot.sh new file mode 100755 index 000000000..a5569f400 --- /dev/null +++ b/speechx/examples/custom_asr/local/mk_tlg_with_slot.sh @@ -0,0 +1,61 @@ +#!/bin/bash + +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +lm=$1 +lang=$2 +tgt_lang=$3 + +unset GREP_OPTIONS + +sym=$lang/words.txt +arpa_lm=$lm/lm.arpa +# Compose the language model to FST +cat $arpa_lm | \ + grep -v ' ' | \ + grep -v ' ' | \ + grep -v ' ' | \ + grep -v -i '' | \ + grep -v -i '' | \ + arpa2fst --read-symbol-table=$sym --keep-symbols=true - | fstprint | \ + utils/fst/eps2disambig.pl | utils/fst/s2eps.pl | fstcompile --isymbols=$sym \ + --osymbols=$sym --keep_isymbols=false --keep_osymbols=false | \ + fstrmepsilon | fstarcsort --sort_type=ilabel > $tgt_lang/G_with_slot.fst + +root_label=`grep ROOT $sym | awk '{print $2}'` +address_slot_label=`grep \ $sym | awk '{print $2}'` +time_slot_label=`grep \ $sym | awk '{print $2}'` +date_slot_label=`grep \ $sym | awk '{print $2}'` +money_slot_label=`grep \ $sym | awk '{print $2}'` +year_slot_label=`grep \ $sym | awk '{print $2}'` + +fstisstochastic $tgt_lang/G_with_slot.fst + +fstreplace --epsilon_on_replace $tgt_lang/G_with_slot.fst \ + $root_label $tgt_lang/address_slot.fst $address_slot_label \ + $tgt_lang/date_slot.fst $date_slot_label \ + $tgt_lang/money_slot.fst $money_slot_label \ + $tgt_lang/time_slot.fst $time_slot_label \ + $tgt_lang/year_slot.fst $year_slot_label $tgt_lang/G.fst + +fstisstochastic $tgt_lang/G.fst + +# Compose the token, lexicon and language-model FST into the final decoding graph +fsttablecompose $lang/L.fst $tgt_lang/G.fst | fstdeterminizestar --use-log=true | \ + fstminimizeencoded | fstarcsort --sort_type=ilabel > $tgt_lang/LG.fst || exit 1; +fsttablecompose $lang/T.fst $tgt_lang/LG.fst > $tgt_lang/TLG.fst || exit 1; +rm $tgt_lang/LG.fst + +echo "Composing decoding graph TLG.fst succeeded" \ No newline at end of file diff --git a/speechx/examples/custom_asr/local/train_lm_with_slot.sh b/speechx/examples/custom_asr/local/train_lm_with_slot.sh new file mode 100755 index 000000000..3f557ec39 --- /dev/null +++ b/speechx/examples/custom_asr/local/train_lm_with_slot.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +# To be run from one directory above this script. +. ./path.sh +src=ds2_graph_with_slot +text=$src/train_text +lexicon=$src/local/dict/lexicon.txt + +dir=$src/local/lm +mkdir -p $dir + +for f in "$text" "$lexicon"; do + [ ! -f $x ] && echo "$0: No such file $f" && exit 1; +done + +# Check SRILM tools +if ! which ngram-count > /dev/null; then + pushd $MAIN_ROOT/tools + make srilm.done + popd +fi + +# This script takes no arguments. It assumes you have already run +# It takes as input the files +# data/local/lm/text +# data/local/dict/lexicon.txt + + +cleantext=$dir/text.no_oov + +cat $text | awk -v lex=$lexicon 'BEGIN{while((getline0){ seen[$1]=1; } } + {for(n=1; n<=NF;n++) { if (seen[$n]) { printf("%s ", $n); } else {printf(" ");} } printf("\n");}' \ + > $cleantext || exit 1; + +cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort | uniq -c | \ + sort -nr > $dir/word.counts || exit 1; +# Get counts from acoustic training transcripts, and add one-count +# for each word in the lexicon (but not silence, we don't want it +# in the LM-- we'll add it optionally later). +cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | \ + cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') | \ + sort | uniq -c | sort -nr > $dir/unigram.counts || exit 1; + +# filter the words which are not in the text +cat $dir/unigram.counts | awk '$1>1{print $0}' | awk '{print $2}' | cat - <(echo ""; echo "" ) > $dir/wordlist + +# kaldi_lm results +mkdir -p $dir +cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n $dir/train + +ngram-count -text $dir/train -order 3 -limit-vocab -vocab $dir/wordlist -unk \ + -map-unk "" -gt3max 0 -gt2max 0 -gt1max 0 -lm $dir/lm.arpa + +#ngram-count -text $dir/train -order 3 -limit-vocab -vocab $dir/wordlist -unk \ +# -map-unk "" -lm $dir/lm2.arpa \ No newline at end of file diff --git a/speechx/examples/custom_asr/path.sh b/speechx/examples/custom_asr/path.sh new file mode 100644 index 000000000..1907c79f9 --- /dev/null +++ b/speechx/examples/custom_asr/path.sh @@ -0,0 +1,17 @@ +# This contains the locations of binarys build required for running the examples. + +MAIN_ROOT=`realpath $PWD/../../../` +SPEECHX_ROOT=`realpath $MAIN_ROOT/speechx` +SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples + +export LC_AL=C + +# srilm +export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10 +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs +export SRILM=${MAIN_ROOT}/tools/srilm + +# kaldi lm +KALDI_DIR=$SPEECHX_ROOT/build/speechx/kaldi/ +OPENFST_DIR=$SPEECHX_ROOT/fc_patch/openfst-build/src +export PATH=${PATH}:${SRILM}/bin:${SRILM}/bin/i686-m64:$KALDI_DIR/lmbin:$KALDI_DIR/fstbin:$OPENFST_DIR/bin:$SPEECHX_EXAMPLES/ds2_ol/decoder diff --git a/speechx/examples/custom_asr/run.sh b/speechx/examples/custom_asr/run.sh new file mode 100644 index 000000000..ed67a52be --- /dev/null +++ b/speechx/examples/custom_asr/run.sh @@ -0,0 +1,87 @@ +#!/bin/bash +set +x +set -e + +export GLOG_logtostderr=1 + +. ./path.sh || exit 1; + +# ds2 means deepspeech2 (acoutic model type) +dir=$PWD/exp/ds2_graph_with_slot +data=$PWD/data +stage=0 +stop_stage=10 + +mkdir -p $dir + +model_dir=$PWD/resource/model +vocab=$model_dir/vocab.txt +cmvn=$data/cmvn.ark +text_with_slot=$data/text_with_slot +resource=$PWD/resource +# download resource +if [ ! -f $cmvn ]; then + wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/resource.tar.gz + tar xzfv resource.tar.gz + ln -s ./resource/data . +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # make dict + unit_file=$vocab + mkdir -p $dir/local/dict + cp $unit_file $dir/local/dict/units.txt + cp $text_with_slot $dir/train_text + utils/fst/prepare_dict.py --unit_file $unit_file --in_lexicon $data/lexicon.txt \ + --out_lexicon $dir/local/dict/lexicon.txt + # add slot to lexicon, just in case the lm training script filter the slot. + echo " 一" >> $dir/local/dict/lexicon.txt + echo " 一" >> $dir/local/dict/lexicon.txt + echo " 一" >> $dir/local/dict/lexicon.txt + echo " 一" >> $dir/local/dict/lexicon.txt + echo " 一" >> $dir/local/dict/lexicon.txt +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # train lm + lm=$dir/local/lm + mkdir -p $lm + # this script is different with the common lm training script + local/train_lm_with_slot.sh +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # make T & L + local/compile_lexicon_token_fst.sh $dir/local/dict $dir/local/tmp $dir/local/lang + mkdir -p $dir/local/lang_test + # make slot graph + local/mk_slot_graph.sh $resource/graph $dir/local/lang_test + # make TLG + local/mk_tlg_with_slot.sh $dir/local/lm $dir/local/lang $dir/local/lang_test || exit 1; + mv $dir/local/lang_test/TLG.fst $dir/local/lang/ +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # test TLG + model_dir=$PWD/resource/model + cmvn=$data/cmvn.ark + wav_scp=$data/wav.scp + graph=$dir/local/lang + + recognizer_test_main \ + --wav_rspecifier=scp:$wav_scp \ + --cmvn_file=$cmvn \ + --use_fbank=true \ + --model_path=$model_dir/avg_10.jit.pdmodel \ + --param_path=$model_dir/avg_10.jit.pdiparams \ + --model_cache_shapes="5-1-2048,5-1-2048" \ + --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ + --word_symbol_table=$graph/words.txt \ + --graph_path=$graph/TLG.fst --max_active=7500 \ + --acoustic_scale=12 \ + --result_wspecifier=ark,t:./exp/result_run.txt + + # the data/wav.trans is the label. + utils/compute-wer.py --char=1 --v=1 data/wav.trans exp/result_run.txt > exp/wer_run + tail -n 7 exp/wer_run +fi diff --git a/speechx/examples/custom_asr/utils b/speechx/examples/custom_asr/utils new file mode 120000 index 000000000..973afe674 --- /dev/null +++ b/speechx/examples/custom_asr/utils @@ -0,0 +1 @@ +../../../utils \ No newline at end of file diff --git a/speechx/examples/dev/glog/CMakeLists.txt b/speechx/examples/dev/glog/CMakeLists.txt deleted file mode 100644 index b4b0e6358..000000000 --- a/speechx/examples/dev/glog/CMakeLists.txt +++ /dev/null @@ -1,8 +0,0 @@ -cmake_minimum_required(VERSION 3.14 FATAL_ERROR) - -add_executable(glog_test ${CMAKE_CURRENT_SOURCE_DIR}/glog_test.cc) -target_link_libraries(glog_test glog) - - -add_executable(glog_logtostderr_test ${CMAKE_CURRENT_SOURCE_DIR}/glog_logtostderr_test.cc) -target_link_libraries(glog_logtostderr_test glog) \ No newline at end of file diff --git a/speechx/examples/dev/glog/run.sh b/speechx/examples/dev/glog/run.sh deleted file mode 100755 index d3fcdb643..000000000 --- a/speechx/examples/dev/glog/run.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash -set +x -set -e - -. ./path.sh - -# 1. compile -if [ ! -d ${SPEECHX_EXAMPLES} ]; then - pushd ${SPEECHX_ROOT} - bash build.sh - popd -fi - -# 2. run -glog_test - -echo "------" -export FLAGS_logtostderr=1 -glog_test - -echo "------" -glog_logtostderr_test diff --git a/speechx/examples/ds2_ol/CMakeLists.txt b/speechx/examples/ds2_ol/CMakeLists.txt deleted file mode 100644 index 08c194846..000000000 --- a/speechx/examples/ds2_ol/CMakeLists.txt +++ /dev/null @@ -1,6 +0,0 @@ -cmake_minimum_required(VERSION 3.14 FATAL_ERROR) - -add_subdirectory(feat) -add_subdirectory(nnet) -add_subdirectory(decoder) -add_subdirectory(websocket) diff --git a/speechx/examples/ds2_ol/README.md b/speechx/examples/ds2_ol/README.md index f405198d9..492d0e1ac 100644 --- a/speechx/examples/ds2_ol/README.md +++ b/speechx/examples/ds2_ol/README.md @@ -2,13 +2,5 @@ ## Examples -* `websocket` - Streaming ASR with websocket. - -* `aishell` - Streaming Decoding under aishell dataset, for local WER test. - -## More - -> The below is for developing and offline testing. Do not run it only if you know what it is. -* nnet -* feat -* decoder +* `websocket` - Streaming ASR with websocket for deepspeech2_aishell. +* `aishell` - Streaming Decoding under aishell dataset, for local WER test. diff --git a/speechx/examples/ds2_ol/aishell/README.md b/speechx/examples/ds2_ol/aishell/README.md index 1ed8a67c2..3e7af9244 100644 --- a/speechx/examples/ds2_ol/aishell/README.md +++ b/speechx/examples/ds2_ol/aishell/README.md @@ -42,3 +42,40 @@ Overall -> 10.93 % N=104765 C=93410 S=9780 D=1575 I=95 Mandarin -> 10.93 % N=104762 C=93410 S=9779 D=1573 I=95 Other -> 100.00 % N=3 C=0 S=1 D=2 I=0 ``` + +## fbank +``` +bash run_fbank.sh +``` + +### CTC Prefix Beam Search w/o LM + +``` +Overall -> 10.44 % N=104765 C=94194 S=10174 D=397 I=369 +Mandarin -> 10.44 % N=104762 C=94194 S=10171 D=397 I=369 +Other -> 100.00 % N=3 C=0 S=3 D=0 I=0 +``` + +### CTC Prefix Beam Search w/ LM + +LM: zh_giga.no_cna_cmn.prune01244.klm + +``` +Overall -> 5.82 % N=104765 C=99386 S=4944 D=435 I=720 +Mandarin -> 5.82 % N=104762 C=99386 S=4941 D=435 I=720 +English -> 0.00 % N=0 C=0 S=0 D=0 I=0 +``` + +### CTC WFST + +LM: [aishell train](https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph2.zip) +``` +Overall -> 9.58 % N=104765 C=94817 S=4326 D=5622 I=84 +Mandarin -> 9.57 % N=104762 C=94817 S=4325 D=5620 I=84 +Other -> 100.00 % N=3 C=0 S=1 D=2 I=0 +``` + +## build TLG graph +``` + bash run_build_tlg.sh +``` diff --git a/speechx/examples/ngram/zh/local/aishell_train_lms.sh b/speechx/examples/ds2_ol/aishell/local/aishell_train_lms.sh similarity index 100% rename from speechx/examples/ngram/zh/local/aishell_train_lms.sh rename to speechx/examples/ds2_ol/aishell/local/aishell_train_lms.sh diff --git a/speechx/examples/ds2_ol/aishell/path.sh b/speechx/examples/ds2_ol/aishell/path.sh index 520129eaf..6e8039350 100755 --- a/speechx/examples/ds2_ol/aishell/path.sh +++ b/speechx/examples/ds2_ol/aishell/path.sh @@ -1,14 +1,24 @@ # This contains the locations of binarys build required for running the examples. -SPEECHX_ROOT=$PWD/../../.. -SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples +MAIN_ROOT=`realpath $PWD/../../../../` +SPEECHX_ROOT=$PWD/../../../ +SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx SPEECHX_TOOLS=$SPEECHX_ROOT/tools TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin -[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; } +[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; } export LC_AL=C -SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat:$SPEECHX_EXAMPLES/ds2_ol/websocket -export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN +# openfst bin & kaldi bin +KALDI_DIR=$SPEECHX_ROOT/build/speechx/kaldi/ +OPENFST_DIR=$SPEECHX_ROOT/fc_patch/openfst-build/src + +# srilm +export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10 +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs +export SRILM=${MAIN_ROOT}/tools/srilm + +SPEECHX_BIN=$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/frontend/audio +export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN:${SRILM}/bin:${SRILM}/bin/i686-m64:$KALDI_DIR/lmbin:$KALDI_DIR/fstbin:$OPENFST_DIR/bin diff --git a/speechx/examples/ds2_ol/aishell/run.sh b/speechx/examples/ds2_ol/aishell/run.sh index 650cb1409..82e889ce5 100755 --- a/speechx/examples/ds2_ol/aishell/run.sh +++ b/speechx/examples/ds2_ol/aishell/run.sh @@ -69,27 +69,27 @@ export GLOG_logtostderr=1 cmvn=$data/cmvn.ark if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # 3. gen linear feat - cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn + cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn ./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \ - linear-spectrogram-wo-db-norm-ol \ + compute_linear_spectrogram_main \ --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ --feature_wspecifier=ark,scp:$data/split${nj}/JOB/feat.ark,$data/split${nj}/JOB/feat.scp \ --cmvn_file=$cmvn \ - --streaming_chunk=0.36 echo "feature make have finished!!!" fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # recognizer utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wolm.log \ - ctc-prefix-beam-search-decoder-ol \ + ctc_prefix_beam_search_decoder_main \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --model_path=$model_dir/avg_1.jit.pdmodel \ --param_path=$model_dir/avg_1.jit.pdiparams \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ + --nnet_decoder_chunk=8 \ --dict_file=$vocb_dir/vocab.txt \ --result_wspecifier=ark,t:$data/split${nj}/JOB/result @@ -97,16 +97,18 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file} > $exp/${wer} echo "ctc-prefix-beam-search-decoder-ol without lm has finished!!!" echo "please checkout in ${exp}/${wer}" + tail -n 7 $exp/${wer} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # decode with lm utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.lm.log \ - ctc-prefix-beam-search-decoder-ol \ + ctc_prefix_beam_search_decoder_main \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --model_path=$model_dir/avg_1.jit.pdmodel \ --param_path=$model_dir/avg_1.jit.pdiparams \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ + --nnet_decoder_chunk=8 \ --dict_file=$vocb_dir/vocab.txt \ --lm_path=$lm \ --result_wspecifier=ark,t:$data/split${nj}/JOB/result_lm @@ -115,6 +117,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_lm > $exp/${wer}.lm echo "ctc-prefix-beam-search-decoder-ol with lm test has finished!!!" echo "please checkout in ${exp}/${wer}.lm" + tail -n 7 $exp/${wer}.lm fi wfst=$data/wfst/ @@ -132,13 +135,14 @@ fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # TLG decoder utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \ - wfst-decoder-ol \ + tlg_decoder_main \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --model_path=$model_dir/avg_1.jit.pdmodel \ --param_path=$model_dir/avg_1.jit.pdiparams \ --word_symbol_table=$wfst/words.txt \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --graph_path=$wfst/TLG.fst --max_active=7500 \ + --nnet_decoder_chunk=8 \ --acoustic_scale=1.2 \ --result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg @@ -146,18 +150,19 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_tlg > $exp/${wer}.tlg echo "wfst-decoder-ol have finished!!!" echo "please checkout in ${exp}/${wer}.tlg" + tail -n 7 $exp/${wer}.tlg fi if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then # TLG decoder utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.log \ - recognizer_test_main \ + recognizer_main \ --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ --cmvn_file=$cmvn \ --model_path=$model_dir/avg_1.jit.pdmodel \ - --streaming_chunk=30 \ --param_path=$model_dir/avg_1.jit.pdiparams \ --word_symbol_table=$wfst/words.txt \ + --nnet_decoder_chunk=8 \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --graph_path=$wfst/TLG.fst --max_active=7500 \ --acoustic_scale=1.2 \ @@ -167,4 +172,5 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_recognizer > $exp/${wer}.recognizer echo "recognizer test have finished!!!" echo "please checkout in ${exp}/${wer}.recognizer" + tail -n 7 $exp/${wer}.recognizer fi diff --git a/speechx/examples/ds2_ol/aishell/run_build_tlg.sh b/speechx/examples/ds2_ol/aishell/run_build_tlg.sh new file mode 100755 index 000000000..2e148657b --- /dev/null +++ b/speechx/examples/ds2_ol/aishell/run_build_tlg.sh @@ -0,0 +1,141 @@ +#!/bin/bash +set -eo pipefail + +. path.sh + +# attention, please replace the vocab is only for this script. +# different acustic model has different vocab +ckpt_dir=data/fbank_model +unit=$ckpt_dir/data/lang_char/vocab.txt # vocab file, line: char/spm_pice +model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/ + +stage=-1 +stop_stage=100 +corpus=aishell +lexicon=data/lexicon.txt # line: word ph0 ... phn, aishell/resource_aishell/lexicon.txt +text=data/text # line: utt text, aishell/data_aishell/transcript/aishell_transcript_v0.8.txt + +. utils/parse_options.sh + +data=$PWD/data +mkdir -p $data + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + if [ ! -f $data/speech.ngram.zh.tar.gz ];then + pushd $data + wget -c http://paddlespeech.bj.bcebos.com/speechx/examples/ngram/zh/speech.ngram.zh.tar.gz + tar xvzf speech.ngram.zh.tar.gz + popd + fi + + if [ ! -f $ckpt_dir/data/mean_std.json ]; then + mkdir -p $ckpt_dir + pushd $ckpt_dir + wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz + tar xzfv WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz + popd + fi +fi + +if [ ! -f $unit ]; then + echo "$0: No such file $unit" + exit 1; +fi + +if ! which ngram-count; then + pushd $MAIN_ROOT/tools + make srilm.done + popd +fi + +mkdir -p data/local/dict +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # Prepare dict + # line: char/spm_pices + cp $unit data/local/dict/units.txt + + if [ ! -f $lexicon ];then + utils/text_to_lexicon.py --has_key true --text $text --lexicon $lexicon + echo "Generate $lexicon from $text" + fi + + # filter by vocab + # line: word ph0 ... phn -> line: word char0 ... charn + utils/fst/prepare_dict.py \ + --unit_file $unit \ + --in_lexicon ${lexicon} \ + --out_lexicon data/local/dict/lexicon.txt +fi + +lm=data/local/lm +mkdir -p $lm + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # Train lm + cp $text $lm/text + local/aishell_train_lms.sh + echo "build LM done." +fi + +# build TLG +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # build T & L + utils/fst/compile_lexicon_token_fst.sh \ + data/local/dict data/local/tmp data/local/lang + + # build G & TLG + utils/fst/make_tlg.sh data/local/lm data/local/lang data/lang_test || exit 1; + +fi + +aishell_wav_scp=aishell_test.scp +nj=40 +cmvn=$data/cmvn_fbank.ark +wfst=$data/lang_test + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + + if [ ! -d $data/test ]; then + pushd $data + wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip + unzip aishell_test.zip + popd + + realpath $data/test/*/*.wav > $data/wavlist + awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id + paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp + fi + + ./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj + + cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn +fi + +wer=aishell_wer +label_file=aishell_result +export GLOG_logtostderr=1 + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # TLG decoder + utils/run.pl JOB=1:$nj $data/split${nj}/JOB/check_tlg.log \ + recognizer_main \ + --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ + --cmvn_file=$cmvn \ + --model_path=$model_dir/avg_5.jit.pdmodel \ + --streaming_chunk=30 \ + --use_fbank=true \ + --param_path=$model_dir/avg_5.jit.pdiparams \ + --word_symbol_table=$wfst/words.txt \ + --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ + --model_cache_shapes="5-1-2048,5-1-2048" \ + --graph_path=$wfst/TLG.fst --max_active=7500 \ + --acoustic_scale=1.2 \ + --result_wspecifier=ark,t:$data/split${nj}/JOB/result_check_tlg + + cat $data/split${nj}/*/result_check_tlg > $exp/${label_file}_check_tlg + utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_check_tlg > $exp/${wer}.check_tlg + echo "recognizer test have finished!!!" + echo "please checkout in ${exp}/${wer}.check_tlg" +fi + +exit 0 diff --git a/speechx/examples/ds2_ol/aishell/run_fbank.sh b/speechx/examples/ds2_ol/aishell/run_fbank.sh index 3d4825ace..720728354 100755 --- a/speechx/examples/ds2_ol/aishell/run_fbank.sh +++ b/speechx/examples/ds2_ol/aishell/run_fbank.sh @@ -69,7 +69,7 @@ export GLOG_logtostderr=1 cmvn=$data/cmvn_fbank.ark if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # 3. gen linear feat - cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn --binary=false + cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn --binary=false ./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj @@ -84,34 +84,38 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # recognizer utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wolm.log \ - ctc-prefix-beam-search-decoder-ol \ + ctc_prefix_beam_search_decoder_main \ --feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \ --model_path=$model_dir/avg_5.jit.pdmodel \ --param_path=$model_dir/avg_5.jit.pdiparams \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --model_cache_shapes="5-1-2048,5-1-2048" \ + --nnet_decoder_chunk=8 \ --dict_file=$vocb_dir/vocab.txt \ --result_wspecifier=ark,t:$data/split${nj}/JOB/result_fbank cat $data/split${nj}/*/result_fbank > $exp/${label_file} utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file} > $exp/${wer} + tail -n 7 $exp/${wer} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # decode with lm utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.lm.log \ - ctc-prefix-beam-search-decoder-ol \ + ctc_prefix_beam_search_decoder_main \ --feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \ --model_path=$model_dir/avg_5.jit.pdmodel \ --param_path=$model_dir/avg_5.jit.pdiparams \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --model_cache_shapes="5-1-2048,5-1-2048" \ + --nnet_decoder_chunk=8 \ --dict_file=$vocb_dir/vocab.txt \ --lm_path=$lm \ --result_wspecifier=ark,t:$data/split${nj}/JOB/fbank_result_lm cat $data/split${nj}/*/fbank_result_lm > $exp/${label_file}_lm utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_lm > $exp/${wer}.lm + tail -n 7 $exp/${wer}.lm fi wfst=$data/wfst_fbank/ @@ -129,13 +133,14 @@ fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # TLG decoder utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wfst.log \ - wfst-decoder-ol \ + tlg_decoder_main \ --feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \ --model_path=$model_dir/avg_5.jit.pdmodel \ --param_path=$model_dir/avg_5.jit.pdiparams \ --word_symbol_table=$wfst/words.txt \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --model_cache_shapes="5-1-2048,5-1-2048" \ + --nnet_decoder_chunk=8 \ --graph_path=$wfst/TLG.fst --max_active=7500 \ --acoustic_scale=1.2 \ --result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg @@ -144,21 +149,21 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_tlg > $exp/${wer}.tlg echo "wfst-decoder-ol have finished!!!" echo "please checkout in ${exp}/${wer}.tlg" + tail -n 7 $exp/${wer}.tlg fi if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then utils/run.pl JOB=1:$nj $data/split${nj}/JOB/fbank_recognizer.log \ - recognizer_test_main \ + recognizer_main \ --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ --cmvn_file=$cmvn \ --model_path=$model_dir/avg_5.jit.pdmodel \ - --streaming_chunk=30 \ --use_fbank=true \ - --to_float32=false \ --param_path=$model_dir/avg_5.jit.pdiparams \ --word_symbol_table=$wfst/words.txt \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --model_cache_shapes="5-1-2048,5-1-2048" \ + --nnet_decoder_chunk=8 \ --graph_path=$wfst/TLG.fst --max_active=7500 \ --acoustic_scale=1.2 \ --result_wspecifier=ark,t:$data/split${nj}/JOB/result_fbank_recognizer @@ -167,4 +172,5 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_recognizer > $exp/${wer}.recognizer echo "recognizer test have finished!!!" echo "please checkout in ${exp}/${wer}.recognizer" + tail -n 7 $exp/${wer}.recognizer fi diff --git a/speechx/examples/ds2_ol/decoder/CMakeLists.txt b/speechx/examples/ds2_ol/decoder/CMakeLists.txt deleted file mode 100644 index 62dd6862e..000000000 --- a/speechx/examples/ds2_ol/decoder/CMakeLists.txt +++ /dev/null @@ -1,22 +0,0 @@ -cmake_minimum_required(VERSION 3.14 FATAL_ERROR) - -set(bin_name ctc-prefix-beam-search-decoder-ol) -add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) -target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) - - -set(bin_name wfst-decoder-ol) -add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) -target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS}) - - -set(bin_name nnet-logprob-decoder-test) -add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) -target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) - -add_executable(recognizer_test_main ${CMAKE_CURRENT_SOURCE_DIR}/recognizer_test_main.cc) -target_include_directories(recognizer_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(recognizer_test_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS}) diff --git a/speechx/examples/ds2_ol/decoder/local/model.sh b/speechx/examples/ds2_ol/decoder/local/model.sh deleted file mode 100644 index 5c609a6cf..000000000 --- a/speechx/examples/ds2_ol/decoder/local/model.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash - - diff --git a/speechx/examples/ds2_ol/decoder/path.sh b/speechx/examples/ds2_ol/decoder/path.sh deleted file mode 100644 index 8e26e6e7e..000000000 --- a/speechx/examples/ds2_ol/decoder/path.sh +++ /dev/null @@ -1,14 +0,0 @@ -# This contains the locations of binarys build required for running the examples. - -SPEECHX_ROOT=$PWD/../../../ -SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples - -SPEECHX_TOOLS=$SPEECHX_ROOT/tools -TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin - -[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; } - -export LC_AL=C - -SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat -export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN diff --git a/speechx/examples/ds2_ol/feat/.gitignore b/speechx/examples/ds2_ol/feat/.gitignore deleted file mode 100644 index 566f2d97b..000000000 --- a/speechx/examples/ds2_ol/feat/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -exp -data diff --git a/speechx/examples/ds2_ol/feat/CMakeLists.txt b/speechx/examples/ds2_ol/feat/CMakeLists.txt deleted file mode 100644 index 632f22e85..000000000 --- a/speechx/examples/ds2_ol/feat/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -cmake_minimum_required(VERSION 3.14 FATAL_ERROR) - -set(bin_name linear-spectrogram-wo-db-norm-ol) -add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) -target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(${bin_name} frontend kaldi-util kaldi-feat-common gflags glog) - -set(bin_name compute_fbank_main) -add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) -target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(${bin_name} frontend kaldi-util kaldi-feat-common gflags glog) - -set(bin_name cmvn-json2kaldi) -add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) -target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog) diff --git a/speechx/examples/ds2_ol/nnet/path.sh b/speechx/examples/ds2_ol/nnet/path.sh deleted file mode 100644 index 0ee8b4787..000000000 --- a/speechx/examples/ds2_ol/nnet/path.sh +++ /dev/null @@ -1,14 +0,0 @@ -# This contains the locations of binarys build required for running the examples. - -SPEECHX_ROOT=$PWD/../../../ -SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples - -SPEECHX_TOOLS=$SPEECHX_ROOT/tools -TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin - -[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; } - -export LC_AL=C - -SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/nnet -export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN diff --git a/speechx/examples/ngram/.gitignore b/speechx/examples/ds2_ol/onnx/.gitignore similarity index 69% rename from speechx/examples/ngram/.gitignore rename to speechx/examples/ds2_ol/onnx/.gitignore index bbd86a25b..f862f73e2 100644 --- a/speechx/examples/ngram/.gitignore +++ b/speechx/examples/ds2_ol/onnx/.gitignore @@ -1,2 +1,3 @@ data +log exp diff --git a/speechx/examples/ds2_ol/onnx/README.md b/speechx/examples/ds2_ol/onnx/README.md new file mode 100644 index 000000000..566a4597d --- /dev/null +++ b/speechx/examples/ds2_ol/onnx/README.md @@ -0,0 +1,37 @@ +# DeepSpeech2 ONNX model + +1. convert deepspeech2 model to ONNX, using Paddle2ONNX. +2. check paddleinference and onnxruntime output equal. +3. optimize onnx model +4. check paddleinference and optimized onnxruntime output equal. + +Please make sure [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX) and [onnx-simplifier](https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape) version is correct. + +The example test with these packages installed: +``` +paddle2onnx 0.9.8rc0 # develop af4354b4e9a61a93be6490640059a02a4499bc7a +paddleaudio 0.2.1 +paddlefsl 1.1.0 +paddlenlp 2.2.6 +paddlepaddle-gpu 2.2.2 +paddlespeech 0.0.0 # develop +paddlespeech-ctcdecoders 0.2.0 +paddlespeech-feat 0.1.0 +onnx 1.11.0 +onnx-simplifier 0.0.0 # https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape +onnxoptimizer 0.2.7 +onnxruntime 1.11.0 +``` + +## Using + +``` +bash run.sh +``` + +For more details please see `run.sh`. + +## Outputs +The optimized onnx model is `exp/model.opt.onnx`. + +To show the graph, please using `local/netron.sh`. diff --git a/speechx/examples/ds2_ol/onnx/local/infer_check.py b/speechx/examples/ds2_ol/onnx/local/infer_check.py new file mode 100755 index 000000000..307a764ce --- /dev/null +++ b/speechx/examples/ds2_ol/onnx/local/infer_check.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +import pickle + +import numpy as np +import onnxruntime +import paddle + + +def parse_args(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--input_file', + type=str, + default="static_ds2online_inputs.pickle", + help="ds2 input pickle file.", ) + parser.add_argument( + '--model_dir', type=str, default=".", help="paddle model dir.") + parser.add_argument( + '--model_prefix', + type=str, + default="avg_1.jit", + help="paddle model prefix.") + parser.add_argument( + '--onnx_model', + type=str, + default='./model.old.onnx', + help="onnx model.") + + return parser.parse_args() + + +if __name__ == '__main__': + FLAGS = parse_args() + + # input and output + with open(FLAGS.input_file, 'rb') as f: + iodict = pickle.load(f) + print(iodict.keys()) + + audio_chunk = iodict['audio_chunk'] + audio_chunk_lens = iodict['audio_chunk_lens'] + chunk_state_h_box = iodict['chunk_state_h_box'] + chunk_state_c_box = iodict['chunk_state_c_bos'] + + # paddle + model = paddle.jit.load(os.path.join(FLAGS.model_dir, FLAGS.model_prefix)) + res_chunk, res_lens, chunk_state_h, chunk_state_c = model( + paddle.to_tensor(audio_chunk), + paddle.to_tensor(audio_chunk_lens), + paddle.to_tensor(chunk_state_h_box), + paddle.to_tensor(chunk_state_c_box), ) + + # onnxruntime + options = onnxruntime.SessionOptions() + options.enable_profiling = True + sess = onnxruntime.InferenceSession(FLAGS.onnx_model, sess_options=options) + ort_res_chunk, ort_res_lens, ort_chunk_state_h, ort_chunk_state_c = sess.run( + ['softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0'], { + "audio_chunk": audio_chunk, + "audio_chunk_lens": audio_chunk_lens, + "chunk_state_h_box": chunk_state_h_box, + "chunk_state_c_box": chunk_state_c_box + }) + + print(sess.end_profiling()) + + # assert paddle equal ort + print(np.allclose(ort_res_chunk, res_chunk, atol=1e-6)) + print(np.allclose(ort_res_lens, res_lens, atol=1e-6)) + print(np.allclose(ort_chunk_state_h, chunk_state_h, atol=1e-6)) + print(np.allclose(ort_chunk_state_c, chunk_state_c, atol=1e-6)) diff --git a/speechx/examples/ds2_ol/onnx/local/netron.sh b/speechx/examples/ds2_ol/onnx/local/netron.sh new file mode 100755 index 000000000..6dd9a39c9 --- /dev/null +++ b/speechx/examples/ds2_ol/onnx/local/netron.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +# show model + +if [ $# != 1 ];then + echo "usage: $0 model_path" + exit 1 +fi + + +file=$1 + +pip install netron +netron -p 8082 --host $(hostname -i) $file \ No newline at end of file diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_clone.sh b/speechx/examples/ds2_ol/onnx/local/onnx_clone.sh new file mode 100755 index 000000000..bce22dbc8 --- /dev/null +++ b/speechx/examples/ds2_ol/onnx/local/onnx_clone.sh @@ -0,0 +1,7 @@ + +#!/bin/bash + +# clone onnx repos +git clone https://github.com/onnx/onnx.git +git clone https://github.com/microsoft/onnxruntime.git +git clone https://github.com/PaddlePaddle/Paddle2ONNX.git \ No newline at end of file diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py b/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py new file mode 100755 index 000000000..c41e66b72 --- /dev/null +++ b/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py @@ -0,0 +1,2515 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# flake8: noqa +import argparse +import logging + +import numpy as np +import onnx +import sympy +from onnx import helper +from onnx import numpy_helper +from onnx import shape_inference +from packaging import version +assert version.parse(onnx.__version__) >= version.parse("1.8.0") + +logger = logging.getLogger(__name__) + + +def get_attribute(node, attr_name, default_value=None): + found = [attr for attr in node.attribute if attr.name == attr_name] + if found: + return helper.get_attribute_value(found[0]) + return default_value + + +def get_dim_from_proto(dim): + return getattr(dim, dim.WhichOneof('value')) if type( + dim.WhichOneof('value')) == str else None + + +def is_sequence(type_proto): + cls_type = type_proto.WhichOneof('value') + assert cls_type in ['tensor_type', 'sequence_type'] + return cls_type == 'sequence_type' + + +def get_shape_from_type_proto(type_proto): + assert not is_sequence(type_proto) + if type_proto.tensor_type.HasField('shape'): + return [get_dim_from_proto(d) for d in type_proto.tensor_type.shape.dim] + else: + return None # note no shape is different from shape without dim (scalar) + + +def get_shape_from_value_info(vi): + cls_type = vi.type.WhichOneof('value') + if cls_type is None: + return None + if is_sequence(vi.type): + if 'tensor_type' == vi.type.sequence_type.elem_type.WhichOneof('value'): + return get_shape_from_type_proto(vi.type.sequence_type.elem_type) + else: + return None + else: + return get_shape_from_type_proto(vi.type) + + +def make_named_value_info(name): + vi = onnx.ValueInfoProto() + vi.name = name + return vi + + +def get_shape_from_sympy_shape(sympy_shape): + return [ + None if i is None else (int(i) if is_literal(i) else str(i)) + for i in sympy_shape + ] + + +def is_literal(dim): + return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (hasattr( + dim, 'is_number') and dim.is_number) + + +def handle_negative_axis(axis, rank): + assert axis < rank and axis >= -rank + return axis if axis >= 0 else rank + axis + + +def get_opset(mp, domain=None): + domain = domain or ['', 'onnx', 'ai.onnx'] + if type(domain) != list: + domain = [domain] + for opset in mp.opset_import: + if opset.domain in domain: + return opset.version + + return None + + +def as_scalar(x): + if type(x) == list: + assert len(x) == 1 + return x[0] + elif type(x) == np.ndarray: + return x.item() + else: + return x + + +def as_list(x, keep_none): + if type(x) == list: + return x + elif type(x) == np.ndarray: + return list(x) + elif keep_none and x is None: + return None + else: + return [x] + + +def sympy_reduce_product(x): + if type(x) == list: + value = sympy.Integer(1) + for v in x: + value = value * v + else: + value = x + return value + + +class SymbolicShapeInference: + def __init__(self, + int_max, + auto_merge, + guess_output_rank, + verbose, + prefix=''): + self.dispatcher_ = { + 'Add': + self._infer_symbolic_compute_ops, + 'ArrayFeatureExtractor': + self._infer_ArrayFeatureExtractor, + 'AveragePool': + self._infer_Pool, + 'BatchNormalization': + self._infer_BatchNormalization, + 'Cast': + self._infer_Cast, + 'CategoryMapper': + self._infer_CategoryMapper, + 'Compress': + self._infer_Compress, + 'Concat': + self._infer_Concat, + 'ConcatFromSequence': + self._infer_ConcatFromSequence, + 'Constant': + self._infer_Constant, + 'ConstantOfShape': + self._infer_ConstantOfShape, + 'Conv': + self._infer_Conv, + 'CumSum': + self._pass_on_shape_and_type, + 'Div': + self._infer_symbolic_compute_ops, + 'Einsum': + self._infer_Einsum, + 'Expand': + self._infer_Expand, + 'Equal': + self._infer_symbolic_compute_ops, + 'Floor': + self._infer_symbolic_compute_ops, + 'Gather': + self._infer_Gather, + 'GatherElements': + self._infer_GatherElements, + 'GatherND': + self._infer_GatherND, + 'Gelu': + self._pass_on_shape_and_type, + 'If': + self._infer_If, + 'Loop': + self._infer_Loop, + 'MatMul': + self._infer_MatMul, + 'MatMulInteger16': + self._infer_MatMulInteger, + 'MaxPool': + self._infer_Pool, + 'Max': + self._infer_symbolic_compute_ops, + 'Min': + self._infer_symbolic_compute_ops, + 'Mul': + self._infer_symbolic_compute_ops, + 'NonMaxSuppression': + self._infer_NonMaxSuppression, + 'NonZero': + self._infer_NonZero, + 'OneHot': + self._infer_OneHot, + 'Pad': + self._infer_Pad, + 'Range': + self._infer_Range, + 'Reciprocal': + self._pass_on_shape_and_type, + 'ReduceSum': + self._infer_ReduceSum, + 'ReduceProd': + self._infer_ReduceProd, + 'Reshape': + self._infer_Reshape, + 'Resize': + self._infer_Resize, + 'Round': + self._pass_on_shape_and_type, + 'Scan': + self._infer_Scan, + 'ScatterElements': + self._infer_ScatterElements, + 'SequenceAt': + self._infer_SequenceAt, + 'SequenceInsert': + self._infer_SequenceInsert, + 'Shape': + self._infer_Shape, + 'Size': + self._infer_Size, + 'Slice': + self._infer_Slice, + 'SoftmaxCrossEntropyLoss': + self._infer_SoftmaxCrossEntropyLoss, + 'SoftmaxCrossEntropyLossInternal': + self._infer_SoftmaxCrossEntropyLoss, + 'NegativeLogLikelihoodLossInternal': + self._infer_SoftmaxCrossEntropyLoss, + 'Split': + self._infer_Split, + 'SplitToSequence': + self._infer_SplitToSequence, + 'Squeeze': + self._infer_Squeeze, + 'Sub': + self._infer_symbolic_compute_ops, + 'Tile': + self._infer_Tile, + 'TopK': + self._infer_TopK, + 'Transpose': + self._infer_Transpose, + 'Unsqueeze': + self._infer_Unsqueeze, + 'Where': + self._infer_symbolic_compute_ops, + 'ZipMap': + self._infer_ZipMap, + 'Neg': + self._infer_symbolic_compute_ops, + # contrib ops: + 'Attention': + self._infer_Attention, + 'BiasGelu': + self._infer_BiasGelu, + 'EmbedLayerNormalization': + self._infer_EmbedLayerNormalization, + 'FastGelu': + self._infer_FastGelu, + 'Gelu': + self._infer_Gelu, + 'LayerNormalization': + self._infer_LayerNormalization, + 'LongformerAttention': + self._infer_LongformerAttention, + 'PythonOp': + self._infer_PythonOp, + 'SkipLayerNormalization': + self._infer_SkipLayerNormalization + } + self.aten_op_dispatcher_ = { + 'aten::embedding': self._infer_Gather, + 'aten::bitwise_or': self._infer_aten_bitwise_or, + 'aten::diagonal': self._infer_aten_diagonal, + 'aten::max_pool2d_with_indices': self._infer_aten_pool2d, + 'aten::multinomial': self._infer_aten_multinomial, + 'aten::unfold': self._infer_aten_unfold, + 'aten::argmax': self._infer_aten_argmax, + 'aten::avg_pool2d': self._infer_aten_pool2d, + 'aten::_adaptive_avg_pool2d': self._infer_aten_pool2d, + 'aten::binary_cross_entropy_with_logits': self._infer_aten_bce, + 'aten::numpy_T': self._infer_Transpose, + } + self.run_ = True + self.suggested_merge_ = {} + self.symbolic_dims_ = {} + self.input_symbols_ = {} + self.auto_merge_ = auto_merge + self.guess_output_rank_ = guess_output_rank + self.verbose_ = verbose + self.int_max_ = int_max + self.subgraph_id_ = 0 + self.prefix_ = prefix + + def _add_suggested_merge(self, symbols, apply=False): + assert all([(type(s) == str and s in self.symbolic_dims_) or + is_literal(s) for s in symbols]) + symbols = set(symbols) + for k, v in self.suggested_merge_.items(): + if k in symbols: + symbols.remove(k) + symbols.add(v) + map_to = None + # if there is literal, map to it first + for s in symbols: + if is_literal(s): + map_to = s + break + # when no literals, map to input symbolic dims, then existing symbolic dims + if map_to is None: + for s in symbols: + if s in self.input_symbols_: + map_to = s + break + if map_to is None: + for s in symbols: + if type(self.symbolic_dims_[s]) == sympy.Symbol: + map_to = s + break + # when nothing to map to, use the shorter one + if map_to is None: + if self.verbose_ > 0: + logger.warning( + 'Potential unsafe merge between symbolic expressions: ({})'. + format(','.join(symbols))) + symbols_list = list(symbols) + lens = [len(s) for s in symbols_list] + map_to = symbols_list[lens.index(min(lens))] + symbols.remove(map_to) + + for s in symbols: + if s == map_to: + continue + if is_literal(map_to) and is_literal(s): + assert int(map_to) == int(s) + self.suggested_merge_[s] = int(map_to) if is_literal( + map_to) else map_to + for k, v in self.suggested_merge_.items(): + if v == s: + self.suggested_merge_[k] = map_to + if apply and self.auto_merge_: + self._apply_suggested_merge() + + def _apply_suggested_merge(self, graph_input_only=False): + if not self.suggested_merge_: + return + for i in list(self.out_mp_.graph.input) + ( + [] if graph_input_only else list(self.out_mp_.graph.value_info)): + for d in i.type.tensor_type.shape.dim: + if d.dim_param in self.suggested_merge_: + v = self.suggested_merge_[d.dim_param] + if is_literal(v): + d.dim_value = int(v) + else: + d.dim_param = v + + def _preprocess(self, in_mp): + self.out_mp_ = onnx.ModelProto() + self.out_mp_.CopyFrom(in_mp) + self.graph_inputs_ = dict( + [(i.name, i) for i in list(self.out_mp_.graph.input)]) + self.initializers_ = dict( + [(i.name, i) for i in self.out_mp_.graph.initializer]) + self.known_vi_ = dict( + [(i.name, i) for i in list(self.out_mp_.graph.input)]) + self.known_vi_.update( + dict([(i.name, helper.make_tensor_value_info(i.name, i.data_type, + list(i.dims))) + for i in self.out_mp_.graph.initializer])) + + def _merge_symbols(self, dims): + if not all([type(d) == str for d in dims]): + if self.auto_merge_: + unique_dims = list(set(dims)) + is_int = [is_literal(d) for d in unique_dims] + assert sum( + is_int + ) <= 1 # if there are more than 1 unique ints, something is wrong + if sum(is_int) == 1: + int_dim = is_int.index(1) + if self.verbose_ > 0: + logger.debug('dim {} has been merged with value {}'. + format(unique_dims[:int_dim] + unique_dims[ + int_dim + 1:], unique_dims[int_dim])) + self._check_merged_dims(unique_dims, allow_broadcast=False) + return unique_dims[int_dim] + else: + if self.verbose_ > 0: + logger.debug('dim {} has been mergd with dim {}'.format( + unique_dims[1:], unique_dims[0])) + return dims[0] + else: + return None + if all([d == dims[0] for d in dims]): + return dims[0] + merged = [ + self.suggested_merge_[d] if d in self.suggested_merge_ else d + for d in dims + ] + if all([d == merged[0] for d in merged]): + assert merged[0] in self.symbolic_dims_ + return merged[0] + else: + return None + + # broadcast from right to left, and merge symbolic dims if needed + def _broadcast_shapes(self, shape1, shape2): + new_shape = [] + rank1 = len(shape1) + rank2 = len(shape2) + new_rank = max(rank1, rank2) + for i in range(new_rank): + dim1 = shape1[rank1 - 1 - i] if i < rank1 else 1 + dim2 = shape2[rank2 - 1 - i] if i < rank2 else 1 + if dim1 == 1 or dim1 == dim2: + new_dim = dim2 + elif dim2 == 1: + new_dim = dim1 + else: + new_dim = self._merge_symbols([dim1, dim2]) + if not new_dim: + # warning about unsupported broadcast when not auto merge + # note that auto merge has the risk of incorrectly merge symbols while one of them being 1 + # for example, 'a' = 1, 'b' = 5 at runtime is valid broadcasting, but with auto merge 'a' == 'b' + if self.auto_merge_: + self._add_suggested_merge([dim1, dim2], apply=True) + else: + logger.warning('unsupported broadcast between ' + str( + dim1) + ' ' + str(dim2)) + new_shape = [new_dim] + new_shape + return new_shape + + def _get_shape(self, node, idx): + name = node.input[idx] + if name in self.known_vi_: + vi = self.known_vi_[name] + return get_shape_from_value_info(vi) + else: + assert name in self.initializers_ + return list(self.initializers_[name].dims) + + def _get_shape_rank(self, node, idx): + return len(self._get_shape(node, idx)) + + def _get_sympy_shape(self, node, idx): + sympy_shape = [] + for d in self._get_shape(node, idx): + if type(d) == str: + sympy_shape.append(self.symbolic_dims_[d] if d in + self.symbolic_dims_ else sympy.Symbol( + d, integer=True, nonnegative=True)) + else: + assert None != d + sympy_shape.append(d) + return sympy_shape + + def _get_value(self, node, idx): + name = node.input[idx] + assert name in self.sympy_data_ or name in self.initializers_ + return self.sympy_data_[ + name] if name in self.sympy_data_ else numpy_helper.to_array( + self.initializers_[name]) + + def _try_get_value(self, node, idx): + if idx >= len(node.input): + return None + name = node.input[idx] + if name in self.sympy_data_ or name in self.initializers_: + return self._get_value(node, idx) + return None + + def _update_computed_dims(self, new_sympy_shape): + for i, new_dim in enumerate(new_sympy_shape): + if not is_literal(new_dim) and not type(new_dim) == str: + str_dim = str(new_dim) + if str_dim in self.suggested_merge_: + if is_literal(self.suggested_merge_[str_dim]): + continue # no need to create dim for literals + new_sympy_shape[i] = self.symbolic_dims_[ + self.suggested_merge_[str_dim]] + else: + # add new_dim if it's a computational expression + if not str(new_dim) in self.symbolic_dims_: + self.symbolic_dims_[str(new_dim)] = new_dim + + def _onnx_infer_single_node(self, node): + # skip onnx shape inference for some ops, as they are handled in _infer_* + skip_infer = node.op_type in [ + 'If', 'Loop', 'Scan', 'SplitToSequence', 'ZipMap', \ + # contrib ops + 'Attention', 'BiasGelu', \ + 'EmbedLayerNormalization', \ + 'FastGelu', 'Gelu', 'LayerNormalization', \ + 'LongformerAttention', \ + 'SkipLayerNormalization', \ + 'PythonOp' + ] + + if not skip_infer: + # Only pass initializers that satisfy the following condition: + # (1) Operator need value of some input for shape inference. + # For example, Unsqueeze in opset 13 uses the axes input to calculate shape of output. + # (2) opset version >= 9. In older version, initializer is required in graph input by onnx spec. + # (3) The initializer is not in graph input. The means the node input is "constant" in inference. + initializers = [] + if (get_opset(self.out_mp_) >= 9) and node.op_type in ['Unsqueeze']: + initializers = [ + self.initializers_[name] for name in node.input + if (name in self.initializers_ and + name not in self.graph_inputs_) + ] + + # run single node inference with self.known_vi_ shapes + tmp_graph = helper.make_graph( + [node], 'tmp', [self.known_vi_[i] for i in node.input if i], + [make_named_value_info(i) for i in node.output], initializers) + + self.tmp_mp_.graph.CopyFrom(tmp_graph) + + self.tmp_mp_ = shape_inference.infer_shapes(self.tmp_mp_) + + for i_o in range(len(node.output)): + o = node.output[i_o] + vi = self.out_mp_.graph.value_info.add() + if not skip_infer: + vi.CopyFrom(self.tmp_mp_.graph.output[i_o]) + else: + vi.name = o + self.known_vi_[o] = vi + + def _onnx_infer_subgraph(self, + node, + subgraph, + use_node_input=True, + inc_subgraph_id=True): + if self.verbose_ > 2: + logger.debug( + 'Inferencing subgraph of node {} with output({}...): {}'.format( + node.name, node.output[0], node.op_type)) + # node inputs are not passed directly to the subgraph + # it's up to the node dispatcher to prepare subgraph input + # for example, with Scan/Loop, subgraph input shape would be trimmed from node input shape + # besides, inputs in subgraph could shadow implicit inputs + subgraph_inputs = set( + [i.name for i in list(subgraph.initializer) + list(subgraph.input)]) + subgraph_implicit_input = set([ + name for name in self.known_vi_.keys() + if not name in subgraph_inputs + ]) + tmp_graph = helper.make_graph( + list(subgraph.node), 'tmp', + list(subgraph.input) + + [self.known_vi_[i] for i in subgraph_implicit_input], + [make_named_value_info(i.name) for i in subgraph.output]) + tmp_graph.initializer.extend([ + i for i in self.out_mp_.graph.initializer + if i.name in subgraph_implicit_input + ]) + tmp_graph.initializer.extend(subgraph.initializer) + self.tmp_mp_.graph.CopyFrom(tmp_graph) + + symbolic_shape_inference = SymbolicShapeInference( + self.int_max_, + self.auto_merge_, + self.guess_output_rank_, + self.verbose_, + prefix=self.prefix_ + '_' + str(self.subgraph_id_)) + if inc_subgraph_id: + self.subgraph_id_ += 1 + + all_shapes_inferred = False + symbolic_shape_inference._preprocess(self.tmp_mp_) + symbolic_shape_inference.suggested_merge_ = self.suggested_merge_.copy() + while symbolic_shape_inference.run_: + all_shapes_inferred = symbolic_shape_inference._infer_impl( + self.sympy_data_.copy()) + symbolic_shape_inference._update_output_from_vi() + if use_node_input: + # if subgraph uses node input, it needs to update to merged dims + subgraph.ClearField('input') + subgraph.input.extend( + symbolic_shape_inference.out_mp_.graph.input[:len(node.input)]) + subgraph.ClearField('output') + subgraph.output.extend(symbolic_shape_inference.out_mp_.graph.output) + subgraph.ClearField('value_info') + subgraph.value_info.extend( + symbolic_shape_inference.out_mp_.graph.value_info) + subgraph.ClearField('node') + subgraph.node.extend(symbolic_shape_inference.out_mp_.graph.node) + # for new symbolic dims from subgraph output, add to main graph symbolic dims + subgraph_shapes = [ + get_shape_from_value_info(o) + for o in symbolic_shape_inference.out_mp_.graph.output + ] + subgraph_new_symbolic_dims = set([ + d for s in subgraph_shapes if s for d in s + if type(d) == str and not d in self.symbolic_dims_ + ]) + new_dims = {} + for d in subgraph_new_symbolic_dims: + assert d in symbolic_shape_inference.symbolic_dims_ + new_dims[d] = symbolic_shape_inference.symbolic_dims_[d] + self.symbolic_dims_.update(new_dims) + return symbolic_shape_inference + + def _get_int_values(self, node, broadcast=False): + values = [self._try_get_value(node, i) for i in range(len(node.input))] + if all([v is not None for v in values]): + # some shape compute is in floating point, cast to int for sympy + for i, v in enumerate(values): + if type(v) != np.ndarray: + continue + if len(v.shape) > 1: + new_v = None # ignore value for rank > 1 + elif len(v.shape) == 0: + new_v = int(v.item()) + else: + assert len(v.shape) == 1 + new_v = [int(vv) for vv in v] + values[i] = new_v + values_len = [len(v) if type(v) == list else 0 for v in values] + max_len = max(values_len) + if max_len >= 1 and broadcast: + # broadcast + for i, v in enumerate(values): + if v is None: + continue # don't broadcast if value is unknown + if type(v) == list: + if len(v) < max_len: + values[i] = v * max_len + else: + assert len(v) == max_len + else: + values[i] = [v] * max_len + return values + + def _compute_on_sympy_data(self, node, op_func): + assert len(node.output) == 1 + values = self._get_int_values(node, broadcast=True) + if all([v is not None for v in values]): + is_list = [type(v) == list for v in values] + as_list = any(is_list) + if as_list: + self.sympy_data_[node.output[ + 0]] = [op_func(vs) for vs in zip(*values)] + else: + self.sympy_data_[node.output[0]] = op_func(values) + + def _pass_on_sympy_data(self, node): + assert len( + node. + input) == 1 or node.op_type in ['Reshape', 'Unsqueeze', 'Squeeze'] + self._compute_on_sympy_data(node, lambda x: x[0]) + + def _pass_on_shape_and_type(self, node): + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], self.known_vi_[ + node.input[0]].type.tensor_type.elem_type, + self._get_shape(node, 0))) + + def _new_symbolic_dim(self, prefix, dim): + new_dim = '{}_d{}'.format(prefix, dim) + if new_dim in self.suggested_merge_: + v = self.suggested_merge_[new_dim] + new_symbolic_dim = sympy.Integer(int(v)) if is_literal(v) else v + else: + new_symbolic_dim = sympy.Symbol( + new_dim, integer=True, nonnegative=True) + self.symbolic_dims_[new_dim] = new_symbolic_dim + return new_symbolic_dim + + def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0): + return self._new_symbolic_dim('{}{}_{}_o{}_'.format( + node.op_type, self.prefix_, + list(self.out_mp_.graph.node).index(node), out_idx), dim) + + def _new_symbolic_shape(self, rank, node, out_idx=0): + return [ + self._new_symbolic_dim_from_output(node, out_idx, i) + for i in range(rank) + ] + + def _compute_conv_pool_shape(self, node): + sympy_shape = self._get_sympy_shape(node, 0) + if len(node.input) > 1: + W_shape = self._get_sympy_shape(node, 1) + rank = len(W_shape) - 2 # number of spatial axes + kernel_shape = W_shape[-rank:] + sympy_shape[1] = W_shape[0] + else: + W_shape = None + kernel_shape = get_attribute(node, 'kernel_shape') + rank = len(kernel_shape) + + assert len(sympy_shape) == rank + 2 + + # only need to symbolic shape inference if input has symbolic dims in spatial axes + is_symbolic_dims = [not is_literal(i) for i in sympy_shape[-rank:]] + + if not any(is_symbolic_dims): + shape = get_shape_from_value_info(self.known_vi_[node.output[0]]) + if len(shape) > 0: + assert len(sympy_shape) == len(shape) + sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]] + return sympy_shape + + dilations = get_attribute(node, 'dilations', [1] * rank) + strides = get_attribute(node, 'strides', [1] * rank) + effective_kernel_shape = [(k - 1) * d + 1 + for k, d in zip(kernel_shape, dilations)] + pads = get_attribute(node, 'pads') + if pads is None: + pads = [0] * (2 * rank) + auto_pad = get_attribute(node, 'auto_pad', + b'NOTSET').decode('utf-8') + if auto_pad != 'VALID' and auto_pad != 'NOTSET': + try: + residual = [ + sympy.Mod(d, s) + for d, s in zip(sympy_shape[-rank:], strides) + ] + total_pads = [ + max(0, (k - s) if r == 0 else (k - r)) for k, s, r in + zip(effective_kernel_shape, strides, residual) + ] + except TypeError: # sympy may throw TypeError: cannot determine truth value of Relational + total_pads = [ + max(0, (k - s)) + for k, s in zip(effective_kernel_shape, strides) + ] # assuming no residual if sympy throws error + elif auto_pad == 'VALID': + total_pads = [] + else: + total_pads = [0] * rank + else: + assert len(pads) == 2 * rank + total_pads = [p1 + p2 for p1, p2 in zip(pads[:rank], pads[rank:])] + + ceil_mode = get_attribute(node, 'ceil_mode', 0) + for i in range(rank): + effective_input_size = sympy_shape[-rank + i] + if len(total_pads) > 0: + effective_input_size = effective_input_size + total_pads[i] + if ceil_mode: + strided_kernel_positions = sympy.ceiling( + (effective_input_size - effective_kernel_shape[i]) / + strides[i]) + else: + strided_kernel_positions = ( + effective_input_size - effective_kernel_shape[i] + ) // strides[i] + sympy_shape[-rank + i] = strided_kernel_positions + 1 + return sympy_shape + + def _check_merged_dims(self, dims, allow_broadcast=True): + if allow_broadcast: + dims = [d for d in dims if not (is_literal(d) and int(d) <= 1)] + if not all([d == dims[0] for d in dims]): + self._add_suggested_merge(dims, apply=True) + + def _compute_matmul_shape(self, node, output_dtype=None): + lhs_shape = self._get_shape(node, 0) + rhs_shape = self._get_shape(node, 1) + lhs_rank = len(lhs_shape) + rhs_rank = len(rhs_shape) + lhs_reduce_dim = 0 + rhs_reduce_dim = 0 + assert lhs_rank > 0 and rhs_rank > 0 + if lhs_rank == 1 and rhs_rank == 1: + new_shape = [] + elif lhs_rank == 1: + rhs_reduce_dim = -2 + new_shape = rhs_shape[:rhs_reduce_dim] + [rhs_shape[-1]] + elif rhs_rank == 1: + lhs_reduce_dim = -1 + new_shape = lhs_shape[:lhs_reduce_dim] + else: + lhs_reduce_dim = -1 + rhs_reduce_dim = -2 + new_shape = self._broadcast_shapes( + lhs_shape[:-2], + rhs_shape[:-2]) + [lhs_shape[-2]] + [rhs_shape[-1]] + # merge reduce dim + self._check_merged_dims( + [lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]], + allow_broadcast=False) + if output_dtype is None: + # infer output_dtype from input type when not specified + output_dtype = self.known_vi_[node.input[ + 0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], output_dtype, + new_shape)) + + def _fuse_tensor_type(self, node, out_idx, dst_type, src_type): + ''' + update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches + ''' + dst_tensor_type = dst_type.sequence_type.elem_type.tensor_type if is_sequence( + dst_type) else dst_type.tensor_type + src_tensor_type = src_type.sequence_type.elem_type.tensor_type if is_sequence( + src_type) else src_type.tensor_type + if dst_tensor_type.elem_type != src_tensor_type.elem_type: + node_id = node.name if node.name else node.op_type + raise ValueError( + f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: " + f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs " + f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}" + ) + if dst_tensor_type.HasField('shape'): + for di, ds in enumerate( + zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)): + if ds[0] != ds[1]: + # create a new symbolic dimension for node/out_idx/mismatch dim id in dst_tensor_type for tensor_type + # for sequence_type, clear the dimension + new_dim = onnx.TensorShapeProto.Dimension() + if not is_sequence(dst_type): + new_dim.dim_param = str( + self._new_symbolic_dim_from_output(node, out_idx, + di)) + dst_tensor_type.shape.dim[di].CopyFrom(new_dim) + else: + dst_tensor_type.CopyFrom(src_tensor_type) + + def _infer_ArrayFeatureExtractor(self, node): + data_shape = self._get_shape(node, 0) + indices_shape = self._get_shape(node, 1) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], self.known_vi_[ + node.input[0]].type.tensor_type.elem_type, data_shape[:-1] + + indices_shape)) + + def _infer_symbolic_compute_ops(self, node): + funcs = { + 'Add': + lambda l: l[0] + l[1], + 'Div': + lambda l: l[0] // l[1], # integer div in sympy + 'Equal': + lambda l: l[0] == l[1], + 'Floor': + lambda l: sympy.floor(l[0]), + 'Max': + lambda l: l[1] if is_literal(l[0]) and int(l[0]) < -self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])), + 'Min': + lambda l: l[1] if is_literal(l[0]) and int(l[0]) > self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])), + 'Mul': + lambda l: l[0] * l[1], + 'Sub': + lambda l: l[0] - l[1], + 'Where': + lambda l: l[1] if l[0] else l[2], + 'Neg': + lambda l: -l[0] + } + assert node.op_type in funcs + self._compute_on_sympy_data(node, funcs[node.op_type]) + + def _infer_Cast(self, node): + self._pass_on_sympy_data(node) + + def _infer_CategoryMapper(self, node): + input_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type + if input_type == onnx.TensorProto.STRING: + output_type = onnx.TensorProto.INT64 + else: + output_type = onnx.TensorProto.STRING + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], output_type, + self._get_shape(node, 0))) + + def _infer_Compress(self, node): + input_shape = self._get_shape(node, 0) + # create a new symbolic dimension for Compress output + compress_len = str(self._new_symbolic_dim_from_output(node)) + axis = get_attribute(node, 'axis') + if axis == None: + # when axis is not specified, input is flattened before compress so output is 1D + output_shape = [compress_len] + else: + output_shape = input_shape + output_shape[handle_negative_axis(axis, len( + input_shape))] = compress_len + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], self.known_vi_[ + node.input[0]].type.tensor_type.elem_type, output_shape)) + + def _infer_Concat(self, node): + if any([ + i in self.sympy_data_ or i in self.initializers_ + for i in node.input + ]): + values = self._get_int_values(node) + print("=======", values, node.name, get_attribute(node, 'axis')) + if all([v is not None for v in values]): + axis = get_attribute(node, 'axis') + if axis < 0: + axis = axis + len(values[0]) + assert 0 == axis + self.sympy_data_[node.output[0]] = [] + for i in range(len(node.input)): + value = values[i] + if type(value) == list: + self.sympy_data_[node.output[0]].extend(value) + else: + self.sympy_data_[node.output[0]].append(value) + + sympy_shape = self._get_sympy_shape(node, 0) + axis = handle_negative_axis( + get_attribute(node, 'axis'), len(sympy_shape)) + for i_idx in range(1, len(node.input)): + input_shape = self._get_sympy_shape(node, i_idx) + if input_shape: + sympy_shape[axis] = sympy_shape[axis] + input_shape[axis] + self._update_computed_dims(sympy_shape) + # merge symbolic dims for non-concat axes + for d in range(len(sympy_shape)): + if d == axis: + continue + dims = [ + self._get_shape(node, i_idx)[d] + for i_idx in range(len(node.input)) + if self._get_shape(node, i_idx) + ] + if all([d == dims[0] for d in dims]): + continue + merged = self._merge_symbols(dims) + if type(merged) == str: + sympy_shape[d] = self.symbolic_dims_[merged] if merged else None + else: + sympy_shape[d] = merged + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], self.known_vi_[node.input[0]].type.tensor_type. + elem_type, get_shape_from_sympy_shape(sympy_shape))) + + def _infer_ConcatFromSequence(self, node): + seq_shape = self._get_shape(node, 0) + new_axis = 1 if get_attribute(node, 'new_axis') else 0 + axis = handle_negative_axis( + get_attribute(node, 'axis'), len(seq_shape) + new_axis) + concat_dim = str(self._new_symbolic_dim_from_output(node, 0, axis)) + new_shape = seq_shape + if new_axis: + new_shape = seq_shape[:axis] + [concat_dim] + seq_shape[axis:] + else: + new_shape[axis] = concat_dim + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], self.known_vi_[node.input[0]] + .type.sequence_type.elem_type.tensor_type.elem_type, new_shape)) + + def _infer_Constant(self, node): + t = get_attribute(node, 'value') + self.sympy_data_[node.output[0]] = numpy_helper.to_array(t) + + def _infer_ConstantOfShape(self, node): + sympy_shape = self._get_int_values(node)[0] + vi = self.known_vi_[node.output[0]] + if sympy_shape is not None: + if type(sympy_shape) != list: + sympy_shape = [sympy_shape] + self._update_computed_dims(sympy_shape) + # update sympy data if output type is int, and shape is known + if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all( + [is_literal(x) for x in sympy_shape]): + self.sympy_data_[node.output[0]] = np.ones( + [int(x) for x in sympy_shape], + dtype=np.int64) * numpy_helper.to_array( + get_attribute(node, 'value', 0)) + else: + # create new dynamic shape + # note input0 is a 1D vector of shape, the new symbolic shape has the rank of the shape vector length + sympy_shape = self._new_symbolic_shape( + self._get_shape(node, 0)[0], node) + + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape))) + + def _infer_Conv(self, node): + sympy_shape = self._compute_conv_pool_shape(node) + self._update_computed_dims(sympy_shape) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape))) + + def _infer_Einsum(self, node): + # ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275 + equation = get_attribute(node, 'equation') + equation = equation.replace(b' ', b'') + mid_index = equation.find(b'->') + left_equation = equation[:mid_index] if mid_index != -1 else equation + + num_operands = 0 + num_ellipsis = 0 + num_ellipsis_indices = 0 + + letter_to_dim = {} + + terms = left_equation.split(b',') + for term in terms: + ellipsis_index = term.find(b'...') + shape = self._get_shape(node, num_operands) + rank = len(shape) + if ellipsis_index != -1: + if num_ellipsis == 0: + num_ellipsis_indices = rank - len(term) + 3 + num_ellipsis = num_ellipsis + 1 + for i in range(1, rank + 1): + letter = term[-i] + if letter != 46: # letter != b'.' + dim = shape[-i] + if letter not in letter_to_dim.keys(): + letter_to_dim[letter] = dim + elif type(dim) != sympy.Symbol: + letter_to_dim[letter] = dim + num_operands = num_operands + 1 + + new_sympy_shape = [] + from collections import OrderedDict + num_letter_occurrences = OrderedDict() + if mid_index != -1: + right_equation = equation[mid_index + 2:] + right_ellipsis_index = right_equation.find(b'...') + if right_ellipsis_index != -1: + for i in range(num_ellipsis_indices): + new_sympy_shape.append(shape[i]) + for c in right_equation: + if c != 46: # c != b'.' + new_sympy_shape.append(letter_to_dim[c]) + else: + for i in range(num_ellipsis_indices): + new_sympy_shape.append(shape[i]) + for c in left_equation: + if c != 44 and c != 46: # c != b',' and c != b'.': + if c in num_letter_occurrences: + num_letter_occurrences[c] = num_letter_occurrences[ + c] + 1 + else: + num_letter_occurrences[c] = 1 + for key, value in num_letter_occurrences.items(): + if value == 1: + new_sympy_shape.append(letter_to_dim[key]) + + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], output_dtype, + new_sympy_shape)) + + def _infer_Expand(self, node): + expand_to_shape = as_list(self._try_get_value(node, 1), keep_none=True) + if expand_to_shape is not None: + # new_shape's dim can come from shape value + self._update_computed_dims(expand_to_shape) + shape = self._get_shape(node, 0) + new_shape = self._broadcast_shapes( + shape, get_shape_from_sympy_shape(expand_to_shape)) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], self.known_vi_[ + node.input[0]].type.tensor_type.elem_type, new_shape)) + + def _infer_Gather(self, node): + data_shape = self._get_shape(node, 0) + axis = handle_negative_axis( + get_attribute(node, 'axis', 0), len(data_shape)) + indices_shape = self._get_shape(node, 1) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], self.known_vi_[ + node.input[0]].type.tensor_type.elem_type, data_shape[:axis] + + indices_shape + data_shape[axis + + 1:])) + # for 1D input, do some sympy compute + if node.input[0] in self.sympy_data_ and len( + data_shape) == 1 and 0 == get_attribute(node, 'axis', 0): + idx = self._try_get_value(node, 1) + if idx is not None: + data = self.sympy_data_[node.input[0]] + if type(data) == list: + if type(idx) == np.ndarray and len(idx.shape) == 1: + self.sympy_data_[node.output[ + 0]] = [data[int(i)] for i in idx] + else: + self.sympy_data_[node.output[0]] = data[int(idx)] + else: + assert idx == 0 or idx == -1 + self.sympy_data_[node.output[0]] = data + + def _infer_GatherElements(self, node): + indices_shape = self._get_shape(node, 1) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], self.known_vi_[ + node.input[0]].type.tensor_type.elem_type, indices_shape)) + + def _infer_GatherND(self, node): + data_shape = self._get_shape(node, 0) + data_rank = len(data_shape) + indices_shape = self._get_shape(node, 1) + indices_rank = len(indices_shape) + last_index_dimension = indices_shape[-1] + assert is_literal( + last_index_dimension) and last_index_dimension <= data_rank + new_shape = indices_shape[:-1] + data_shape[last_index_dimension:] + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], self.known_vi_[ + node.input[0]].type.tensor_type.elem_type, new_shape)) + + def _infer_If(self, node): + # special case for constant condition, in case there are mismatching shape from the non-executed branch + subgraphs = [ + get_attribute(node, 'then_branch'), get_attribute(node, + 'else_branch') + ] + cond = self._try_get_value(node, 0) + if cond is not None: + if as_scalar(cond) > 0: + subgraphs[1].CopyFrom(subgraphs[0]) + else: + subgraphs[0].CopyFrom(subgraphs[1]) + + for i_sub, subgraph in enumerate(subgraphs): + subgraph_infer = self._onnx_infer_subgraph( + node, subgraph, use_node_input=False) + for i_out in range(len(node.output)): + vi = self.known_vi_[node.output[i_out]] + if i_sub == 0: + vi.CopyFrom(subgraph.output[i_out]) + vi.name = node.output[i_out] + else: + self._fuse_tensor_type(node, i_out, vi.type, + subgraph.output[i_out].type) + + # pass on sympy data from subgraph, if cond is constant + if cond is not None and i_sub == (0 if as_scalar(cond) > 0 else + 1): + if subgraph.output[ + i_out].name in subgraph_infer.sympy_data_: + self.sympy_data_[vi.name] = subgraph_infer.sympy_data_[ + subgraph.output[i_out].name] + + def _infer_Loop(self, node): + subgraph = get_attribute(node, 'body') + assert len(subgraph.input) == len(node.input) + num_loop_carried = len( + node.input) - 2 # minus the length and initial loop condition + # when sequence_type is used as loop carried input + # needs to run subgraph infer twice if the tensor shape in sequence contains None + for i, si in enumerate(subgraph.input): + si_name = si.name + si.CopyFrom(self.known_vi_[node.input[i]]) + si.name = si_name + + self._onnx_infer_subgraph(node, subgraph) + + # check subgraph input/output for shape changes in loop carried variables + # for tensor_type, create new symbolic dim when changing, i.e., output = Concat(input, a) + # for sequence_type, propagate from output to input + need_second_infer = False + for i_out in range(1, num_loop_carried + 1): + so = subgraph.output[i_out] + so_shape = get_shape_from_value_info(so) + if is_sequence(so.type): + if so_shape and None in so_shape: + # copy shape from output to input + # note that loop input is [loop_len, cond, input_0, input_1, ...] + # while loop output is [cond, output_0, output_1, ...] + subgraph.input[i_out + + 1].type.sequence_type.elem_type.CopyFrom( + so.type.sequence_type.elem_type) + need_second_infer = True + else: + si = subgraph.input[i_out + 1] + si_shape = get_shape_from_value_info(si) + for di, dims in enumerate(zip(si_shape, so_shape)): + if dims[0] != dims[1]: + new_dim = onnx.TensorShapeProto.Dimension() + new_dim.dim_param = str( + self._new_symbolic_dim_from_output(node, i_out, di)) + si.type.tensor_type.shape.dim[di].CopyFrom(new_dim) + so.type.tensor_type.shape.dim[di].CopyFrom(new_dim) + need_second_infer = True + + if need_second_infer: + if self.verbose_ > 2: + logger.debug( + "Rerun Loop: {}({}...), because of sequence in loop carried variables". + format(node.name, node.output[0])) + self._onnx_infer_subgraph(node, subgraph, inc_subgraph_id=False) + + # create a new symbolic dimension for iteration dependent dimension + loop_iter_dim = str(self._new_symbolic_dim_from_output(node)) + for i in range(len(node.output)): + vi = self.known_vi_[node.output[i]] + vi.CopyFrom(subgraph.output[ + i + + 1]) # first subgraph output is condition, not in node output + if i >= num_loop_carried: + assert not is_sequence( + vi.type) # TODO: handle loop accumulation in sequence_type + subgraph_vi_dim = subgraph.output[i + + 1].type.tensor_type.shape.dim + vi.type.tensor_type.shape.ClearField('dim') + vi_dim = vi.type.tensor_type.shape.dim + vi_dim.add().dim_param = loop_iter_dim + vi_dim.extend(list(subgraph_vi_dim)) + vi.name = node.output[i] + + def _infer_MatMul(self, node): + self._compute_matmul_shape(node) + + def _infer_MatMulInteger(self, node): + self._compute_matmul_shape(node, onnx.TensorProto.INT32) + + def _infer_NonMaxSuppression(self, node): + selected = str(self._new_symbolic_dim_from_output(node)) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[ + 0], onnx.TensorProto.INT64, [selected, 3])) + + def _infer_NonZero(self, node): + input_rank = self._get_shape_rank(node, 0) + # create a new symbolic dimension for NonZero output + nz_len = str(self._new_symbolic_dim_from_output(node, 0, 1)) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[ + 0], vi.type.tensor_type.elem_type, [input_rank, nz_len])) + + def _infer_OneHot(self, node): + sympy_shape = self._get_sympy_shape(node, 0) + depth = self._try_get_value(node, 1) + axis = get_attribute(node, 'axis', -1) + axis = handle_negative_axis(axis, len(sympy_shape) + 1) + new_shape = get_shape_from_sympy_shape(sympy_shape[:axis] + [ + self._new_symbolic_dim_from_output(node) + if not is_literal(depth) else depth + ] + sympy_shape[axis:]) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], self.known_vi_[ + node.input[2]].type.tensor_type.elem_type, new_shape)) + + def _infer_Pad(self, node): + if get_opset(self.out_mp_) <= 10: + pads = get_attribute(node, 'pads') + else: + pads = self._try_get_value(node, 1) + + sympy_shape = self._get_sympy_shape(node, 0) + rank = len(sympy_shape) + + if pads is not None: + assert len(pads) == 2 * rank + new_sympy_shape = [ + d + pad_up + pad_down for d, pad_up, pad_down in + zip(sympy_shape, pads[:rank], pads[rank:]) + ] + self._update_computed_dims(new_sympy_shape) + else: + # dynamic pads, create new symbolic dimensions + new_sympy_shape = self._new_symbolic_shape(rank, node) + output_tp = self.known_vi_[node.input[0]].type.tensor_type.elem_type + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[ + 0], output_tp, get_shape_from_sympy_shape(new_sympy_shape))) + + def _infer_Pool(self, node): + sympy_shape = self._compute_conv_pool_shape(node) + self._update_computed_dims(sympy_shape) + for o in node.output: + if not o: + continue + vi = self.known_vi_[o] + vi.CopyFrom( + helper.make_tensor_value_info(o, vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape( + sympy_shape))) + + def _infer_aten_bitwise_or(self, node): + shape0 = self._get_shape(node, 0) + shape1 = self._get_shape(node, 1) + new_shape = self._broadcast_shapes(shape0, shape1) + t0 = self.known_vi_[node.input[0]] + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[ + 0], t0.type.tensor_type.elem_type, new_shape)) + + def _infer_aten_diagonal(self, node): + sympy_shape = self._get_sympy_shape(node, 0) + rank = len(sympy_shape) + offset = self._try_get_value(node, 1) + dim1 = self._try_get_value(node, 2) + dim2 = self._try_get_value(node, 3) + + assert offset is not None and dim1 is not None and dim2 is not None + dim1 = handle_negative_axis(dim1, rank) + dim2 = handle_negative_axis(dim2, rank) + + new_shape = [] + for dim, val in enumerate(sympy_shape): + if dim not in [dim1, dim2]: + new_shape.append(val) + + shape1 = sympy_shape[dim1] + shape2 = sympy_shape[dim2] + if offset >= 0: + diag_shape = sympy.Max(0, sympy.Min(shape1, shape2 - offset)) + else: + diag_shape = sympy.Max(0, sympy.Min(shape1 + offset, shape2)) + new_shape.append(diag_shape) + + if node.output[0]: + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], self.known_vi_[ + node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape( + new_shape))) + + def _infer_aten_multinomial(self, node): + sympy_shape = self._get_sympy_shape(node, 0) + rank = len(sympy_shape) + assert rank in [1, 2] + num_samples = self._try_get_value(node, 1) + di = rank - 1 + last_dim = num_samples if num_samples else str( + self._new_symbolic_dim_from_output(node, 0, di)) + output_shape = sympy_shape[:-1] + [last_dim] + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], onnx.TensorProto.INT64, + get_shape_from_sympy_shape(output_shape))) + + def _infer_aten_pool2d(self, node): + sympy_shape = self._get_sympy_shape(node, 0) + assert len(sympy_shape) == 4 + sympy_shape[-2:] = [ + self._new_symbolic_dim_from_output(node, 0, i) for i in [2, 3] + ] + self._update_computed_dims(sympy_shape) + for i, o in enumerate(node.output): + if not o: + continue + vi = self.known_vi_[o] + elem_type = onnx.TensorProto.INT64 if i == 1 else self.known_vi_[ + node.input[0]].type.tensor_type.elem_type + vi.CopyFrom( + helper.make_tensor_value_info( + o, elem_type, get_shape_from_sympy_shape(sympy_shape))) + + def _infer_aten_unfold(self, node): + sympy_shape = self._get_sympy_shape(node, 0) + dimension = self._try_get_value(node, 1) + size = self._try_get_value(node, 2) + step = self._try_get_value(node, 3) + if dimension is not None and size is not None and step is not None: + assert dimension < len(sympy_shape) + sympy_shape[dimension] = (sympy_shape[dimension] - size) // step + 1 + sympy_shape.append(size) + else: + rank = len(sympy_shape) + sympy_shape = self._new_symbolic_shape(rank + 1, node) + self._update_computed_dims(sympy_shape) + if node.output[0]: + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], self.known_vi_[ + node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape( + sympy_shape))) + + def _infer_aten_argmax(self, node): + new_shape = None + if node.input[1] == '': + # The argmax of the flattened input is returned. + new_shape = [] + else: + dim = self._try_get_value(node, 1) + keepdim = self._try_get_value(node, 2) + if keepdim is not None: + sympy_shape = self._get_sympy_shape(node, 0) + if dim is not None: + dim = handle_negative_axis(dim, len(sympy_shape)) + if keepdim: + sympy_shape[dim] = 1 + else: + del sympy_shape[dim] + else: + rank = len(sympy_shape) + sympy_shape = self._new_symbolic_shape(rank if keepdim else + rank - 1, node) + self._update_computed_dims(sympy_shape) + new_shape = get_shape_from_sympy_shape(sympy_shape) + if node.output[0] and new_shape is not None: + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[ + 0], onnx.TensorProto.INT64, new_shape)) + + def _infer_aten_bce(self, node): + reduction = self._try_get_value(node, 4) + if reduction is None: + reduction = 1 + elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + if reduction == 0: + vi.type.tensor_type.elem_type = elem_type + vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto()) + else: + vi.CopyFrom( + helper.make_tensor_value_info(vi.name, elem_type, + self._get_shape(node, 0))) + + def _infer_BatchNormalization(self, node): + self._propagate_shape_and_type(node) + + # this works for opsets < 14 and 14 since we check i < len(node.output) in the loop + for i in [1, 2, 3, 4]: + if i < len(node.output) and node.output[i] != "": + # all of these parameters have the same shape as the 1st input + self._propagate_shape_and_type( + node, input_index=1, output_index=i) + + def _infer_Range(self, node): + vi = self.known_vi_[node.output[0]] + input_data = self._get_int_values(node) + if all([i is not None for i in input_data]): + start = as_scalar(input_data[0]) + limit = as_scalar(input_data[1]) + delta = as_scalar(input_data[2]) + new_sympy_shape = [ + sympy.Max(sympy.ceiling((limit - start) / delta), 0) + ] + else: + new_sympy_shape = [self._new_symbolic_dim_from_output(node)] + self._update_computed_dims(new_sympy_shape) + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], self.known_vi_[node.input[0]].type.tensor_type. + elem_type, get_shape_from_sympy_shape(new_sympy_shape))) + + def _infer_ReduceSum(self, node): + keep_dims = get_attribute(node, 'keepdims', 1) + if get_opset(self.out_mp_) >= 13 and len(node.input) > 1: + # ReduceSum changes axes to input[1] in opset 13 + axes = self._try_get_value(node, 1) + vi = self.known_vi_[node.output[0]] + if axes is None: + assert keep_dims # can only handle keep_dims==True when axes is unknown, by generating new ranks + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], self.known_vi_[node.input[ + 0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape( + self._new_symbolic_shape( + self._get_shape_rank(node, 0), node)))) + else: + shape = self._get_shape(node, 0) + output_shape = [] + axes = [handle_negative_axis(a, len(shape)) for a in axes] + for i, d in enumerate(shape): + if i in axes: + if keep_dims: + output_shape.append(1) + else: + output_shape.append(d) + vi.CopyFrom( + helper.make_tensor_value_info(node.output[ + 0], self.known_vi_[node.input[ + 0]].type.tensor_type.elem_type, output_shape)) + + def _infer_ReduceProd(self, node): + axes = get_attribute(node, 'axes') + keep_dims = get_attribute(node, 'keepdims', 1) + if keep_dims == 0 and axes == [0]: + data = self._get_int_values(node)[0] + if data is not None: + self.sympy_data_[node.output[0]] = sympy_reduce_product(data) + + def _infer_Reshape(self, node): + shape_value = self._try_get_value(node, 1) + vi = self.known_vi_[node.output[0]] + if shape_value is None: + shape_shape = self._get_shape(node, 1) + assert len(shape_shape) == 1 + shape_rank = shape_shape[0] + assert is_literal(shape_rank) + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape( + self._new_symbolic_shape(shape_rank, node)))) + else: + input_sympy_shape = self._get_sympy_shape(node, 0) + total = int(1) + for d in input_sympy_shape: + total = total * d + new_sympy_shape = [] + deferred_dim_idx = -1 + non_deferred_size = int(1) + for i, d in enumerate(shape_value): + if type(d) == sympy.Symbol: + new_sympy_shape.append(d) + elif d == 0: + new_sympy_shape.append(input_sympy_shape[i]) + non_deferred_size = non_deferred_size * input_sympy_shape[i] + else: + new_sympy_shape.append(d) + if d == -1: + deferred_dim_idx = i + elif d != 0: + non_deferred_size = non_deferred_size * d + + assert new_sympy_shape.count(-1) < 2 + if -1 in new_sympy_shape: + new_dim = total // non_deferred_size + new_sympy_shape[deferred_dim_idx] = new_dim + + self._update_computed_dims(new_sympy_shape) + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape))) + + self._pass_on_sympy_data(node) + + def _infer_Resize(self, node): + vi = self.known_vi_[node.output[0]] + input_sympy_shape = self._get_sympy_shape(node, 0) + if get_opset(self.out_mp_) <= 10: + scales = self._try_get_value(node, 1) + if scales is not None: + new_sympy_shape = [ + sympy.simplify(sympy.floor(d * s)) + for d, s in zip(input_sympy_shape, scales) + ] + self._update_computed_dims(new_sympy_shape) + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], self.known_vi_[node.input[ + 0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape))) + else: + roi = self._try_get_value(node, 1) + scales = self._try_get_value(node, 2) + sizes = self._try_get_value(node, 3) + if sizes is not None: + new_sympy_shape = [ + sympy.simplify(sympy.floor(s)) for s in sizes + ] + self._update_computed_dims(new_sympy_shape) + elif scales is not None: + rank = len(scales) + if get_attribute(node, 'coordinate_transformation_mode' + ) == 'tf_crop_and_resize': + assert len(roi) == 2 * rank + roi_start = list(roi)[:rank] + roi_end = list(roi)[rank:] + else: + roi_start = [0] * rank + roi_end = [1] * rank + scales = list(scales) + new_sympy_shape = [ + sympy.simplify(sympy.floor(d * (end - start) * scale)) + for d, start, end, scale in + zip(input_sympy_shape, roi_start, roi_end, scales) + ] + self._update_computed_dims(new_sympy_shape) + else: + new_sympy_shape = self._new_symbolic_shape( + self._get_shape_rank(node, 0), node) + + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], self.known_vi_[ + node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape( + new_sympy_shape))) + + def _infer_Scan(self, node): + subgraph = get_attribute(node, 'body') + num_scan_inputs = get_attribute(node, 'num_scan_inputs') + scan_input_axes = get_attribute(node, 'scan_input_axes', + [0] * num_scan_inputs) + num_scan_states = len(node.input) - num_scan_inputs + scan_input_axes = [ + handle_negative_axis( + ax, self._get_shape_rank(node, i + num_scan_states)) + for i, ax in enumerate(scan_input_axes) + ] + # We may have cases where the subgraph has optionial inputs that appear in both subgraph's input and initializer, + # but not in the node's input. In such cases, the input model might be invalid, but let's skip those optional inputs. + assert len(subgraph.input) >= len(node.input) + subgraph_inputs = subgraph.input[:len(node.input)] + for i, si in enumerate(subgraph_inputs): + subgraph_name = si.name + si.CopyFrom(self.known_vi_[node.input[i]]) + if i >= num_scan_states: + scan_input_dim = si.type.tensor_type.shape.dim + scan_input_dim.remove( + scan_input_dim[scan_input_axes[i - num_scan_states]]) + si.name = subgraph_name + self._onnx_infer_subgraph(node, subgraph) + num_scan_outputs = len(node.output) - num_scan_states + scan_output_axes = get_attribute(node, 'scan_output_axes', + [0] * num_scan_outputs) + scan_input_dim = get_shape_from_type_proto( + self.known_vi_[node.input[-1]].type)[scan_input_axes[-1]] + for i, o in enumerate(node.output): + vi = self.known_vi_[o] + if i >= num_scan_states: + shape = get_shape_from_type_proto(subgraph.output[i].type) + new_dim = handle_negative_axis( + scan_output_axes[i - num_scan_states], len(shape) + 1) + shape = shape[:new_dim] + [scan_input_dim] + shape[new_dim:] + vi.CopyFrom( + helper.make_tensor_value_info(o, subgraph.output[ + i].type.tensor_type.elem_type, shape)) + else: + vi.CopyFrom(subgraph.output[i]) + vi.name = o + + def _infer_ScatterElements(self, node): + data_shape = self._get_shape(node, 0) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], self.known_vi_[ + node.input[0]].type.tensor_type.elem_type, data_shape)) + + def _infer_SequenceAt(self, node): + # need to create new symbolic dimension if sequence shape has None: + seq_shape = self._get_shape(node, 0) + vi = self.known_vi_[node.output[0]] + if seq_shape is not None: + for di, d in enumerate(seq_shape): + if d is not None: + continue + new_dim = onnx.TensorShapeProto.Dimension() + new_dim.dim_param = str( + self._new_symbolic_dim_from_output(node, 0, di)) + vi.type.tensor_type.shape.dim[di].CopyFrom(new_dim) + + def _infer_SequenceInsert(self, node): + # workaround bug in onnx's shape inference + vi_seq = self.known_vi_[node.input[0]] + vi_tensor = self.known_vi_[node.input[1]] + vi_out_seq = self.known_vi_[node.output[0]] + vi_out_seq.CopyFrom(vi_seq) + vi_out_seq.name = node.output[0] + self._fuse_tensor_type(node, 0, vi_out_seq.type, vi_tensor.type) + + def _infer_Shape(self, node): + self.sympy_data_[node.output[0]] = self._get_sympy_shape(node, 0) + + def _infer_Size(self, node): + sympy_shape = self._get_sympy_shape(node, 0) + self.sympy_data_[node.output[0]] = sympy_reduce_product(sympy_shape) + self.known_vi_[node.output[0]].CopyFrom( + helper.make_tensor_value_info(node.output[0], + onnx.TensorProto.INT64, [])) + + def _infer_Slice(self, node): + def less_equal(x, y): + try: + return bool(x <= y) + except TypeError: + pass + try: + return bool(y >= x) + except TypeError: + pass + try: + return bool(-x >= -y) + except TypeError: + pass + try: + return bool(-y <= -x) + except TypeError: + # the last attempt; this may raise TypeError + return bool(y - x >= 0) + + def handle_negative_index(index, bound): + """ normalizes a negative index to be in [0, bound) """ + try: + if not less_equal(0, index): + if is_literal(index) and index <= -self.int_max_: + # this case is handled separately + return index + return bound + index + except TypeError: + logger.warning("Cannot determine if {} < 0".format(index)) + return index + + if get_opset(self.out_mp_) <= 9: + axes = get_attribute(node, 'axes') + starts = get_attribute(node, 'starts') + ends = get_attribute(node, 'ends') + if not axes: + axes = list(range(len(starts))) + steps = [1] * len(axes) + else: + starts = as_list(self._try_get_value(node, 1), keep_none=True) + ends = as_list(self._try_get_value(node, 2), keep_none=True) + axes = self._try_get_value(node, 3) + steps = self._try_get_value(node, 4) + if axes is None and not (starts is None and ends is None): + axes = list( + range(0, len(starts if starts is not None else ends))) + if steps is None and not (starts is None and ends is None): + steps = [1] * len(starts if starts is not None else ends) + axes = as_list(axes, keep_none=True) + steps = as_list(steps, keep_none=True) + + new_sympy_shape = self._get_sympy_shape(node, 0) + if starts is None or ends is None: + if axes is None: + for i in range(len(new_sympy_shape)): + new_sympy_shape[i] = self._new_symbolic_dim_from_output( + node, 0, i) + else: + new_sympy_shape = get_shape_from_sympy_shape(new_sympy_shape) + for i in axes: + new_sympy_shape[i] = self._new_symbolic_dim_from_output( + node, 0, i) + else: + for i, s, e, t in zip(axes, starts, ends, steps): + e = handle_negative_index(e, new_sympy_shape[i]) + if is_literal(e): + if e >= self.int_max_: + e = new_sympy_shape[i] + elif e <= -self.int_max_: + e = 0 if s > 0 else -1 + elif is_literal(new_sympy_shape[i]): + if e < 0: + e = max(0, e + new_sympy_shape[i]) + e = min(e, new_sympy_shape[i]) + else: + if e > 0: + e = sympy.Min( + e, new_sympy_shape[i] + ) if e > 1 else e #special case for slicing first to make computation easier + else: + if is_literal(new_sympy_shape[i]): + e = sympy.Min(e, new_sympy_shape[i]) + else: + try: + if not less_equal(e, new_sympy_shape[i]): + e = new_sympy_shape[i] + except Exception: + logger.warning( + 'Unable to determine if {} <= {}, treat as equal'. + format(e, new_sympy_shape[i])) + e = new_sympy_shape[i] + + s = handle_negative_index(s, new_sympy_shape[i]) + if is_literal(new_sympy_shape[i]) and is_literal(s): + s = max(0, min(s, new_sympy_shape[i])) + + new_sympy_shape[i] = sympy.simplify( + (e - s + t + (-1 if t > 0 else 1)) // t) + + self._update_computed_dims(new_sympy_shape) + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape))) + + # handle sympy_data if needed, for slice in shape computation + if (node.input[0] in self.sympy_data_ and [0] == axes and + len(starts) == 1 and len(ends) == 1 and len(steps) == 1): + input_sympy_data = self.sympy_data_[node.input[0]] + if type(input_sympy_data) == list or ( + type(input_sympy_data) == np.array and + len(input_sympy_data.shape) == 1): + self.sympy_data_[node.output[0]] = input_sympy_data[starts[ + 0]:ends[0]:steps[0]] + + def _infer_SoftmaxCrossEntropyLoss(self, node): + vi = self.known_vi_[node.output[0]] + elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi.type.tensor_type.elem_type = elem_type + vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto()) + + if len(node.output) > 1: + data_shape = self._get_shape(node, 0) + vi = self.known_vi_[node.output[1]] + vi.CopyFrom( + helper.make_tensor_value_info(vi.name, elem_type, data_shape)) + + def _infer_Split_Common(self, node, make_value_info_func): + input_sympy_shape = self._get_sympy_shape(node, 0) + axis = handle_negative_axis( + get_attribute(node, 'axis', 0), len(input_sympy_shape)) + split = get_attribute(node, 'split') + if not split: + num_outputs = len(node.output) + split = [input_sympy_shape[axis] / + sympy.Integer(num_outputs)] * num_outputs + self._update_computed_dims(split) + else: + split = [sympy.Integer(s) for s in split] + + for i_o in range(len(split)): + vi = self.known_vi_[node.output[i_o]] + vi.CopyFrom( + make_value_info_func(node.output[i_o], self.known_vi_[ + node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape( + input_sympy_shape[:axis] + [ + split[i_o] + ] + input_sympy_shape[axis + 1:]))) + self.known_vi_[vi.name] = vi + + def _infer_Split(self, node): + self._infer_Split_Common(node, helper.make_tensor_value_info) + + def _infer_SplitToSequence(self, node): + self._infer_Split_Common(node, helper.make_sequence_value_info) + + def _infer_Squeeze(self, node): + input_shape = self._get_shape(node, 0) + op_set = get_opset(self.out_mp_) + + # Depending on op-version 'axes' are provided as attribute or via 2nd input + if op_set < 13: + axes = get_attribute(node, 'axes') + assert self._try_get_value(node, 1) is None + else: + axes = self._try_get_value(node, 1) + assert get_attribute(node, 'axes') is None + + if axes is None: + # No axes have been provided (neither via attribute nor via input). + # In this case the 'Shape' op should remove all axis with dimension 1. + # For symbolic dimensions we guess they are !=1. + output_shape = [s for s in input_shape if s != 1] + if self.verbose_ > 0: + symbolic_dimensions = [s for s in input_shape if type(s) != int] + if len(symbolic_dimensions) > 0: + logger.debug( + f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " + + + f"Assuming the following dimensions are never equal to 1: {symbolic_dimensions}" + ) + else: + axes = [handle_negative_axis(a, len(input_shape)) for a in axes] + output_shape = [] + for i in range(len(input_shape)): + if i not in axes: + output_shape.append(input_shape[i]) + else: + assert input_shape[i] == 1 or type(input_shape[i]) != int + if self.verbose_ > 0 and type(input_shape[i]) != int: + logger.debug( + f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " + + + f"Assuming the dimension '{input_shape[i]}' at index {i} of the input to be equal to 1." + ) + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], self.known_vi_[ + node.input[0]].type.tensor_type.elem_type, output_shape)) + self._pass_on_sympy_data(node) + + def _infer_Tile(self, node): + repeats_value = self._try_get_value(node, 1) + new_sympy_shape = [] + if repeats_value is not None: + input_sympy_shape = self._get_sympy_shape(node, 0) + for i, d in enumerate(input_sympy_shape): + new_dim = d * repeats_value[i] + new_sympy_shape.append(new_dim) + self._update_computed_dims(new_sympy_shape) + else: + new_sympy_shape = self._new_symbolic_shape( + self._get_shape_rank(node, 0), node) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape))) + + def _infer_TopK(self, node): + rank = self._get_shape_rank(node, 0) + axis = handle_negative_axis(get_attribute(node, 'axis', -1), rank) + new_shape = self._get_shape(node, 0) + + if get_opset(self.out_mp_) <= 9: + k = get_attribute(node, 'k') + else: + k = self._get_int_values(node)[1] + + if k == None: + k = self._new_symbolic_dim_from_output(node) + else: + k = as_scalar(k) + + if type(k) in [int, str]: + new_shape[axis] = k + else: + new_sympy_shape = self._get_sympy_shape(node, 0) + new_sympy_shape[axis] = k + self._update_computed_dims( + new_sympy_shape + ) # note that TopK dim could be computed in sympy_data, so need to update computed_dims when it enters shape + new_shape = get_shape_from_sympy_shape(new_sympy_shape) + + for i_o in range(len(node.output)): + vi = self.known_vi_[node.output[i_o]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[ + i_o], vi.type.tensor_type.elem_type, new_shape)) + + def _infer_Transpose(self, node): + if node.input[0] in self.sympy_data_: + data_shape = self._get_shape(node, 0) + perm = get_attribute(node, 'perm', + reversed(list(range(len(data_shape))))) + input_data = self.sympy_data_[node.input[0]] + self.sympy_data_[node.output[0]] = np.transpose( + np.array(input_data).reshape(*data_shape), + axes=tuple(perm)).flatten().tolist() + + def _infer_Unsqueeze(self, node): + input_shape = self._get_shape(node, 0) + op_set = get_opset(self.out_mp_) + + # Depending on op-version 'axes' are provided as attribute or via 2nd input + if op_set < 13: + axes = get_attribute(node, 'axes') + assert self._try_get_value(node, 1) is None + else: + axes = self._try_get_value(node, 1) + assert get_attribute(node, 'axes') is None + + output_rank = len(input_shape) + len(axes) + axes = [handle_negative_axis(a, output_rank) for a in axes] + + input_axis = 0 + output_shape = [] + for i in range(output_rank): + if i in axes: + output_shape.append(1) + else: + output_shape.append(input_shape[input_axis]) + input_axis += 1 + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], self.known_vi_[ + node.input[0]].type.tensor_type.elem_type, output_shape)) + + self._pass_on_sympy_data(node) + + def _infer_ZipMap(self, node): + map_key_type = None + if get_attribute(node, 'classlabels_int64s') is not None: + map_key_type = onnx.TensorProto.INT64 + elif get_attribute(node, 'classlabels_strings') is not None: + map_key_type = onnx.TensorProto.STRING + + assert map_key_type is not None + new_vi = onnx.ValueInfoProto() + new_vi.name = node.output[0] + new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = onnx.TensorProto.FLOAT + new_vi.type.sequence_type.elem_type.map_type.key_type = map_key_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(new_vi) + + def _infer_Attention(self, node): + shape = self._get_shape(node, 0) + shape_bias = self._get_shape(node, 2) + assert len(shape) == 3 and len(shape_bias) == 1 + qkv_hidden_sizes_attr = get_attribute(node, 'qkv_hidden_sizes') + if qkv_hidden_sizes_attr is not None: + assert len(qkv_hidden_sizes_attr) == 3 + shape[2] = int(qkv_hidden_sizes_attr[2]) + else: + shape[2] = int(shape_bias[0] / 3) + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], output_dtype, shape)) + + if len(node.output) > 1: + # input shape: (batch_size, sequence_length, hidden_size) + # past shape: (2, batch_size, num_heads, past_sequence_length, head_size) + # mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) or (batch_size, 1, max_seq_len, max_seq_len) + # present shape: (2, batch_size, num_heads, total_sequence_length, head_size), where total_sequence_length=sequence_length+past_sequence_length + input_shape = self._get_shape(node, 0) + past_shape = self._get_shape(node, 4) + mask_shape = self._get_shape(node, 3) + if len(past_shape) == 5: + if len(mask_shape) in [2, 3]: + past_shape[3] = mask_shape[-1] + elif isinstance(input_shape[1], int) and isinstance( + past_shape[3], int): + past_shape[3] = input_shape[1] + past_shape[3] + else: + past_shape[3] = f"{past_shape[3]}+{input_shape[1]}" + vi = self.known_vi_[node.output[1]] + vi.CopyFrom( + helper.make_tensor_value_info(vi.name, output_dtype, + past_shape)) + + def _infer_BiasGelu(self, node): + self._propagate_shape_and_type(node) + + def _infer_FastGelu(self, node): + self._propagate_shape_and_type(node) + + def _infer_Gelu(self, node): + self._propagate_shape_and_type(node) + + def _infer_LayerNormalization(self, node): + self._propagate_shape_and_type(node) + + def _infer_LongformerAttention(self, node): + self._propagate_shape_and_type(node) + + def _infer_EmbedLayerNormalization(self, node): + input_ids_shape = self._get_shape(node, 0) + word_embedding_shape = self._get_shape(node, 2) + assert len(input_ids_shape) == 2 and len(word_embedding_shape) == 2 + output_shape = input_ids_shape + [word_embedding_shape[1]] + + word_embedding_dtype = self.known_vi_[node.input[ + 2]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], word_embedding_dtype, + output_shape)) + + mask_index_shape = [input_ids_shape[0]] + vi = self.known_vi_[node.output[1]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[ + 1], onnx.TensorProto.INT32, mask_index_shape)) + + if len(node.output) > 2: + # Optional output of add before layer nomalization is done + # shape is same as the output + vi = self.known_vi_[node.output[2]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[ + 2], word_embedding_dtype, output_shape)) + + def _infer_SkipLayerNormalization(self, node): + self._propagate_shape_and_type(node) + + def _infer_PythonOp(self, node): + output_tensor_types = get_attribute(node, 'output_tensor_types') + assert output_tensor_types + output_tensor_ranks = get_attribute(node, 'output_tensor_ranks') + assert output_tensor_ranks + + # set the context output seperately. + # The first output is autograd's context. + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], + onnx.TensorProto.INT64, [])) + + # Outputs after autograd's context are tensors. + # We assume their ranks are fixed for different model inputs. + for i in range(len(node.output) - 1): + # Process the i-th tensor outputs. + vi = self.known_vi_[node.output[i + 1]] + sympy_shape = self._new_symbolic_shape(output_tensor_ranks[i], node) + shape = get_shape_from_sympy_shape(sympy_shape) + value_info = helper.make_tensor_value_info( + node.output[i + 1], output_tensor_types[i], shape) + vi.CopyFrom(value_info) + + def _propagate_shape_and_type(self, node, input_index=0, output_index=0): + shape = self._get_shape(node, input_index) + output_dtype = self.known_vi_[node.input[ + input_index]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[output_index]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[output_index], + output_dtype, shape)) + + def _is_none_dim(self, dim_value): + if type(dim_value) != str: + return False + if "unk__" not in dim_value: + return False + if dim_value in self.symbolic_dims_.keys(): + return False + return True + + def _is_shape_contains_none_dim(self, out_shape): + for out in out_shape: + if self._is_none_dim(out): + return out + return None + + def _infer_impl(self, start_sympy_data=None): + self.sympy_data_ = start_sympy_data or {} + self.out_mp_.graph.ClearField('value_info') + self._apply_suggested_merge(graph_input_only=True) + self.input_symbols_ = set() + for i in self.out_mp_.graph.input: + input_shape = get_shape_from_value_info(i) + if input_shape is None: + continue + + if is_sequence(i.type): + input_dims = i.type.sequence_type.elem_type.tensor_type.shape.dim + else: + input_dims = i.type.tensor_type.shape.dim + + for i_dim, dim in enumerate(input_shape): + if dim is None: + # some models use None for symbolic dim in input, replace it with a string + input_dims[i_dim].dim_param = str( + self._new_symbolic_dim(i.name, i_dim)) + + self.input_symbols_.update( + [d for d in input_shape if type(d) == str]) + + for s in self.input_symbols_: + if s in self.suggested_merge_: + s_merge = self.suggested_merge_[s] + assert s_merge in self.symbolic_dims_ + self.symbolic_dims_[s] = self.symbolic_dims_[s_merge] + else: + # Since inputs are not produced by other ops, we can assume positivity + self.symbolic_dims_[s] = sympy.Symbol( + s, integer=True, positive=True) + # create a temporary ModelProto for single node inference + # note that we remove initializer to have faster inference + # for tensor ops like Reshape/Tile/Expand that read initializer, we need to do sympy computation based inference anyways + self.tmp_mp_ = onnx.ModelProto() + self.tmp_mp_.CopyFrom(self.out_mp_) + self.tmp_mp_.graph.ClearField('initializer') + + # compute prerequesite for node for topological sort + # node with subgraphs may have dependency on implicit inputs, which will affect topological sort + prereq_for_node = { + } # map from node to all its inputs, including implicit ones in subgraph + + def get_prereq(node): + names = set(i for i in node.input if i) + subgraphs = [] + if 'If' == node.op_type: + subgraphs = [ + get_attribute(node, 'then_branch'), + get_attribute(node, 'else_branch') + ] + elif node.op_type in ['Loop', 'Scan']: + subgraphs = [get_attribute(node, 'body')] + for g in subgraphs: + g_outputs_and_initializers = {i.name for i in g.initializer} + g_prereq = set() + for n in g.node: + g_outputs_and_initializers.update(n.output) + for n in g.node: + g_prereq.update([ + i for i in get_prereq(n) + if i not in g_outputs_and_initializers + ]) + names.update(g_prereq) + # remove subgraph inputs from g_prereq since those are local-only + for i in g.input: + if i.name in names: + names.remove(i.name) + return names + + for n in self.tmp_mp_.graph.node: + prereq_for_node[n.output[0]] = get_prereq(n) + + # topological sort nodes, note there might be dead nodes so we check if all graph outputs are reached to terminate + sorted_nodes = [] + sorted_known_vi = set([ + i.name for i in list(self.out_mp_.graph.input) + + list(self.out_mp_.graph.initializer) + ]) + if any([o.name in sorted_known_vi for o in self.out_mp_.graph.output]): + # Loop/Scan will have some graph output in graph inputs, so don't do topological sort + sorted_nodes = self.out_mp_.graph.node + else: + while not all( + [o.name in sorted_known_vi for o in self.out_mp_.graph.output]): + old_sorted_nodes_len = len(sorted_nodes) + for node in self.out_mp_.graph.node: + if (node.output[0] not in sorted_known_vi) and all([ + i in sorted_known_vi + for i in prereq_for_node[node.output[0]] if i + ]): + sorted_known_vi.update(node.output) + sorted_nodes.append(node) + if old_sorted_nodes_len == len(sorted_nodes) and not all([ + o.name in sorted_known_vi + for o in self.out_mp_.graph.output + ]): + raise Exception('Invalid model with cyclic graph') + + for node in sorted_nodes: + assert all([i in self.known_vi_ for i in node.input if i]) + self._onnx_infer_single_node(node) + known_aten_op = False + if node.op_type in self.dispatcher_: + self.dispatcher_[node.op_type](node) + elif node.op_type in ['ConvTranspose']: + # onnx shape inference ops like ConvTranspose may have empty shape for symbolic input + # before adding symbolic compute for them + # mark the output type as UNDEFINED to allow guessing of rank + vi = self.known_vi_[node.output[0]] + if len(vi.type.tensor_type.shape.dim) == 0: + vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED + elif node.op_type == 'ATen' and node.domain == 'org.pytorch.aten': + for attr in node.attribute: + # TODO: Is overload_name needed? + if attr.name == 'operator': + aten_op_name = attr.s.decode('utf-8') if isinstance( + attr.s, bytes) else attr.s + if aten_op_name in self.aten_op_dispatcher_: + known_aten_op = True + self.aten_op_dispatcher_[aten_op_name](node) + break + + if self.verbose_ > 2: + logger.debug(node.op_type + ': ' + node.name) + for i, name in enumerate(node.input): + logger.debug(' Input {}: {} {}'.format( + i, name, 'initializer' + if name in self.initializers_ else '')) + + # onnx automatically merge dims with value, i.e. Mul(['aaa', 'bbb'], [1000, 1]) -> [1000, 'bbb'] + # symbolic shape inference needs to apply merge of 'aaa' -> 1000 in this case + if node.op_type in [ + 'Add', 'Sub', 'Mul', 'Div', 'MatMul', 'MatMulInteger', + 'MatMulInteger16', 'Where', 'Sum' + ]: + vi = self.known_vi_[node.output[0]] + out_rank = len(get_shape_from_type_proto(vi.type)) + in_shapes = [ + self._get_shape(node, i) for i in range(len(node.input)) + ] + for d in range(out_rank - (2 if node.op_type in [ + 'MatMul', 'MatMulInteger', 'MatMulInteger16' + ] else 0)): + in_dims = [ + s[len(s) - out_rank + d] for s in in_shapes + if len(s) + d >= out_rank + ] + if len(in_dims) > 1: + self._check_merged_dims(in_dims, allow_broadcast=True) + + for i_o in range(len(node.output)): + vi = self.known_vi_[node.output[i_o]] + out_type = vi.type + out_type_kind = out_type.WhichOneof('value') + + # do not process shape for non-tensors + if out_type_kind not in [ + 'tensor_type', 'sparse_tensor_type', None + ]: + if self.verbose_ > 2: + if out_type_kind == 'sequence_type': + seq_cls_type = out_type.sequence_type.elem_type.WhichOneof( + 'value') + if 'tensor_type' == seq_cls_type: + logger.debug(' {}: sequence of {} {}'.format( + node.output[i_o], + str(get_shape_from_value_info(vi)), + onnx.TensorProto.DataType.Name( + vi.type.sequence_type.elem_type. + tensor_type.elem_type))) + else: + logger.debug(' {}: sequence of {}'.format( + node.output[i_o], seq_cls_type)) + else: + logger.debug(' {}: {}'.format(node.output[i_o], + out_type_kind)) + continue + + out_shape = get_shape_from_value_info(vi) + out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED + if self.verbose_ > 2: + logger.debug(' {}: {} {}'.format( + node.output[i_o], + str(out_shape), + onnx.TensorProto.DataType.Name( + vi.type.tensor_type.elem_type))) + if node.output[i_o] in self.sympy_data_: + logger.debug(' Sympy Data: ' + str(self.sympy_data_[ + node.output[i_o]])) + + # onnx >= 1.11.0, use unk__#index instead of None when the shape dim is uncertain + if (out_shape is not None and + (None in out_shape or + self._is_shape_contains_none_dim(out_shape)) + ) or out_type_undefined: + if self.auto_merge_: + if node.op_type in [ + 'Add', 'Sub', 'Mul', 'Div', 'MatMul', + 'MatMulInteger', 'MatMulInteger16', 'Concat', + 'Where', 'Sum', 'Equal', 'Less', 'Greater', + 'LessOrEqual', 'GreaterOrEqual' + ]: + shapes = [ + self._get_shape(node, i) + for i in range(len(node.input)) + ] + if node.op_type in [ + 'MatMul', 'MatMulInteger', 'MatMulInteger16' + ]: + if None in out_shape or self._is_shape_contains_none_dim( + out_shape): + if None in out_shape: + idx = out_shape.index(None) + else: + idx = out_shape.index( + self._is_shape_contains_none_dim( + out_shape)) + dim_idx = [ + len(s) - len(out_shape) + idx + for s in shapes + ] + # only support auto merge for MatMul for dim < rank-2 when rank > 2 + assert len( + shapes[0]) > 2 and dim_idx[0] < len( + shapes[0]) - 2 + assert len( + shapes[1]) > 2 and dim_idx[1] < len( + shapes[1]) - 2 + elif node.op_type == 'Expand': + # auto merge for cases like Expand([min(batch, 1), min(seq, 512)], [batch, seq]) + shapes = [ + self._get_shape(node, 0), self._get_value(node, + 1) + ] + else: + shapes = [] + + if shapes: + for idx in range(len(out_shape)): + if out_shape[ + idx] is not None and not self._is_none_dim( + out_shape[idx]): + continue + # note that the broadcasting rule aligns from right to left + # if a tensor has a lower rank (dim_idx[idx] < 0), it would automatically broadcast and need no merge + dim_idx = [ + len(s) - len(out_shape) + idx + for s in shapes + ] + if len(dim_idx) > 0: + self._add_suggested_merge([ + s[i] if is_literal(s[i]) else str(s[i]) + for s, i in zip(shapes, dim_idx) + if i >= 0 + ]) + self.run_ = True + else: + self.run_ = False + else: + self.run_ = False + + # create new dynamic dims for ops not handled by symbolic shape inference + if self.run_ == False and not node.op_type in self.dispatcher_ and not known_aten_op: + is_unknown_op = out_type_undefined and ( + out_shape is None or len(out_shape) == 0) + if is_unknown_op: + # unknown op to ONNX, maybe from higher opset or other domain + # only guess the output rank from input 0 when using guess_output_rank option + out_rank = self._get_shape_rank( + node, 0) if self.guess_output_rank_ else -1 + else: + # valid ONNX op, but not handled by symbolic shape inference, just assign dynamic shape + out_rank = len(out_shape) + + if out_rank >= 0: + new_shape = self._new_symbolic_shape(out_rank, node, + i_o) + if out_type_undefined: + # guess output data type from input vi if not defined + out_dtype = self.known_vi_[node.input[ + 0]].type.tensor_type.elem_type + else: + # otherwise, use original data type + out_dtype = vi.type.tensor_type.elem_type + vi.CopyFrom( + helper.make_tensor_value_info( + vi.name, out_dtype, + get_shape_from_sympy_shape(new_shape))) + + if self.verbose_ > 0: + if is_unknown_op: + logger.debug( + "Possible unknown op: {} node: {}, guessing {} shape". + format(node.op_type, node.name, + vi.name)) + if self.verbose_ > 2: + logger.debug(' {}: {} {}'.format( + node.output[i_o], + str(new_shape), + vi.type.tensor_type.elem_type)) + + self.run_ = True + continue # continue the inference after guess, no need to stop as no merge is needed + + if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined: + logger.debug( + 'Stopping at incomplete shape inference at ' + + node.op_type + ': ' + node.name) + logger.debug('node inputs:') + for i in node.input: + logger.debug(self.known_vi_[i]) + logger.debug('node outputs:') + for o in node.output: + logger.debug(self.known_vi_[o]) + if self.auto_merge_ and not out_type_undefined: + logger.debug('Merging: ' + str( + self.suggested_merge_)) + return False + + self.run_ = False + return True + + def _update_output_from_vi(self): + for output in self.out_mp_.graph.output: + if output.name in self.known_vi_: + output.CopyFrom(self.known_vi_[output.name]) + + @staticmethod + def infer_shapes(in_mp, + int_max=2**31 - 1, + auto_merge=False, + guess_output_rank=False, + verbose=0): + onnx_opset = get_opset(in_mp) + if (not onnx_opset) or onnx_opset < 7: + logger.warning('Only support models of onnx opset 7 and above.') + return None + symbolic_shape_inference = SymbolicShapeInference( + int_max, auto_merge, guess_output_rank, verbose) + all_shapes_inferred = False + symbolic_shape_inference._preprocess(in_mp) + while symbolic_shape_inference.run_: + all_shapes_inferred = symbolic_shape_inference._infer_impl() + symbolic_shape_inference._update_output_from_vi() + if not all_shapes_inferred: + raise Exception("Incomplete symbolic shape inference") + return symbolic_shape_inference.out_mp_ + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--input', required=True, help='The input model file') + parser.add_argument('--output', help='The output model file') + parser.add_argument( + '--auto_merge', + help='Automatically merge symbolic dims when confliction happens', + action='store_true', + default=False) + parser.add_argument( + '--int_max', + help='maximum value for integer to be treated as boundless for ops like slice', + type=int, + default=2**31 - 1) + parser.add_argument( + '--guess_output_rank', + help='guess output rank to be the same as input 0 for unknown ops', + action='store_true', + default=False) + parser.add_argument( + '--verbose', + help='Prints detailed logs of inference, 0: turn off, 1: warnings, 3: detailed', + type=int, + default=0) + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_arguments() + logger.info('input model: ' + args.input) + if args.output: + logger.info('output model ' + args.output) + logger.info('Doing symbolic shape inference...') + out_mp = SymbolicShapeInference.infer_shapes( + onnx.load(args.input), args.int_max, args.auto_merge, + args.guess_output_rank, args.verbose) + if args.output and out_mp: + onnx.save(out_mp, args.output) + logger.info('Done!') diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_opt.sh b/speechx/examples/ds2_ol/onnx/local/onnx_opt.sh new file mode 100755 index 000000000..ce2f24e58 --- /dev/null +++ b/speechx/examples/ds2_ol/onnx/local/onnx_opt.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +set -e + +if [ $# != 3 ];then + # ./local/onnx_opt.sh model.old.onnx model.opt.onnx "audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,1024 chunk_state_h_box:5,1,1024" + echo "usage: $0 onnx.model.in onnx.model.out input_shape " + exit 1 +fi + +# onnx optimizer +pip install onnx-simplifier + +in=$1 +out=$2 +input_shape=$3 + +check_n=3 + +onnxsim $in $out $check_n --dynamic-input-shape --input-shape $input_shape \ No newline at end of file diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_prune_model.py b/speechx/examples/ds2_ol/onnx/local/onnx_prune_model.py new file mode 100755 index 000000000..5b85eef3e --- /dev/null +++ b/speechx/examples/ds2_ol/onnx/local/onnx_prune_model.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 -W ignore::DeprecationWarning +# prune model by output names +import argparse +import copy +import sys + +import onnx + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--model', + required=True, + help='Path of directory saved the input model.') + parser.add_argument( + '--output_names', + required=True, + nargs='+', + help='The outputs of pruned model.') + parser.add_argument( + '--save_file', required=True, help='Path to save the new onnx model.') + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_arguments() + + if len(set(args.output_names)) < len(args.output_names): + print( + "[ERROR] There's dumplicate name in --output_names, which is not allowed." + ) + sys.exit(-1) + + model = onnx.load(args.model) + + # collect all node outputs and graph output + output_tensor_names = set() + for node in model.graph.node: + for out in node.output: + # may contain model output + output_tensor_names.add(out) + + # for out in model.graph.output: + # output_tensor_names.add(out.name) + + for output_name in args.output_names: + if output_name not in output_tensor_names: + print( + "[ERROR] Cannot find output tensor name '{}' in onnx model graph.". + format(output_name)) + sys.exit(-1) + + output_node_indices = set() # has output names + output_to_node = dict() # all node outputs + for i, node in enumerate(model.graph.node): + for out in node.output: + output_to_node[out] = i + if out in args.output_names: + output_node_indices.add(i) + + # from outputs find all the ancestors + reserved_node_indices = copy.deepcopy( + output_node_indices) # nodes need to keep + reserved_inputs = set() # model input to keep + new_output_node_indices = copy.deepcopy(output_node_indices) + + while True and len(new_output_node_indices) > 0: + output_node_indices = copy.deepcopy(new_output_node_indices) + + new_output_node_indices = set() + + for out_node_idx in output_node_indices: + # backtrace to parenet + for ipt in model.graph.node[out_node_idx].input: + if ipt in output_to_node: + reserved_node_indices.add(output_to_node[ipt]) + new_output_node_indices.add(output_to_node[ipt]) + else: + reserved_inputs.add(ipt) + + num_inputs = len(model.graph.input) + num_outputs = len(model.graph.output) + num_nodes = len(model.graph.node) + print( + f"old graph has {num_inputs} inputs, {num_outputs} outpus, {num_nodes} nodes" + ) + print(f"{len(reserved_node_indices)} node to keep.") + + # del node not to keep + for idx in range(num_nodes - 1, -1, -1): + if idx not in reserved_node_indices: + del model.graph.node[idx] + + # del graph input not to keep + for idx in range(num_inputs - 1, -1, -1): + if model.graph.input[idx].name not in reserved_inputs: + del model.graph.input[idx] + + # del old graph outputs + for i in range(num_outputs): + del model.graph.output[0] + + # new graph output as user input + for out in args.output_names: + model.graph.output.extend([onnx.ValueInfoProto(name=out)]) + + # infer shape + try: + from onnx_infer_shape import SymbolicShapeInference + model = SymbolicShapeInference.infer_shapes( + model, + int_max=2**31 - 1, + auto_merge=True, + guess_output_rank=False, + verbose=1) + except Exception as e: + print(f"skip infer shape step: {e}") + + # check onnx model + onnx.checker.check_model(model) + # save onnx model + onnx.save(model, args.save_file) + print("[Finished] The new model saved in {}.".format(args.save_file)) + print("[DEBUG INFO] The inputs of new model: {}".format( + [x.name for x in model.graph.input])) + print("[DEBUG INFO] The outputs of new model: {}".format( + [x.name for x in model.graph.output])) diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_rename_model.py b/speechx/examples/ds2_ol/onnx/local/onnx_rename_model.py new file mode 100755 index 000000000..fc00a82ec --- /dev/null +++ b/speechx/examples/ds2_ol/onnx/local/onnx_rename_model.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 -W ignore::DeprecationWarning +# rename node to new names +import argparse +import sys + +import onnx + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--model', + required=True, + help='Path of directory saved the input model.') + parser.add_argument( + '--origin_names', + required=True, + nargs='+', + help='The original name you want to modify.') + parser.add_argument( + '--new_names', + required=True, + nargs='+', + help='The new name you want change to, the number of new_names should be same with the number of origin_names' + ) + parser.add_argument( + '--save_file', required=True, help='Path to save the new onnx model.') + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_arguments() + + if len(set(args.origin_names)) < len(args.origin_names): + print( + "[ERROR] There's dumplicate name in --origin_names, which is not allowed." + ) + sys.exit(-1) + + if len(set(args.new_names)) < len(args.new_names): + print( + "[ERROR] There's dumplicate name in --new_names, which is not allowed." + ) + sys.exit(-1) + + if len(args.new_names) != len(args.origin_names): + print( + "[ERROR] Number of --new_names must be same with the number of --origin_names." + ) + sys.exit(-1) + + model = onnx.load(args.model) + + # collect input and all node output + output_tensor_names = set() + for ipt in model.graph.input: + output_tensor_names.add(ipt.name) + + for node in model.graph.node: + for out in node.output: + output_tensor_names.add(out) + + for origin_name in args.origin_names: + if origin_name not in output_tensor_names: + print( + f"[ERROR] Cannot find tensor name '{origin_name}' in onnx model graph." + ) + sys.exit(-1) + + for new_name in args.new_names: + if new_name in output_tensor_names: + print( + "[ERROR] The defined new_name '{}' is already exist in the onnx model, which is not allowed." + ) + sys.exit(-1) + + # rename graph input + for i, ipt in enumerate(model.graph.input): + if ipt.name in args.origin_names: + idx = args.origin_names.index(ipt.name) + model.graph.input[i].name = args.new_names[idx] + + # rename node input and output + for i, node in enumerate(model.graph.node): + for j, ipt in enumerate(node.input): + if ipt in args.origin_names: + idx = args.origin_names.index(ipt) + model.graph.node[i].input[j] = args.new_names[idx] + + for j, out in enumerate(node.output): + if out in args.origin_names: + idx = args.origin_names.index(out) + model.graph.node[i].output[j] = args.new_names[idx] + + # rename graph output + for i, out in enumerate(model.graph.output): + if out.name in args.origin_names: + idx = args.origin_names.index(out.name) + model.graph.output[i].name = args.new_names[idx] + + # check onnx model + onnx.checker.check_model(model) + + # save model + onnx.save(model, args.save_file) + + print("[Finished] The new model saved in {}.".format(args.save_file)) + print("[DEBUG INFO] The inputs of new model: {}".format( + [x.name for x in model.graph.input])) + print("[DEBUG INFO] The outputs of new model: {}".format( + [x.name for x in model.graph.output])) diff --git a/speechx/examples/ds2_ol/onnx/local/pd_infer_shape.py b/speechx/examples/ds2_ol/onnx/local/pd_infer_shape.py new file mode 100755 index 000000000..c6e693c6b --- /dev/null +++ b/speechx/examples/ds2_ol/onnx/local/pd_infer_shape.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 -W ignore::DeprecationWarning +# https://github.com/jiangjiajun/PaddleUtils/blob/main/paddle/README.md#2-%E4%BF%AE%E6%94%B9paddle%E6%A8%A1%E5%9E%8B%E8%BE%93%E5%85%A5shape +import argparse + +# paddle inference shape + + +def process_old_ops_desc(program): + """set matmul op head_number attr to 1 is not exist. + + Args: + program (_type_): _description_ + """ + for i in range(len(program.blocks[0].ops)): + if program.blocks[0].ops[i].type == "matmul": + if not program.blocks[0].ops[i].has_attr("head_number"): + program.blocks[0].ops[i]._set_attr("head_number", 1) + + +def infer_shape(program, input_shape_dict): + # 2002002 + model_version = program.desc._version() + # 2.2.2 + paddle_version = paddle.__version__ + major_ver = model_version // 1000000 + minor_ver = (model_version - major_ver * 1000000) // 1000 + patch_ver = model_version - major_ver * 1000000 - minor_ver * 1000 + model_version = "{}.{}.{}".format(major_ver, minor_ver, patch_ver) + if model_version != paddle_version: + print( + f"[WARNING] The model is saved by paddlepaddle v{model_version}, but now your paddlepaddle is version of {paddle_version}, this difference may cause error, it is recommend you reinstall a same version of paddlepaddle for this model" + ) + + OP_WITHOUT_KERNEL_SET = { + 'feed', 'fetch', 'recurrent', 'go', 'rnn_memory_helper_grad', + 'conditional_block', 'while', 'send', 'recv', 'listen_and_serv', + 'fl_listen_and_serv', 'ncclInit', 'select', 'checkpoint_notify', + 'gen_bkcl_id', 'c_gen_bkcl_id', 'gen_nccl_id', 'c_gen_nccl_id', + 'c_comm_init', 'c_sync_calc_stream', 'c_sync_comm_stream', + 'queue_generator', 'dequeue', 'enqueue', 'heter_listen_and_serv', + 'c_wait_comm', 'c_wait_compute', 'c_gen_hccl_id', 'c_comm_init_hccl', + 'copy_cross_scope' + } + + for k, v in input_shape_dict.items(): + program.blocks[0].var(k).desc.set_shape(v) + + for i in range(len(program.blocks)): + for j in range(len(program.blocks[0].ops)): + # for ops + if program.blocks[i].ops[j].type in OP_WITHOUT_KERNEL_SET: + print(f"not infer: {program.blocks[i].ops[j].type} op") + continue + print(f"infer: {program.blocks[i].ops[j].type} op") + program.blocks[i].ops[j].desc.infer_shape(program.blocks[i].desc) + + +def parse_arguments(): + # python pd_infer_shape.py --model_dir data/exp/deepspeech2_online/checkpoints \ + # --model_filename avg_1.jit.pdmodel\ + # --params_filename avg_1.jit.pdiparams \ + # --save_dir . \ + # --input_shape_dict="{'audio_chunk':[1,-1,161], 'audio_chunk_lens':[1], 'chunk_state_c_box':[5, 1, 1024], 'chunk_state_h_box':[5,1,1024]}" + parser = argparse.ArgumentParser() + parser.add_argument( + '--model_dir', + required=True, + help='Path of directory saved the input model.') + parser.add_argument( + '--model_filename', required=True, help='model.pdmodel.') + parser.add_argument( + '--params_filename', required=True, help='model.pdiparams.') + parser.add_argument( + '--save_dir', + required=True, + help='directory to save the exported model.') + parser.add_argument( + '--input_shape_dict', required=True, help="The new shape information.") + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_arguments() + + import paddle + paddle.enable_static() + import paddle.fluid as fluid + + input_shape_dict_str = args.input_shape_dict + input_shape_dict = eval(input_shape_dict_str) + + print("Start to load paddle model...") + exe = fluid.Executor(fluid.CPUPlace()) + + prog, ipts, outs = fluid.io.load_inference_model( + args.model_dir, + exe, + model_filename=args.model_filename, + params_filename=args.params_filename) + + process_old_ops_desc(prog) + infer_shape(prog, input_shape_dict) + + fluid.io.save_inference_model( + args.save_dir, + ipts, + outs, + exe, + prog, + model_filename=args.model_filename, + params_filename=args.params_filename) diff --git a/speechx/examples/ds2_ol/onnx/local/pd_prune_model.py b/speechx/examples/ds2_ol/onnx/local/pd_prune_model.py new file mode 100755 index 000000000..5386a971a --- /dev/null +++ b/speechx/examples/ds2_ol/onnx/local/pd_prune_model.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 -W ignore::DeprecationWarning +# https://github.com/jiangjiajun/PaddleUtils/blob/main/paddle/README.md#1-%E8%A3%81%E5%89%AApaddle%E6%A8%A1%E5%9E%8B +import argparse +import sys +from typing import List + +# paddle prune model. + + +def prepend_feed_ops(program, + feed_target_names: List[str], + feed_holder_name='feed'): + import paddle.fluid.core as core + if len(feed_target_names) == 0: + return + + global_block = program.global_block() + feed_var = global_block.create_var( + name=feed_holder_name, + type=core.VarDesc.VarType.FEED_MINIBATCH, + persistable=True, ) + + for i, name in enumerate(feed_target_names, 0): + if not global_block.has_var(name): + print( + f"The input[{i}]: '{name}' doesn't exist in pruned inference program, which will be ignored in new saved model." + ) + continue + + out = global_block.var(name) + global_block._prepend_op( + type='feed', + inputs={'X': [feed_var]}, + outputs={'Out': [out]}, + attrs={'col': i}, ) + + +def append_fetch_ops(program, + fetch_target_names: List[str], + fetch_holder_name='fetch'): + """in the place, we will add the fetch op + + Args: + program (_type_): inference program + fetch_target_names (List[str]): target names + fetch_holder_name (str, optional): fetch op name. Defaults to 'fetch'. + """ + import paddle.fluid.core as core + global_block = program.global_block() + fetch_var = global_block.create_var( + name=fetch_holder_name, + type=core.VarDesc.VarType.FETCH_LIST, + persistable=True, ) + + print(f"the len of fetch_target_names: {len(fetch_target_names)}") + + for i, name in enumerate(fetch_target_names): + global_block.append_op( + type='fetch', + inputs={'X': [name]}, + outputs={'Out': [fetch_var]}, + attrs={'col': i}, ) + + +def insert_fetch(program, + fetch_target_names: List[str], + fetch_holder_name='fetch'): + """in the place, we will add the fetch op + + Args: + program (_type_): inference program + fetch_target_names (List[str]): target names + fetch_holder_name (str, optional): fetch op name. Defaults to 'fetch'. + """ + global_block = program.global_block() + + # remove fetch + need_to_remove_op_index = [] + for i, op in enumerate(global_block.ops): + if op.type == 'fetch': + need_to_remove_op_index.append(i) + + for index in reversed(need_to_remove_op_index): + global_block._remove_op(index) + + program.desc.flush() + + # append new fetch + append_fetch_ops(program, fetch_target_names, fetch_holder_name) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--model_dir', + required=True, + help='Path of directory saved the input model.') + parser.add_argument( + '--model_filename', required=True, help='model.pdmodel.') + parser.add_argument( + '--params_filename', required=True, help='model.pdiparams.') + parser.add_argument( + '--output_names', + required=True, + help='The outputs of model. sep by comma') + parser.add_argument( + '--save_dir', + required=True, + help='directory to save the exported model.') + parser.add_argument('--debug', default=False, help='output debug info.') + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_arguments() + + args.output_names = args.output_names.split(",") + + if len(set(args.output_names)) < len(args.output_names): + print( + f"[ERROR] There's dumplicate name in --output_names {args.output_names}, which is not allowed." + ) + sys.exit(-1) + + import paddle + paddle.enable_static() + # hack prepend_feed_ops + paddle.fluid.io.prepend_feed_ops = prepend_feed_ops + + import paddle.fluid as fluid + + print("start to load paddle model") + exe = fluid.Executor(fluid.CPUPlace()) + prog, ipts, outs = fluid.io.load_inference_model( + args.model_dir, + exe, + model_filename=args.model_filename, + params_filename=args.params_filename) + + print("start to load insert fetch op") + new_outputs = [] + insert_fetch(prog, args.output_names) + for out_name in args.output_names: + new_outputs.append(prog.global_block().var(out_name)) + + # not equal to paddle.static.save_inference_model + fluid.io.save_inference_model( + args.save_dir, + ipts, + new_outputs, + exe, + prog, + model_filename=args.model_filename, + params_filename=args.params_filename) + + if args.debug: + for op in prog.global_block().ops: + print(op) diff --git a/speechx/examples/ds2_ol/onnx/local/prune.sh b/speechx/examples/ds2_ol/onnx/local/prune.sh new file mode 100755 index 000000000..64636bccf --- /dev/null +++ b/speechx/examples/ds2_ol/onnx/local/prune.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +set -e + +if [ $# != 5 ]; then + # local/prune.sh data/exp/deepspeech2_online/checkpoints avg_1.jit.pdmodel avg_1.jit.pdiparams softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 $PWD + echo "usage: $0 model_dir model_filename param_filename outputs_names save_dir" + exit 1 +fi + +dir=$1 +model=$2 +param=$3 +outputs=$4 +save_dir=$5 + + +python local/pd_prune_model.py \ + --model_dir $dir \ + --model_filename $model \ + --params_filename $param \ + --output_names $outputs \ + --save_dir $save_dir \ No newline at end of file diff --git a/speechx/examples/ds2_ol/onnx/local/tonnx.sh b/speechx/examples/ds2_ol/onnx/local/tonnx.sh new file mode 100755 index 000000000..ffedf001c --- /dev/null +++ b/speechx/examples/ds2_ol/onnx/local/tonnx.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +if [ $# != 4 ];then + # local/tonnx.sh data/exp/deepspeech2_online/checkpoints avg_1.jit.pdmodel avg_1.jit.pdiparams exp/model.onnx + echo "usage: $0 model_dir model_name param_name onnx_output_name" + exit 1 +fi + +dir=$1 +model=$2 +param=$3 +output=$4 + +pip install paddle2onnx +pip install onnx + +# https://github.com/PaddlePaddle/Paddle2ONNX#%E5%91%BD%E4%BB%A4%E8%A1%8C%E8%BD%AC%E6%8D%A2 +paddle2onnx --model_dir $dir \ + --model_filename $model \ + --params_filename $param \ + --save_file $output \ + --enable_dev_version True \ + --opset_version 9 \ + --enable_onnx_checker True + \ No newline at end of file diff --git a/speechx/examples/ds2_ol/onnx/path.sh b/speechx/examples/ds2_ol/onnx/path.sh new file mode 100755 index 000000000..97d487379 --- /dev/null +++ b/speechx/examples/ds2_ol/onnx/path.sh @@ -0,0 +1,14 @@ +# This contains the locations of binarys build required for running the examples. + +MAIN_ROOT=`realpath $PWD/../../../../` +SPEECHX_ROOT=$PWD/../../../ +SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx + +SPEECHX_TOOLS=$SPEECHX_ROOT/tools +TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin + +[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; } + +export LC_AL=C + +export PATH=$PATH:$TOOLS_BIN diff --git a/speechx/examples/ds2_ol/onnx/run.sh b/speechx/examples/ds2_ol/onnx/run.sh new file mode 100755 index 000000000..dda5a57a6 --- /dev/null +++ b/speechx/examples/ds2_ol/onnx/run.sh @@ -0,0 +1,76 @@ +#!/bin/bash + +set -e + +. path.sh + +stage=0 +stop_stage=100 + +. utils/parse_options.sh + +data=data +exp=exp + +mkdir -p $data $exp + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then + test -f $data/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz || wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz -P $data + + # wenetspeech ds2 model + pushd $data + tar zxvf asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz + popd + + # ds2 model demo inputs + pushd $exp + wget -c http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/onnx/static_ds2online_inputs.pickle + popd + + +fi + +dir=$data/exp/deepspeech2_online/checkpoints +model=avg_1.jit.pdmodel +param=avg_1.jit.pdiparams + + +output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ];then + # prune model by outputs + mkdir -p $exp/prune + + # prune model deps on output_names. + ./local/prune.sh $dir $model $param $output_names $exp/prune +fi + +input_shape_dict="{'audio_chunk':[1,-1,161], 'audio_chunk_lens':[1], 'chunk_state_c_box':[5, 1, 1024], 'chunk_state_h_box':[5,1,1024]}" +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ];then + # infer shape by new shape + mkdir -p $exp/shape + python3 local/pd_infer_shape.py \ + --model_dir $dir \ + --model_filename $model \ + --params_filename $param \ + --save_dir $exp/shape \ + --input_shape_dict="${input_shape_dict}" +fi + +input_file=$exp/static_ds2online_inputs.pickle +test -e $input_file + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ];then + # to onnx + ./local/tonnx.sh $dir $model $param $exp/model.onnx + + ./local/infer_check.py --input_file $input_file --model_dir $dir --onnx_model $exp/model.onnx +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ] ;then + input_shape="audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,1024 chunk_state_h_box:5,1,1024" + # simplifying onnx model + ./local/onnx_opt.sh $exp/model.onnx $exp/model.opt.onnx "$input_shape" + + ./local/infer_check.py --input_file $input_file --model_dir $dir --onnx_model $exp/model.opt.onnx +fi \ No newline at end of file diff --git a/speechx/examples/ngram/zh/utils b/speechx/examples/ds2_ol/onnx/utils similarity index 100% rename from speechx/examples/ngram/zh/utils rename to speechx/examples/ds2_ol/onnx/utils diff --git a/speechx/examples/ds2_ol/websocket/CMakeLists.txt b/speechx/examples/ds2_ol/websocket/CMakeLists.txt deleted file mode 100644 index ed542aad0..000000000 --- a/speechx/examples/ds2_ol/websocket/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -cmake_minimum_required(VERSION 3.14 FATAL_ERROR) - -add_executable(websocket_server_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_server_main.cc) -target_include_directories(websocket_server_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(websocket_server_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder websocket ${DEPS}) - -add_executable(websocket_client_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_client_main.cc) -target_include_directories(websocket_client_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(websocket_client_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder websocket ${DEPS}) \ No newline at end of file diff --git a/speechx/examples/ds2_ol/websocket/path.sh b/speechx/examples/ds2_ol/websocket/path.sh index d66b5dcce..6dd6bddbf 100755 --- a/speechx/examples/ds2_ol/websocket/path.sh +++ b/speechx/examples/ds2_ol/websocket/path.sh @@ -1,14 +1,14 @@ # This contains the locations of binarys build required for running the examples. -SPEECHX_ROOT=$PWD/../../.. -SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples +SPEECHX_ROOT=$PWD/../../../ +SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx SPEECHX_TOOLS=$SPEECHX_ROOT/tools TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin -[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; } +[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; } export LC_AL=C -SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/websocket:$SPEECHX_EXAMPLES/ds2_ol/feat +SPEECHX_BIN=$SPEECHX_BUILD/protocol/websocket:$SPEECHX_BUILD/frontend/audio export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN diff --git a/speechx/examples/ds2_ol/websocket/websocket_client.sh b/speechx/examples/ds2_ol/websocket/websocket_client.sh index 2a52d2a3d..a508adfbc 100755 --- a/speechx/examples/ds2_ol/websocket/websocket_client.sh +++ b/speechx/examples/ds2_ol/websocket/websocket_client.sh @@ -32,4 +32,4 @@ export GLOG_logtostderr=1 # websocket client websocket_client_main \ - --wav_rspecifier=scp:$data/$aishell_wav_scp --streaming_chunk=0.36 \ No newline at end of file + --wav_rspecifier=scp:$data/$aishell_wav_scp --streaming_chunk=0.5 diff --git a/speechx/examples/ds2_ol/websocket/websocket_server.sh b/speechx/examples/ds2_ol/websocket/websocket_server.sh index fc57e326f..18d29857c 100755 --- a/speechx/examples/ds2_ol/websocket/websocket_server.sh +++ b/speechx/examples/ds2_ol/websocket/websocket_server.sh @@ -4,7 +4,6 @@ set -e . path.sh - # 1. compile if [ ! -d ${SPEECHX_EXAMPLES} ]; then pushd ${SPEECHX_ROOT} @@ -19,19 +18,6 @@ ckpt_dir=$data/model model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/ vocb_dir=$ckpt_dir/data/lang_char/ -# output -aishell_wav_scp=aishell_test.scp -if [ ! -d $data/test ]; then - pushd $data - wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip - unzip aishell_test.zip - popd - - realpath $data/test/*/*.wav > $data/wavlist - awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id - paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp -fi - if [ ! -f $ckpt_dir/data/mean_std.json ]; then mkdir -p $ckpt_dir @@ -45,7 +31,7 @@ export GLOG_logtostderr=1 # 3. gen cmvn cmvn=$data/cmvn.ark -cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn +cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn wfst=$data/wfst/ @@ -62,7 +48,6 @@ fi websocket_server_main \ --cmvn_file=$cmvn \ --model_path=$model_dir/avg_1.jit.pdmodel \ - --streaming_chunk=0.1 \ --param_path=$model_dir/avg_1.jit.pdiparams \ --word_symbol_table=$wfst/words.txt \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ diff --git a/speechx/examples/ngram/en/README.md b/speechx/examples/ngram/en/README.md deleted file mode 100644 index e69de29bb..000000000 diff --git a/speechx/examples/ngram/zh/README.md b/speechx/examples/ngram/zh/README.md deleted file mode 100644 index e11bd3439..000000000 --- a/speechx/examples/ngram/zh/README.md +++ /dev/null @@ -1,101 +0,0 @@ -# ngram train for mandarin - -Quick run: -``` -bash run.sh --stage -1 -``` - -## input - -input files: -``` -data/ -├── lexicon.txt -├── text -└── vocab.txt -``` - -``` -==> data/text <== -BAC009S0002W0122 而 对 楼市 成交 抑制 作用 最 大 的 限 购 -BAC009S0002W0123 也 成为 地方 政府 的 眼中 钉 -BAC009S0002W0124 自 六月 底 呼和浩特 市 率先 宣布 取消 限 购 后 -BAC009S0002W0125 各地 政府 便 纷纷 跟进 -BAC009S0002W0126 仅 一 个 多 月 的 时间 里 -BAC009S0002W0127 除了 北京 上海 广州 深圳 四 个 一 线 城市 和 三亚 之外 -BAC009S0002W0128 四十六 个 限 购 城市 当中 -BAC009S0002W0129 四十一 个 已 正式 取消 或 变相 放松 了 限 购 -BAC009S0002W0130 财政 金融 政策 紧随 其后 而来 -BAC009S0002W0131 显示 出 了 极 强 的 威力 - -==> data/lexicon.txt <== -SIL sil - sil -啊 aa a1 -啊 aa a2 -啊 aa a4 -啊 aa a5 -啊啊啊 aa a2 aa a2 aa a2 -啊啊啊 aa a5 aa a5 aa a5 -坐地 z uo4 d i4 -坐实 z uo4 sh ix2 -坐视 z uo4 sh ix4 -坐稳 z uo4 uu un3 -坐拥 z uo4 ii iong1 -坐诊 z uo4 zh en3 -坐庄 z uo4 zh uang1 -坐姿 z uo4 z iy1 - -==> data/vocab.txt <== - - -A -B -C -D -E -龙 -龚 -龛 - -``` - -## output - -``` -data/ -├── local -│ ├── dict -│ │ ├── lexicon.txt -│ │ └── units.txt -│ └── lm -│ ├── heldout -│ ├── lm.arpa -│ ├── text -│ ├── text.no_oov -│ ├── train -│ ├── unigram.counts -│ ├── word.counts -│ └── wordlist -``` - -``` -/workspace/srilm/bin/i686-m64/ngram-count -Namespace(bpemodel=None, in_lexicon='data/lexicon.txt', out_lexicon='data/local/dict/lexicon.txt', unit_file='data/vocab.txt') -Ignoring words 矽, which contains oov unit -Ignoring words 傩, which contains oov unit -Ignoring words 堀, which contains oov unit -Ignoring words 莼, which contains oov unit -Ignoring words 菰, which contains oov unit -Ignoring words 摭, which contains oov unit -Ignoring words 帙, which contains oov unit -Ignoring words 迨, which contains oov unit -Ignoring words 孥, which contains oov unit -Ignoring words 瑗, which contains oov unit -... -... -... -file data/local/lm/heldout: 10000 sentences, 89496 words, 0 OOVs -0 zeroprobs, logprob= -270337.9 ppl= 521.2819 ppl1= 1048.745 -build LM done. -``` diff --git a/speechx/examples/ngram/zh/local/split_data.sh b/speechx/examples/ngram/zh/local/split_data.sh deleted file mode 100755 index 2af6fc5ab..000000000 --- a/speechx/examples/ngram/zh/local/split_data.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env bash - -set -eo pipefail - -data=$1 -scp=$2 -split_name=$3 -numsplit=$4 - -# save in $data/split{n} -# $scp to split -# - -if [[ ! $numsplit -gt 0 ]]; then - echo "Invalid num-split argument"; - exit 1; -fi - -directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n; done) -scp_splits=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n/${split_name}; done) - -# if this mkdir fails due to argument-list being too long, iterate. -if ! mkdir -p $directories >&/dev/null; then - for n in `seq $numsplit`; do - mkdir -p $data/split${numsplit}/$n - done -fi - -echo "utils/split_scp.pl $scp $scp_splits" -utils/split_scp.pl $scp $scp_splits diff --git a/speechx/examples/ngram/zh/path.sh b/speechx/examples/ngram/zh/path.sh deleted file mode 100644 index a3fb3d758..000000000 --- a/speechx/examples/ngram/zh/path.sh +++ /dev/null @@ -1,12 +0,0 @@ -# This contains the locations of binarys build required for running the examples. - -MAIN_ROOT=`realpath $PWD/../../../../` -SPEECHX_ROOT=`realpath $MAIN_ROOT/speechx` - -export LC_AL=C - -# srilm -export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10 -export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs -export SRILM=${MAIN_ROOT}/tools/srilm -export PATH=${PATH}:${SRILM}/bin:${SRILM}/bin/i686-m64 diff --git a/speechx/examples/ngram/zh/run.sh b/speechx/examples/ngram/zh/run.sh deleted file mode 100755 index f24ad0a7c..000000000 --- a/speechx/examples/ngram/zh/run.sh +++ /dev/null @@ -1,68 +0,0 @@ -#!/bin/bash -set -eo pipefail - -. path.sh - -stage=-1 -stop_stage=100 -corpus=aishell - -unit=data/vocab.txt # vocab file, line: char/spm_pice -lexicon=data/lexicon.txt # line: word ph0 ... phn, aishell/resource_aishell/lexicon.txt -text=data/text # line: utt text, aishell/data_aishell/transcript/aishell_transcript_v0.8.txt - -. utils/parse_options.sh - -data=$PWD/data -mkdir -p $data - -if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - if [ ! -f $data/speech.ngram.zh.tar.gz ];then - pushd $data - wget -c http://paddlespeech.bj.bcebos.com/speechx/examples/ngram/zh/speech.ngram.zh.tar.gz - tar xvzf speech.ngram.zh.tar.gz - popd - fi -fi - -if [ ! -f $unit ]; then - echo "$0: No such file $unit" - exit 1; -fi - -if ! which ngram-count; then - pushd $MAIN_ROOT/tools - make srilm.done - popd -fi - -mkdir -p data/local/dict -if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then - # 7.1 Prepare dict - # line: char/spm_pices - cp $unit data/local/dict/units.txt - - if [ ! -f $lexicon ];then - local/text_to_lexicon.py --has_key true --text $text --lexicon $lexicon - echo "Generate $lexicon from $text" - fi - - # filter by vocab - # line: word ph0 ... phn -> line: word char0 ... charn - utils/fst/prepare_dict.py \ - --unit_file $unit \ - --in_lexicon ${lexicon} \ - --out_lexicon data/local/dict/lexicon.txt -fi - -lm=data/local/lm -mkdir -p $lm - -if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - # 7.2 Train lm - cp $text $lm/text - local/aishell_train_lms.sh -fi - -echo "build LM done." -exit 0 diff --git a/speechx/examples/wfst/.gitignore b/speechx/examples/wfst/.gitignore deleted file mode 100644 index 1269488f7..000000000 --- a/speechx/examples/wfst/.gitignore +++ /dev/null @@ -1 +0,0 @@ -data diff --git a/speechx/examples/wfst/README.md b/speechx/examples/wfst/README.md deleted file mode 100644 index d0bdac0fc..000000000 --- a/speechx/examples/wfst/README.md +++ /dev/null @@ -1,186 +0,0 @@ -# Built TLG wfst - -## Input -``` -data/local/ -├── dict -│ ├── lexicon.txt -│ └── units.txt -└── lm - ├── heldout - ├── lm.arpa - ├── text - ├── text.no_oov - ├── train - ├── unigram.counts - ├── word.counts - └── wordlist -``` - -``` -==> data/local/dict/lexicon.txt <== -啊 啊 -啊啊啊 啊 啊 啊 -阿 阿 -阿尔 阿 尔 -阿根廷 阿 根 廷 -阿九 阿 九 -阿克 阿 克 -阿拉伯数字 阿 拉 伯 数 字 -阿拉法特 阿 拉 法 特 -阿拉木图 阿 拉 木 图 - -==> data/local/dict/units.txt <== - - -A -B -C -D -E -F -G -H - -==> data/local/lm/heldout <== -而 对 楼市 成交 抑制 作用 最 大 的 限 购 -也 成为 地方 政府 的 眼中 钉 -自 六月 底 呼和浩特 市 率先 宣布 取消 限 购 后 -各地 政府 便 纷纷 跟进 -仅 一 个 多 月 的 时间 里 -除了 北京 上海 广州 深圳 四 个 一 线 城市 和 三亚 之外 -四十六 个 限 购 城市 当中 -四十一 个 已 正式 取消 或 变相 放松 了 限 购 -财政 金融 政策 紧随 其后 而来 -显示 出 了 极 强 的 威力 - -==> data/local/lm/lm.arpa <== - -\data\ -ngram 1=129356 -ngram 2=504661 -ngram 3=123455 - -\1-grams: --1.531278 --3.828829 -0.1600094 --6.157292 - -==> data/local/lm/text <== -BAC009S0002W0122 而 对 楼市 成交 抑制 作用 最 大 的 限 购 -BAC009S0002W0123 也 成为 地方 政府 的 眼中 钉 -BAC009S0002W0124 自 六月 底 呼和浩特 市 率先 宣布 取消 限 购 后 -BAC009S0002W0125 各地 政府 便 纷纷 跟进 -BAC009S0002W0126 仅 一 个 多 月 的 时间 里 -BAC009S0002W0127 除了 北京 上海 广州 深圳 四 个 一 线 城市 和 三亚 之外 -BAC009S0002W0128 四十六 个 限 购 城市 当中 -BAC009S0002W0129 四十一 个 已 正式 取消 或 变相 放松 了 限 购 -BAC009S0002W0130 财政 金融 政策 紧随 其后 而来 -BAC009S0002W0131 显示 出 了 极 强 的 威力 - -==> data/local/lm/text.no_oov <== - 而 对 楼市 成交 抑制 作用 最 大 的 限 购 - 也 成为 地方 政府 的 眼中 钉 - 自 六月 底 呼和浩特 市 率先 宣布 取消 限 购 后 - 各地 政府 便 纷纷 跟进 - 仅 一 个 多 月 的 时间 里 - 除了 北京 上海 广州 深圳 四 个 一 线 城市 和 三亚 之外 - 四十六 个 限 购 城市 当中 - 四十一 个 已 正式 取消 或 变相 放松 了 限 购 - 财政 ���融 政策 紧随 其后 而来 - 显示 出 了 极 强 的 威力 - -==> data/local/lm/train <== -汉莎 不 得 不 通过 这样 的 方式 寻求 新 的 发展 点 -并 计划 朝云 计算 方面 发展 -汉莎 的 基础 设施 部门 拥有 一千四百 名 员工 -媒体 就 曾 披露 这笔 交易 -虽然 双方 已经 正式 签署 了 外包 协议 -但是 这笔 交易 还 需要 得到 反 垄断 部门 的 批准 -陈 黎明 一九八九 年 获得 美国 康乃尔 大学 硕士 学位 -并 于 二零零三 年 顺利 完成 美国 哈佛 商学 院 高级 管理 课程 -曾 在 多家 国际 公司 任职 -拥有 业务 开发 商务 及 企业 治理 - -==> data/local/lm/unigram.counts <== - 57487 的 - 13099 在 - 11862 一 - 11397 了 - 10998 不 - 9913 是 - 7952 有 - 6250 和 - 6152 个 - 5422 将 - -==> data/local/lm/word.counts <== - 57486 的 - 13098 在 - 11861 一 - 11396 了 - 10997 不 - 9912 是 - 7951 有 - 6249 和 - 6151 个 - 5421 将 - -==> data/local/lm/wordlist <== -的 -在 -一 -了 -不 -是 -有 -和 -个 -将 -``` - -## Output - -``` -fstaddselfloops 'echo 4234 |' 'echo 123660 |' -Lexicon and Token FSTs compiling succeeded -arpa2fst --read-symbol-table=data/lang_test/words.txt --keep-symbols=true - -LOG (arpa2fst[5.5.0~1-5a37]:Read():arpa-file-parser.cc:94) Reading \data\ section. -LOG (arpa2fst[5.5.0~1-5a37]:Read():arpa-file-parser.cc:149) Reading \1-grams: section. -LOG (arpa2fst[5.5.0~1-5a37]:Read():arpa-file-parser.cc:149) Reading \2-grams: section. -LOG (arpa2fst[5.5.0~1-5a37]:Read():arpa-file-parser.cc:149) Reading \3-grams: section. -Checking how stochastic G is (the first of these numbers should be small): -fstisstochastic data/lang_test/G.fst -0 -1.14386 -fsttablecompose data/lang_test/L.fst data/lang_test/G.fst -fstminimizeencoded -fstdeterminizestar --use-log=true -fsttablecompose data/lang_test/T.fst data/lang_test/LG.fst -Composing decoding graph TLG.fst succeeded -Aishell build TLG done. -``` - -``` -data/ -├── lang_test -│ ├── G.fst -│ ├── L.fst -│ ├── LG.fst -│ ├── T.fst -│ ├── TLG.fst -│ ├── tokens.txt -│ ├── units.txt -│ └── words.txt -└── local - ├── lang - │ ├── L.fst - │ ├── T.fst - │ ├── tokens.txt - │ ├── units.txt - │ └── words.txt - └── tmp - ├── disambig.list - ├── lexiconp_disambig.txt - ├── lexiconp.txt - └── units.list -``` diff --git a/speechx/examples/wfst/path.sh b/speechx/examples/wfst/path.sh deleted file mode 100644 index a07c1297d..000000000 --- a/speechx/examples/wfst/path.sh +++ /dev/null @@ -1,19 +0,0 @@ -# This contains the locations of binarys build required for running the examples. - -MAIN_ROOT=`realpath $PWD/../../../` -SPEECHX_ROOT=`realpath $MAIN_ROOT/speechx` - -export LC_AL=C - -# srilm -export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10 -export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs -export SRILM=${MAIN_ROOT}/tools/srilm -export PATH=${PATH}:${SRILM}/bin:${SRILM}/bin/i686-m64 - -# Kaldi -export KALDI_ROOT=${MAIN_ROOT}/tools/kaldi -[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh -export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH -[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present, can not using Kaldi!" -[ -f $KALDI_ROOT/tools/config/common_path.sh ] && . $KALDI_ROOT/tools/config/common_path.sh diff --git a/speechx/examples/wfst/run.sh b/speechx/examples/wfst/run.sh deleted file mode 100755 index 1354646af..000000000 --- a/speechx/examples/wfst/run.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/bash -set -eo pipefail - -. path.sh - -stage=-1 -stop_stage=100 - -. utils/parse_options.sh - -if ! which fstprint ; then - pushd $MAIN_ROOT/tools - make kaldi.done - popd -fi - -if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then - # build T & L - # utils/fst/compile_lexicon_token_fst.sh - utils/fst/compile_lexicon_token_fst.sh \ - data/local/dict data/local/tmp data/local/lang - - # build G & LG & TLG - # utils/fst/make_tlg.sh - utils/fst/make_tlg.sh data/local/lm data/local/lang data/lang_test || exit 1; -fi - -echo "build TLG done." -exit 0 diff --git a/speechx/examples/wfst/utils b/speechx/examples/wfst/utils deleted file mode 120000 index 256f914ab..000000000 --- a/speechx/examples/wfst/utils +++ /dev/null @@ -1 +0,0 @@ -../../../utils/ \ No newline at end of file diff --git a/speechx/speechx/CMakeLists.txt b/speechx/speechx/CMakeLists.txt index b4da095d8..c8e21d486 100644 --- a/speechx/speechx/CMakeLists.txt +++ b/speechx/speechx/CMakeLists.txt @@ -34,6 +34,12 @@ add_subdirectory(decoder) include_directories( ${CMAKE_CURRENT_SOURCE_DIR} -${CMAKE_CURRENT_SOURCE_DIR}/websocket +${CMAKE_CURRENT_SOURCE_DIR}/protocol ) -add_subdirectory(websocket) +add_subdirectory(protocol) + +include_directories( +${CMAKE_CURRENT_SOURCE_DIR} +${CMAKE_CURRENT_SOURCE_DIR}/codelab +) +add_subdirectory(codelab) diff --git a/speechx/examples/dev/CMakeLists.txt b/speechx/speechx/codelab/CMakeLists.txt similarity index 76% rename from speechx/examples/dev/CMakeLists.txt rename to speechx/speechx/codelab/CMakeLists.txt index c8445fb82..950432637 100644 --- a/speechx/examples/dev/CMakeLists.txt +++ b/speechx/speechx/codelab/CMakeLists.txt @@ -1,3 +1,4 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR) add_subdirectory(glog) +add_subdirectory(nnet) diff --git a/speechx/speechx/codelab/README.md b/speechx/speechx/codelab/README.md new file mode 100644 index 000000000..077c4cef2 --- /dev/null +++ b/speechx/speechx/codelab/README.md @@ -0,0 +1,6 @@ + +## For Developer + +> Reminder: Only for developer. + +* codelab - for speechx developer, using for test. diff --git a/speechx/speechx/codelab/glog/CMakeLists.txt b/speechx/speechx/codelab/glog/CMakeLists.txt new file mode 100644 index 000000000..08a98641f --- /dev/null +++ b/speechx/speechx/codelab/glog/CMakeLists.txt @@ -0,0 +1,8 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +add_executable(glog_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_main.cc) +target_link_libraries(glog_main glog) + + +add_executable(glog_logtostderr_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_logtostderr_main.cc) +target_link_libraries(glog_logtostderr_main glog) diff --git a/speechx/examples/dev/glog/README.md b/speechx/speechx/codelab/glog/README.md similarity index 92% rename from speechx/examples/dev/glog/README.md rename to speechx/speechx/codelab/glog/README.md index 996e192e9..3282c920d 100644 --- a/speechx/examples/dev/glog/README.md +++ b/speechx/speechx/codelab/glog/README.md @@ -23,3 +23,16 @@ You can also modify flag values in your program by modifying global variables `F FLAGS_log_dir = "/some/log/directory"; LOG(INFO) << "the same file"; ``` + +* this is the test script: +``` +# run +glog_test + +echo "------" +export FLAGS_logtostderr=1 +glog_test + +echo "------" +glog_logtostderr_test +``` diff --git a/speechx/examples/dev/glog/glog_logtostderr_test.cc b/speechx/speechx/codelab/glog/glog_logtostderr_main.cc similarity index 100% rename from speechx/examples/dev/glog/glog_logtostderr_test.cc rename to speechx/speechx/codelab/glog/glog_logtostderr_main.cc diff --git a/speechx/examples/dev/glog/glog_test.cc b/speechx/speechx/codelab/glog/glog_main.cc similarity index 100% rename from speechx/examples/dev/glog/glog_test.cc rename to speechx/speechx/codelab/glog/glog_main.cc diff --git a/speechx/examples/ds2_ol/nnet/CMakeLists.txt b/speechx/speechx/codelab/nnet/CMakeLists.txt similarity index 87% rename from speechx/examples/ds2_ol/nnet/CMakeLists.txt rename to speechx/speechx/codelab/nnet/CMakeLists.txt index 6745a51ae..dcad8a9c6 100644 --- a/speechx/examples/ds2_ol/nnet/CMakeLists.txt +++ b/speechx/speechx/codelab/nnet/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR) -set(bin_name ds2-model-ol-test) +set(bin_name ds2_model_test_main) add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(${bin_name} PUBLIC nnet gflags glog ${DEPS}) \ No newline at end of file +target_link_libraries(${bin_name} PUBLIC nnet gflags glog ${DEPS}) diff --git a/speechx/examples/ds2_ol/nnet/ds2-model-ol-test.cc b/speechx/speechx/codelab/nnet/ds2_model_test_main.cc similarity index 100% rename from speechx/examples/ds2_ol/nnet/ds2-model-ol-test.cc rename to speechx/speechx/codelab/nnet/ds2_model_test_main.cc diff --git a/speechx/speechx/decoder/CMakeLists.txt b/speechx/speechx/decoder/CMakeLists.txt index 06bf4020f..1df935112 100644 --- a/speechx/speechx/decoder/CMakeLists.txt +++ b/speechx/speechx/decoder/CMakeLists.txt @@ -10,3 +10,16 @@ add_library(decoder STATIC recognizer.cc ) target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder) + +set(BINS + ctc_prefix_beam_search_decoder_main + nnet_logprob_decoder_main + recognizer_main + tlg_decoder_main +) + +foreach(bin_name IN LISTS BINS) + add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) + target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) +endforeach() diff --git a/speechx/examples/ds2_ol/decoder/ctc-prefix-beam-search-decoder-ol.cc b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc similarity index 94% rename from speechx/examples/ds2_ol/decoder/ctc-prefix-beam-search-decoder-ol.cc rename to speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc index eaec41b71..7cfee06c9 100644 --- a/speechx/examples/ds2_ol/decoder/ctc-prefix-beam-search-decoder-ol.cc +++ b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc @@ -30,10 +30,10 @@ DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); DEFINE_string(lm_path, "", "language model"); DEFINE_int32(receptive_field_length, 7, - "receptive field of two CNN(kernel=5) downsampling module."); + "receptive field of two CNN(kernel=3) downsampling module."); DEFINE_int32(downsampling_rate, 4, - "two CNN(kernel=5) module downsampling rate."); + "two CNN(kernel=3) module downsampling rate."); DEFINE_string( model_input_names, "audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box", @@ -45,6 +45,7 @@ DEFINE_string(model_cache_names, "chunk_state_h_box,chunk_state_c_box", "model cache names"); DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes"); +DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk"); using kaldi::BaseFloat; using kaldi::Matrix; @@ -90,8 +91,9 @@ int main(int argc, char* argv[]) { std::shared_ptr decodable( new ppspeech::Decodable(nnet, raw_data)); - int32 chunk_size = FLAGS_receptive_field_length; - int32 chunk_stride = FLAGS_downsampling_rate; + int32 chunk_size = FLAGS_receptive_field_length + + (FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; + int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk; int32 receptive_field_length = FLAGS_receptive_field_length; LOG(INFO) << "chunk size (frame): " << chunk_size; LOG(INFO) << "chunk stride (frame): " << chunk_stride; diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.cc b/speechx/speechx/decoder/ctc_tlg_decoder.cc index 02e643165..712d27dd4 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.cc +++ b/speechx/speechx/decoder/ctc_tlg_decoder.cc @@ -47,6 +47,26 @@ void TLGDecoder::Reset() { return; } +std::string TLGDecoder::GetPartialResult() { + if (frame_decoded_size_ == 0) { + // Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call + // BestPathEnd if no frames were decoded.") + return std::string(""); + } + kaldi::Lattice lat; + kaldi::LatticeWeight weight; + std::vector alignment; + std::vector words_id; + decoder_->GetBestPath(&lat, false); + fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight); + std::string words; + for (int32 idx = 0; idx < words_id.size(); ++idx) { + std::string word = word_symbol_table_->Find(words_id[idx]); + words += word; + } + return words; +} + std::string TLGDecoder::GetFinalBestPath() { if (frame_decoded_size_ == 0) { // Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.h b/speechx/speechx/decoder/ctc_tlg_decoder.h index 361c44af5..1ac46ac64 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.h +++ b/speechx/speechx/decoder/ctc_tlg_decoder.h @@ -38,6 +38,7 @@ class TLGDecoder { std::string GetBestPath(); std::vector> GetNBestPath(); std::string GetFinalBestPath(); + std::string GetPartialResult(); int NumFrameDecoded(); int DecodeLikelihoods(const std::vector>& probs, std::vector& nbest_words); diff --git a/speechx/examples/ds2_ol/decoder/nnet-logprob-decoder-test.cc b/speechx/speechx/decoder/nnet_logprob_decoder_main.cc similarity index 100% rename from speechx/examples/ds2_ol/decoder/nnet-logprob-decoder-test.cc rename to speechx/speechx/decoder/nnet_logprob_decoder_main.cc diff --git a/speechx/speechx/decoder/param.h b/speechx/speechx/decoder/param.h index b2bf1890a..d6ee27058 100644 --- a/speechx/speechx/decoder/param.h +++ b/speechx/speechx/decoder/param.h @@ -25,15 +25,14 @@ DEFINE_bool(use_fbank, false, "False for fbank; or linear feature"); // feature, or fbank"); DEFINE_int32(num_bins, 161, "num bins of mel"); DEFINE_string(cmvn_file, "", "read cmvn"); -DEFINE_double(streaming_chunk, 0.1, "streaming feature chunk size"); // feature sliding window DEFINE_int32(receptive_field_length, 7, - "receptive field of two CNN(kernel=5) downsampling module."); + "receptive field of two CNN(kernel=3) downsampling module."); DEFINE_int32(downsampling_rate, 4, - "two CNN(kernel=5) module downsampling rate."); - + "two CNN(kernel=3) module downsampling rate."); +DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk"); // nnet DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); @@ -62,7 +61,6 @@ namespace ppspeech { FeaturePipelineOptions InitFeaturePipelineOptions() { FeaturePipelineOptions opts; opts.cmvn_file = FLAGS_cmvn_file; - opts.linear_spectrogram_opts.streaming_chunk = FLAGS_streaming_chunk; kaldi::FrameExtractionOptions frame_opts; frame_opts.dither = 0.0; frame_opts.frame_shift_ms = 10; @@ -71,8 +69,8 @@ FeaturePipelineOptions InitFeaturePipelineOptions() { opts.to_float32 = false; frame_opts.window_type = "povey"; frame_opts.frame_length_ms = 25; - opts.fbank_opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins; - opts.fbank_opts.fbank_opts.frame_opts = frame_opts; + opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins; + opts.fbank_opts.frame_opts = frame_opts; } else { opts.to_float32 = true; frame_opts.remove_dc_offset = false; @@ -81,8 +79,10 @@ FeaturePipelineOptions InitFeaturePipelineOptions() { frame_opts.preemph_coeff = 0.0; opts.linear_spectrogram_opts.frame_opts = frame_opts; } - opts.feature_cache_opts.frame_chunk_size = FLAGS_receptive_field_length; - opts.feature_cache_opts.frame_chunk_stride = FLAGS_downsampling_rate; + opts.assembler_opts.subsampling_rate = FLAGS_downsampling_rate; + opts.assembler_opts.receptive_filed_length = FLAGS_receptive_field_length; + opts.assembler_opts.nnet_decoder_chunk = FLAGS_nnet_decoder_chunk; + return opts; } @@ -115,4 +115,4 @@ RecognizerResource InitRecognizerResoure() { resource.tlg_opts = InitDecoderOptions(); return resource; } -} \ No newline at end of file +} diff --git a/speechx/speechx/decoder/recognizer.cc b/speechx/speechx/decoder/recognizer.cc index 2c90ada99..44c3911c9 100644 --- a/speechx/speechx/decoder/recognizer.cc +++ b/speechx/speechx/decoder/recognizer.cc @@ -44,6 +44,10 @@ std::string Recognizer::GetFinalResult() { return decoder_->GetFinalBestPath(); } +std::string Recognizer::GetPartialResult() { + return decoder_->GetPartialResult(); +} + void Recognizer::SetFinished() { feature_pipeline_->SetFinished(); input_finished_ = true; diff --git a/speechx/speechx/decoder/recognizer.h b/speechx/speechx/decoder/recognizer.h index 9a7e7d11e..35e1e1676 100644 --- a/speechx/speechx/decoder/recognizer.h +++ b/speechx/speechx/decoder/recognizer.h @@ -43,6 +43,7 @@ class Recognizer { void Accept(const kaldi::Vector& waves); void Decode(); std::string GetFinalResult(); + std::string GetPartialResult(); void SetFinished(); bool IsFinished(); void Reset(); diff --git a/speechx/examples/ds2_ol/decoder/recognizer_test_main.cc b/speechx/speechx/decoder/recognizer_main.cc similarity index 98% rename from speechx/examples/ds2_ol/decoder/recognizer_test_main.cc rename to speechx/speechx/decoder/recognizer_main.cc index 7aef73f74..232513539 100644 --- a/speechx/examples/ds2_ol/decoder/recognizer_test_main.cc +++ b/speechx/speechx/decoder/recognizer_main.cc @@ -19,6 +19,7 @@ DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier"); +DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_int32(sample_rate, 16000, "sample rate"); int main(int argc, char* argv[]) { @@ -96,4 +97,4 @@ int main(int argc, char* argv[]) { KALDI_LOG << " cost:" << elapsed << " s"; KALDI_LOG << "total wav duration is: " << tot_wav_duration << " s"; KALDI_LOG << "the RTF is: " << elapsed / tot_wav_duration; -} \ No newline at end of file +} diff --git a/speechx/examples/ds2_ol/decoder/wfst-decoder-ol.cc b/speechx/speechx/decoder/tlg_decoder_main.cc similarity index 94% rename from speechx/examples/ds2_ol/decoder/wfst-decoder-ol.cc rename to speechx/speechx/decoder/tlg_decoder_main.cc index fefc16d2c..b175ed135 100644 --- a/speechx/examples/ds2_ol/decoder/wfst-decoder-ol.cc +++ b/speechx/speechx/decoder/tlg_decoder_main.cc @@ -28,15 +28,15 @@ DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string(word_symbol_table, "words.txt", "word symbol table"); DEFINE_string(graph_path, "TLG", "decoder graph"); - DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); DEFINE_int32(max_active, 7500, "decoder graph"); +DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk"); DEFINE_int32(receptive_field_length, 7, - "receptive field of two CNN(kernel=5) downsampling module."); + "receptive field of two CNN(kernel=3) downsampling module."); DEFINE_int32(downsampling_rate, 4, - "two CNN(kernel=5) module downsampling rate."); + "two CNN(kernel=3) module downsampling rate."); DEFINE_string( model_input_names, "audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box", @@ -93,8 +93,9 @@ int main(int argc, char* argv[]) { std::shared_ptr decodable( new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale)); - int32 chunk_size = FLAGS_receptive_field_length; - int32 chunk_stride = FLAGS_downsampling_rate; + int32 chunk_size = FLAGS_receptive_field_length + + (FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; + int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk; int32 receptive_field_length = FLAGS_receptive_field_length; LOG(INFO) << "chunk size (frame): " << chunk_size; LOG(INFO) << "chunk stride (frame): " << chunk_stride; diff --git a/speechx/speechx/frontend/audio/CMakeLists.txt b/speechx/speechx/frontend/audio/CMakeLists.txt index 745832fe7..8ae63256a 100644 --- a/speechx/speechx/frontend/audio/CMakeLists.txt +++ b/speechx/speechx/frontend/audio/CMakeLists.txt @@ -8,6 +8,24 @@ add_library(frontend STATIC feature_cache.cc feature_pipeline.cc fbank.cc + assembler.cc ) - target_link_libraries(frontend PUBLIC kaldi-matrix kaldi-feat-common kaldi-fbank) + + + +set(bin_name cmvn_json2kaldi_main) +add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog) + +set(BINS + compute_linear_spectrogram_main + compute_fbank_main +) + +foreach(bin_name IN LISTS BINS) + add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) + target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + target_link_libraries(${bin_name} PUBLIC frontend utils kaldi-util gflags glog) +endforeach() diff --git a/speechx/speechx/frontend/audio/assembler.cc b/speechx/speechx/frontend/audio/assembler.cc new file mode 100644 index 000000000..37eeec80f --- /dev/null +++ b/speechx/speechx/frontend/audio/assembler.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "frontend/audio/assembler.h" + +namespace ppspeech { + +using kaldi::Vector; +using kaldi::VectorBase; +using kaldi::BaseFloat; +using std::unique_ptr; + +Assembler::Assembler(AssemblerOptions opts, + unique_ptr base_extractor) { + frame_chunk_stride_ = opts.subsampling_rate * opts.nnet_decoder_chunk; + frame_chunk_size_ = (opts.nnet_decoder_chunk - 1) * opts.subsampling_rate + + opts.receptive_filed_length; + receptive_filed_length_ = opts.receptive_filed_length; + base_extractor_ = std::move(base_extractor); + dim_ = base_extractor_->Dim(); +} + +void Assembler::Accept(const kaldi::VectorBase& inputs) { + // read inputs + base_extractor_->Accept(inputs); +} + +// pop feature chunk +bool Assembler::Read(kaldi::Vector* feats) { + feats->Resize(dim_ * frame_chunk_size_); + bool result = Compute(feats); + return result; +} + +// read all data from base_feature_extractor_ into cache_ +bool Assembler::Compute(Vector* feats) { + // compute and feed + bool result = false; + while (feature_cache_.size() < frame_chunk_size_) { + Vector feature; + result = base_extractor_->Read(&feature); + if (result == false || feature.Dim() == 0) { + if (IsFinished() == false) return false; + break; + } + feature_cache_.push(feature); + } + + if (feature_cache_.size() < receptive_filed_length_) { + return false; + } + + while (feature_cache_.size() < frame_chunk_size_) { + Vector feature(dim_, kaldi::kSetZero); + feature_cache_.push(feature); + } + + int32 counter = 0; + int32 cache_size = frame_chunk_size_ - frame_chunk_stride_; + int32 elem_dim = base_extractor_->Dim(); + while (counter < frame_chunk_size_) { + Vector& val = feature_cache_.front(); + int32 start = counter * elem_dim; + feats->Range(start, elem_dim).CopyFromVec(val); + if (frame_chunk_size_ - counter <= cache_size) { + feature_cache_.push(val); + } + feature_cache_.pop(); + counter++; + } + + return result; +} + +} // namespace ppspeech diff --git a/speechx/speechx/frontend/audio/assembler.h b/speechx/speechx/frontend/audio/assembler.h new file mode 100644 index 000000000..258e61f2b --- /dev/null +++ b/speechx/speechx/frontend/audio/assembler.h @@ -0,0 +1,68 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "base/common.h" +#include "frontend/audio/frontend_itf.h" + +namespace ppspeech { + +struct AssemblerOptions { + // refer:https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/paddlespeech/s2t/exps/deepspeech2/model.py + // the nnet batch forward + int32 receptive_filed_length; + int32 subsampling_rate; + int32 nnet_decoder_chunk; + + AssemblerOptions() + : receptive_filed_length(1), + subsampling_rate(1), + nnet_decoder_chunk(1) {} +}; + +class Assembler : public FrontendInterface { + public: + explicit Assembler( + AssemblerOptions opts, + std::unique_ptr base_extractor = NULL); + + // Feed feats or waves + virtual void Accept(const kaldi::VectorBase& inputs); + + // feats size = num_frames * feat_dim + virtual bool Read(kaldi::Vector* feats); + + // feat dim + virtual size_t Dim() const { return dim_; } + + virtual void SetFinished() { base_extractor_->SetFinished(); } + + virtual bool IsFinished() const { return base_extractor_->IsFinished(); } + + virtual void Reset() { base_extractor_->Reset(); } + + private: + bool Compute(kaldi::Vector* feats); + + int32 dim_; + int32 frame_chunk_size_; // window + int32 frame_chunk_stride_; // stride + int32 receptive_filed_length_; + std::queue> feature_cache_; + std::unique_ptr base_extractor_; + DISALLOW_COPY_AND_ASSIGN(Assembler); +}; + +} // namespace ppspeech diff --git a/speechx/speechx/frontend/audio/audio_cache.h b/speechx/speechx/frontend/audio/audio_cache.h index 4ebcd9474..fc07d4bab 100644 --- a/speechx/speechx/frontend/audio/audio_cache.h +++ b/speechx/speechx/frontend/audio/audio_cache.h @@ -30,8 +30,9 @@ class AudioCache : public FrontendInterface { virtual bool Read(kaldi::Vector* waves); - // the audio dim is 1, one sample - virtual size_t Dim() const { return 1; } + // the audio dim is 1, one sample, which is useless, + // so we return size_(cache samples) instead. + virtual size_t Dim() const { return size_; } virtual void SetFinished() { std::lock_guard lock(mutex_); diff --git a/speechx/examples/ds2_ol/feat/cmvn-json2kaldi.cc b/speechx/speechx/frontend/audio/cmvn_json2kaldi_main.cc similarity index 100% rename from speechx/examples/ds2_ol/feat/cmvn-json2kaldi.cc rename to speechx/speechx/frontend/audio/cmvn_json2kaldi_main.cc diff --git a/speechx/examples/ds2_ol/feat/compute_fbank_main.cc b/speechx/speechx/frontend/audio/compute_fbank_main.cc similarity index 91% rename from speechx/examples/ds2_ol/feat/compute_fbank_main.cc rename to speechx/speechx/frontend/audio/compute_fbank_main.cc index 67683eebf..f7a42315f 100644 --- a/speechx/examples/ds2_ol/feat/compute_fbank_main.cc +++ b/speechx/speechx/frontend/audio/compute_fbank_main.cc @@ -49,12 +49,11 @@ int main(int argc, char* argv[]) { std::unique_ptr data_source( new ppspeech::AudioCache(3600 * 1600, false)); - ppspeech::FbankOptions opt; - opt.fbank_opts.frame_opts.frame_length_ms = 25; - opt.fbank_opts.frame_opts.frame_shift_ms = 10; - opt.streaming_chunk = FLAGS_streaming_chunk; - opt.fbank_opts.mel_opts.num_bins = FLAGS_num_bins; - opt.fbank_opts.frame_opts.dither = 0.0; + kaldi::FbankOptions opt; + opt.frame_opts.frame_length_ms = 25; + opt.frame_opts.frame_shift_ms = 10; + opt.mel_opts.num_bins = FLAGS_num_bins; + opt.frame_opts.dither = 0.0; std::unique_ptr fbank( new ppspeech::Fbank(opt, std::move(data_source))); @@ -64,10 +63,6 @@ int main(int argc, char* argv[]) { ppspeech::FeatureCacheOptions feat_cache_opts; // the feature cache output feature chunk by chunk. - // frame_chunk_size : num frame of a chunk. - // frame_chunk_stride: chunk sliding window stride. - feat_cache_opts.frame_chunk_stride = 1; - feat_cache_opts.frame_chunk_size = 1; ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn)); LOG(INFO) << "fbank: " << true; LOG(INFO) << "feat dim: " << feature_cache.Dim(); diff --git a/speechx/examples/ds2_ol/feat/linear-spectrogram-wo-db-norm-ol.cc b/speechx/speechx/frontend/audio/compute_linear_spectrogram_main.cc similarity index 95% rename from speechx/examples/ds2_ol/feat/linear-spectrogram-wo-db-norm-ol.cc rename to speechx/speechx/frontend/audio/compute_linear_spectrogram_main.cc index bbf0e6908..162c3529d 100644 --- a/speechx/examples/ds2_ol/feat/linear-spectrogram-wo-db-norm-ol.cc +++ b/speechx/speechx/frontend/audio/compute_linear_spectrogram_main.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// todo refactor, repalce with gtest - #include "base/flags.h" #include "base/log.h" #include "kaldi/feat/wave-reader.h" @@ -51,7 +49,6 @@ int main(int argc, char* argv[]) { ppspeech::LinearSpectrogramOptions opt; opt.frame_opts.frame_length_ms = 20; opt.frame_opts.frame_shift_ms = 10; - opt.streaming_chunk = FLAGS_streaming_chunk; opt.frame_opts.dither = 0.0; opt.frame_opts.remove_dc_offset = false; opt.frame_opts.window_type = "hanning"; @@ -68,10 +65,6 @@ int main(int argc, char* argv[]) { ppspeech::FeatureCacheOptions feat_cache_opts; // the feature cache output feature chunk by chunk. - // frame_chunk_size : num frame of a chunk. - // frame_chunk_stride: chunk sliding window stride. - feat_cache_opts.frame_chunk_stride = 1; - feat_cache_opts.frame_chunk_size = 1; ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn)); LOG(INFO) << "feat dim: " << feature_cache.Dim(); diff --git a/speechx/speechx/frontend/audio/fbank.cc b/speechx/speechx/frontend/audio/fbank.cc index fea9032ac..059abbbd1 100644 --- a/speechx/speechx/frontend/audio/fbank.cc +++ b/speechx/speechx/frontend/audio/fbank.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - #include "frontend/audio/fbank.h" #include "kaldi/base/kaldi-math.h" #include "kaldi/feat/feature-common.h" @@ -29,95 +28,33 @@ using kaldi::VectorBase; using kaldi::Matrix; using std::vector; -// todo refactor later:(SmileGoat) - -Fbank::Fbank(const FbankOptions& opts, - std::unique_ptr base_extractor) - : opts_(opts), - computer_(opts.fbank_opts), - window_function_(opts.fbank_opts.frame_opts) { - base_extractor_ = std::move(base_extractor); - chunk_sample_size_ = static_cast( - opts.streaming_chunk * opts.fbank_opts.frame_opts.samp_freq); -} +FbankComputer::FbankComputer(const Options& opts) + : opts_(opts), computer_(opts) {} -void Fbank::Accept(const VectorBase& inputs) { - base_extractor_->Accept(inputs); +int32 FbankComputer::Dim() const { + return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0); } -bool Fbank::Read(Vector* feats) { - Vector wav(chunk_sample_size_); - bool flag = base_extractor_->Read(&wav); - if (flag == false || wav.Dim() == 0) return false; - - // append remaned waves - int32 wav_len = wav.Dim(); - int32 left_len = remained_wav_.Dim(); - Vector waves(left_len + wav_len); - waves.Range(0, left_len).CopyFromVec(remained_wav_); - waves.Range(left_len, wav_len).CopyFromVec(wav); - - // compute speech feature - Compute(waves, feats); - - // cache remaned waves - kaldi::FrameExtractionOptions frame_opts = computer_.GetFrameOptions(); - int32 num_frames = kaldi::NumFrames(waves.Dim(), frame_opts); - int32 frame_shift = frame_opts.WindowShift(); - int32 left_samples = waves.Dim() - frame_shift * num_frames; - remained_wav_.Resize(left_samples); - remained_wav_.CopyFromVec( - waves.Range(frame_shift * num_frames, left_samples)); - return true; +bool FbankComputer::NeedRawLogEnergy() { + return opts_.use_energy && opts_.raw_energy; } -// Compute spectrogram feat -bool Fbank::Compute(const Vector& waves, Vector* feats) { - const kaldi::FrameExtractionOptions& frame_opts = - computer_.GetFrameOptions(); - int32 num_samples = waves.Dim(); - int32 frame_length = frame_opts.WindowSize(); - int32 sample_rate = frame_opts.samp_freq; - if (num_samples < frame_length) { - return true; - } - - int32 num_frames = kaldi::NumFrames(num_samples, frame_opts); - feats->Resize(num_frames * Dim()); - - Vector window; - bool need_raw_log_energy = computer_.NeedRawLogEnergy(); - for (int32 frame = 0; frame < num_frames; frame++) { - BaseFloat raw_log_energy = 0.0; - kaldi::ExtractWindow(0, - waves, - frame, - frame_opts, - window_function_, - &window, - need_raw_log_energy ? &raw_log_energy : NULL); - - - Vector this_feature(computer_.Dim(), kaldi::kUndefined); - // note: this online feature-extraction code does not support VTLN. - RealFft(&window, true); - kaldi::ComputePowerSpectrum(&window); - const kaldi::MelBanks& mel_bank = *(computer_.GetMelBanks(1.0)); - SubVector power_spectrum(window, 0, window.Dim() / 2 + 1); - if (!opts_.fbank_opts.use_power) { - power_spectrum.ApplyPow(0.5); - } - int32 mel_offset = - ((opts_.fbank_opts.use_energy && !opts_.fbank_opts.htk_compat) ? 1 - : 0); - SubVector mel_energies( - this_feature, mel_offset, opts_.fbank_opts.mel_opts.num_bins); - mel_bank.Compute(power_spectrum, &mel_energies); - mel_energies.ApplyFloor(1e-07); - mel_energies.ApplyLog(); - SubVector output_row(feats->Data() + frame * Dim(), Dim()); - output_row.CopyFromVec(this_feature); +// Compute feat +bool FbankComputer::Compute(Vector* window, + Vector* feat) { + RealFft(window, true); + kaldi::ComputePowerSpectrum(window); + const kaldi::MelBanks& mel_bank = *(computer_.GetMelBanks(1.0)); + SubVector power_spectrum(*window, 0, window->Dim() / 2 + 1); + if (!opts_.use_power) { + power_spectrum.ApplyPow(0.5); } + int32 mel_offset = ((opts_.use_energy && !opts_.htk_compat) ? 1 : 0); + SubVector mel_energies( + *feat, mel_offset, opts_.mel_opts.num_bins); + mel_bank.Compute(power_spectrum, &mel_energies); + mel_energies.ApplyFloor(1e-07); + mel_energies.ApplyLog(); return true; } diff --git a/speechx/speechx/frontend/audio/fbank.h b/speechx/speechx/frontend/audio/fbank.h index 66957dc6d..a1e654138 100644 --- a/speechx/speechx/frontend/audio/fbank.h +++ b/speechx/speechx/frontend/audio/fbank.h @@ -15,6 +15,7 @@ #pragma once #include "base/common.h" +#include "frontend/audio/feature_common.h" #include "frontend/audio/frontend_itf.h" #include "kaldi/feat/feature-fbank.h" #include "kaldi/feat/feature-mfcc.h" @@ -22,56 +23,28 @@ namespace ppspeech { -struct FbankOptions { - kaldi::FbankOptions fbank_opts; - kaldi::BaseFloat streaming_chunk; // second - - FbankOptions() : streaming_chunk(0.1), fbank_opts() {} - - void Register(kaldi::OptionsItf* opts) { - opts->Register("streaming-chunk", - &streaming_chunk, - "streaming chunk size, default: 0.1 sec"); - fbank_opts.Register(opts); - } -}; - - -class Fbank : public FrontendInterface { +class FbankComputer { public: - explicit Fbank(const FbankOptions& opts, - std::unique_ptr base_extractor); - virtual void Accept(const kaldi::VectorBase& inputs); - virtual bool Read(kaldi::Vector* feats); + typedef kaldi::FbankOptions Options; + explicit FbankComputer(const Options& opts); - // the dim_ is the dim of single frame feature - virtual size_t Dim() const { return computer_.Dim(); } - - virtual void SetFinished() { base_extractor_->SetFinished(); } + kaldi::FrameExtractionOptions& GetFrameOptions() { + return opts_.frame_opts; + } - virtual bool IsFinished() const { return base_extractor_->IsFinished(); } + bool Compute(kaldi::Vector* window, + kaldi::Vector* feat); + int32 Dim() const; - virtual void Reset() { - base_extractor_->Reset(); - remained_wav_.Resize(0); - } + bool NeedRawLogEnergy(); private: - bool Compute(const kaldi::Vector& waves, - kaldi::Vector* feats); + Options opts_; - FbankOptions opts_; - std::unique_ptr base_extractor_; - - kaldi::FeatureWindowFunction window_function_; kaldi::FbankComputer computer_; - // features_ is the Mfcc or Plp or Fbank features that we have already - // computed. - kaldi::Vector features_; - kaldi::Vector remained_wav_; - kaldi::int32 chunk_sample_size_; - - DISALLOW_COPY_AND_ASSIGN(Fbank); + DISALLOW_COPY_AND_ASSIGN(FbankComputer); }; +typedef StreamingFeatureTpl Fbank; + } // namespace ppspeech diff --git a/speechx/speechx/frontend/audio/feature_cache.cc b/speechx/speechx/frontend/audio/feature_cache.cc index 05283bb7e..509a98c3b 100644 --- a/speechx/speechx/frontend/audio/feature_cache.cc +++ b/speechx/speechx/frontend/audio/feature_cache.cc @@ -26,8 +26,6 @@ using std::unique_ptr; FeatureCache::FeatureCache(FeatureCacheOptions opts, unique_ptr base_extractor) { max_size_ = opts.max_size; - frame_chunk_stride_ = opts.frame_chunk_stride; - frame_chunk_size_ = opts.frame_chunk_size; timeout_ = opts.timeout; // ms base_extractor_ = std::move(base_extractor); dim_ = base_extractor_->Dim(); @@ -74,24 +72,11 @@ bool FeatureCache::Compute() { bool result = base_extractor_->Read(&feature); if (result == false || feature.Dim() == 0) return false; - // join with remained - int32 joint_len = feature.Dim() + remained_feature_.Dim(); - Vector joint_feature(joint_len); - joint_feature.Range(0, remained_feature_.Dim()) - .CopyFromVec(remained_feature_); - joint_feature.Range(remained_feature_.Dim(), feature.Dim()) - .CopyFromVec(feature); - - // one by one, or stride with window - // controlled by frame_chunk_stride_ and frame_chunk_size_ - int32 num_chunk = - ((joint_len / dim_) - frame_chunk_size_) / frame_chunk_stride_ + 1; + int32 num_chunk = feature.Dim() / dim_; for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) { - int32 start = chunk_idx * frame_chunk_stride_ * dim_; - - Vector feature_chunk(frame_chunk_size_ * dim_); - SubVector tmp(joint_feature.Data() + start, - frame_chunk_size_ * dim_); + int32 start = chunk_idx * dim_; + Vector feature_chunk(dim_); + SubVector tmp(feature.Data() + start, dim_); feature_chunk.CopyFromVec(tmp); std::unique_lock lock(mutex_); @@ -104,13 +89,6 @@ bool FeatureCache::Compute() { cache_.push(feature_chunk); ready_read_condition_.notify_one(); } - - // cache remained feats - int32 remained_feature_len = - joint_len - num_chunk * frame_chunk_stride_ * dim_; - remained_feature_.Resize(remained_feature_len); - remained_feature_.CopyFromVec(joint_feature.Range( - frame_chunk_stride_ * num_chunk * dim_, remained_feature_len)); return result; } diff --git a/speechx/speechx/frontend/audio/feature_cache.h b/speechx/speechx/frontend/audio/feature_cache.h index 0dc704bbf..b922de12c 100644 --- a/speechx/speechx/frontend/audio/feature_cache.h +++ b/speechx/speechx/frontend/audio/feature_cache.h @@ -21,14 +21,8 @@ namespace ppspeech { struct FeatureCacheOptions { int32 max_size; - int32 frame_chunk_size; - int32 frame_chunk_stride; int32 timeout; // ms - FeatureCacheOptions() - : max_size(kint16max), - frame_chunk_size(1), - frame_chunk_stride(1), - timeout(1) {} + FeatureCacheOptions() : max_size(kint16max), timeout(1) {} }; class FeatureCache : public FrontendInterface { @@ -80,7 +74,7 @@ class FeatureCache : public FrontendInterface { std::condition_variable ready_feed_condition_; std::condition_variable ready_read_condition_; - // DISALLOW_COPY_AND_ASSGIN(FeatureCache); + DISALLOW_COPY_AND_ASSIGN(FeatureCache); }; } // namespace ppspeech diff --git a/speechx/speechx/frontend/audio/feature_common.h b/speechx/speechx/frontend/audio/feature_common.h new file mode 100644 index 000000000..bad705c9f --- /dev/null +++ b/speechx/speechx/frontend/audio/feature_common.h @@ -0,0 +1,55 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "frontend_itf.h" +#include "kaldi/feat/feature-window.h" + +namespace ppspeech { + +template +class StreamingFeatureTpl : public FrontendInterface { + public: + typedef typename F::Options Options; + StreamingFeatureTpl(const Options& opts, + std::unique_ptr base_extractor); + virtual void Accept(const kaldi::VectorBase& waves); + virtual bool Read(kaldi::Vector* feats); + + // the dim_ is the dim of single frame feature + virtual size_t Dim() const { return computer_.Dim(); } + + virtual void SetFinished() { base_extractor_->SetFinished(); } + + virtual bool IsFinished() const { return base_extractor_->IsFinished(); } + + virtual void Reset() { + base_extractor_->Reset(); + remained_wav_.Resize(0); + } + + private: + bool Compute(const kaldi::Vector& waves, + kaldi::Vector* feats); + Options opts_; + std::unique_ptr base_extractor_; + kaldi::FeatureWindowFunction window_function_; + kaldi::Vector remained_wav_; + F computer_; +}; + +} // namespace ppspeech + +#include "frontend/audio/feature_common_inl.h" diff --git a/speechx/speechx/frontend/audio/feature_common_inl.h b/speechx/speechx/frontend/audio/feature_common_inl.h new file mode 100644 index 000000000..b86f79918 --- /dev/null +++ b/speechx/speechx/frontend/audio/feature_common_inl.h @@ -0,0 +1,97 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +namespace ppspeech { + +template +StreamingFeatureTpl::StreamingFeatureTpl( + const Options& opts, std::unique_ptr base_extractor) + : opts_(opts), computer_(opts), window_function_(opts.frame_opts) { + base_extractor_ = std::move(base_extractor); +} + +template +void StreamingFeatureTpl::Accept( + const kaldi::VectorBase& waves) { + base_extractor_->Accept(waves); +} + +template +bool StreamingFeatureTpl::Read(kaldi::Vector* feats) { + kaldi::Vector wav(base_extractor_->Dim()); + bool flag = base_extractor_->Read(&wav); + if (flag == false || wav.Dim() == 0) return false; + + // append remaned waves + int32 wav_len = wav.Dim(); + int32 left_len = remained_wav_.Dim(); + kaldi::Vector waves(left_len + wav_len); + waves.Range(0, left_len).CopyFromVec(remained_wav_); + waves.Range(left_len, wav_len).CopyFromVec(wav); + + // compute speech feature + Compute(waves, feats); + + // cache remaned waves + kaldi::FrameExtractionOptions frame_opts = computer_.GetFrameOptions(); + int32 num_frames = kaldi::NumFrames(waves.Dim(), frame_opts); + int32 frame_shift = frame_opts.WindowShift(); + int32 left_samples = waves.Dim() - frame_shift * num_frames; + remained_wav_.Resize(left_samples); + remained_wav_.CopyFromVec( + waves.Range(frame_shift * num_frames, left_samples)); + return true; +} + +// Compute feat +template +bool StreamingFeatureTpl::Compute( + const kaldi::Vector& waves, + kaldi::Vector* feats) { + const kaldi::FrameExtractionOptions& frame_opts = + computer_.GetFrameOptions(); + int32 num_samples = waves.Dim(); + int32 frame_length = frame_opts.WindowSize(); + int32 sample_rate = frame_opts.samp_freq; + if (num_samples < frame_length) { + return true; + } + + int32 num_frames = kaldi::NumFrames(num_samples, frame_opts); + feats->Resize(num_frames * Dim()); + + kaldi::Vector window; + bool need_raw_log_energy = computer_.NeedRawLogEnergy(); + for (int32 frame = 0; frame < num_frames; frame++) { + kaldi::BaseFloat raw_log_energy = 0.0; + kaldi::ExtractWindow(0, + waves, + frame, + frame_opts, + window_function_, + &window, + need_raw_log_energy ? &raw_log_energy : NULL); + + kaldi::Vector this_feature(computer_.Dim(), + kaldi::kUndefined); + computer_.Compute(&window, &this_feature); + kaldi::SubVector output_row( + feats->Data() + frame * Dim(), Dim()); + output_row.CopyFromVec(this_feature); + } + return true; +} + +} // namespace ppspeech diff --git a/speechx/speechx/frontend/audio/feature_pipeline.cc b/speechx/speechx/frontend/audio/feature_pipeline.cc index 087de0f0d..9cacff9f7 100644 --- a/speechx/speechx/frontend/audio/feature_pipeline.cc +++ b/speechx/speechx/frontend/audio/feature_pipeline.cc @@ -35,8 +35,11 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) { unique_ptr cmvn( new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature))); - base_extractor_.reset( + unique_ptr cache( new ppspeech::FeatureCache(opts.feature_cache_opts, std::move(cmvn))); + + base_extractor_.reset( + new ppspeech::Assembler(opts.assembler_opts, std::move(cache))); } -} // ppspeech \ No newline at end of file +} // ppspeech diff --git a/speechx/speechx/frontend/audio/feature_pipeline.h b/speechx/speechx/frontend/audio/feature_pipeline.h index 6b9b4795e..48f95e3f3 100644 --- a/speechx/speechx/frontend/audio/feature_pipeline.h +++ b/speechx/speechx/frontend/audio/feature_pipeline.h @@ -16,6 +16,7 @@ #pragma once +#include "frontend/audio/assembler.h" #include "frontend/audio/audio_cache.h" #include "frontend/audio/data_cache.h" #include "frontend/audio/fbank.h" @@ -31,15 +32,18 @@ struct FeaturePipelineOptions { bool to_float32; // true, only for linear feature bool use_fbank; LinearSpectrogramOptions linear_spectrogram_opts; - FbankOptions fbank_opts; + kaldi::FbankOptions fbank_opts; FeatureCacheOptions feature_cache_opts; + AssemblerOptions assembler_opts; + FeaturePipelineOptions() : cmvn_file(""), to_float32(false), // true, only for linear feature use_fbank(true), linear_spectrogram_opts(), fbank_opts(), - feature_cache_opts() {} + feature_cache_opts(), + assembler_opts() {} }; class FeaturePipeline : public FrontendInterface { @@ -59,4 +63,4 @@ class FeaturePipeline : public FrontendInterface { private: std::unique_ptr base_extractor_; }; -} \ No newline at end of file +} diff --git a/speechx/speechx/frontend/audio/linear_spectrogram.cc b/speechx/speechx/frontend/audio/linear_spectrogram.cc index 9ef5e7664..55c039787 100644 --- a/speechx/speechx/frontend/audio/linear_spectrogram.cc +++ b/speechx/speechx/frontend/audio/linear_spectrogram.cc @@ -28,81 +28,31 @@ using kaldi::VectorBase; using kaldi::Matrix; using std::vector; -LinearSpectrogram::LinearSpectrogram( - const LinearSpectrogramOptions& opts, - std::unique_ptr base_extractor) - : opts_(opts), feature_window_funtion_(opts.frame_opts) { - base_extractor_ = std::move(base_extractor); +LinearSpectrogramComputer::LinearSpectrogramComputer(const Options& opts) + : opts_(opts) { + kaldi::FeatureWindowFunction feature_window_function(opts.frame_opts); int32 window_size = opts.frame_opts.WindowSize(); - int32 window_shift = opts.frame_opts.WindowShift(); + frame_length_ = window_size; dim_ = window_size / 2 + 1; - chunk_sample_size_ = - static_cast(opts.streaming_chunk * opts.frame_opts.samp_freq); - hanning_window_energy_ = kaldi::VecVec(feature_window_funtion_.window, - feature_window_funtion_.window); -} - -void LinearSpectrogram::Accept(const VectorBase& inputs) { - base_extractor_->Accept(inputs); -} - -bool LinearSpectrogram::Read(Vector* feats) { - Vector input_feats(chunk_sample_size_); - bool flag = base_extractor_->Read(&input_feats); - if (flag == false || input_feats.Dim() == 0) return false; - - int32 feat_len = input_feats.Dim(); - int32 left_len = remained_wav_.Dim(); - Vector waves(feat_len + left_len); - waves.Range(0, left_len).CopyFromVec(remained_wav_); - waves.Range(left_len, feat_len).CopyFromVec(input_feats); - Compute(waves, feats); - int32 frame_shift = opts_.frame_opts.WindowShift(); - int32 num_frames = kaldi::NumFrames(waves.Dim(), opts_.frame_opts); - int32 left_samples = waves.Dim() - frame_shift * num_frames; - remained_wav_.Resize(left_samples); - remained_wav_.CopyFromVec( - waves.Range(frame_shift * num_frames, left_samples)); - return true; + BaseFloat hanning_window_energy = kaldi::VecVec( + feature_window_function.window, feature_window_function.window); + int32 sample_rate = opts.frame_opts.samp_freq; + scale_ = 2.0 / (hanning_window_energy * sample_rate); } // Compute spectrogram feat -bool LinearSpectrogram::Compute(const Vector& waves, - Vector* feats) { - int32 num_samples = waves.Dim(); - int32 frame_length = opts_.frame_opts.WindowSize(); - int32 sample_rate = opts_.frame_opts.samp_freq; - BaseFloat scale = 2.0 / (hanning_window_energy_ * sample_rate); - - if (num_samples < frame_length) { - return true; - } - - int32 num_frames = kaldi::NumFrames(num_samples, opts_.frame_opts); - feats->Resize(num_frames * dim_); - Vector window; - - for (int frame_idx = 0; frame_idx < num_frames; ++frame_idx) { - kaldi::ExtractWindow(0, - waves, - frame_idx, - opts_.frame_opts, - feature_window_funtion_, - &window, - NULL); - - SubVector output_row(feats->Data() + frame_idx * dim_, dim_); - window.Resize(frame_length, kaldi::kCopyData); - RealFft(&window, true); - kaldi::ComputePowerSpectrum(&window); - SubVector power_spectrum(window, 0, dim_); - power_spectrum.Scale(scale); - power_spectrum(0) = power_spectrum(0) / 2; - power_spectrum(dim_ - 1) = power_spectrum(dim_ - 1) / 2; - power_spectrum.Add(1e-14); - power_spectrum.ApplyLog(); - output_row.CopyFromVec(power_spectrum); - } +bool LinearSpectrogramComputer::Compute(Vector* window, + Vector* feat) { + window->Resize(frame_length_, kaldi::kCopyData); + RealFft(window, true); + kaldi::ComputePowerSpectrum(window); + SubVector power_spectrum(*window, 0, dim_); + power_spectrum.Scale(scale_); + power_spectrum(0) = power_spectrum(0) / 2; + power_spectrum(dim_ - 1) = power_spectrum(dim_ - 1) / 2; + power_spectrum.Add(1e-14); + power_spectrum.ApplyLog(); + feat->CopyFromVec(power_spectrum); return true; } diff --git a/speechx/speechx/frontend/audio/linear_spectrogram.h b/speechx/speechx/frontend/audio/linear_spectrogram.h index 2764b7cf4..de957c235 100644 --- a/speechx/speechx/frontend/audio/linear_spectrogram.h +++ b/speechx/speechx/frontend/audio/linear_spectrogram.h @@ -16,6 +16,7 @@ #pragma once #include "base/common.h" +#include "frontend/audio/feature_common.h" #include "frontend/audio/frontend_itf.h" #include "kaldi/feat/feature-window.h" @@ -23,47 +24,34 @@ namespace ppspeech { struct LinearSpectrogramOptions { kaldi::FrameExtractionOptions frame_opts; - kaldi::BaseFloat streaming_chunk; // second - - LinearSpectrogramOptions() : streaming_chunk(0.1), frame_opts() {} - - void Register(kaldi::OptionsItf* opts) { - opts->Register("streaming-chunk", - &streaming_chunk, - "streaming chunk size, default: 0.1 sec"); - frame_opts.Register(opts); - } + LinearSpectrogramOptions() : frame_opts() {} }; -class LinearSpectrogram : public FrontendInterface { +class LinearSpectrogramComputer { public: - explicit LinearSpectrogram( - const LinearSpectrogramOptions& opts, - std::unique_ptr base_extractor); - virtual void Accept(const kaldi::VectorBase& inputs); - virtual bool Read(kaldi::Vector* feats); - // the dim_ is the dim of single frame feature - virtual size_t Dim() const { return dim_; } - virtual void SetFinished() { base_extractor_->SetFinished(); } - virtual bool IsFinished() const { return base_extractor_->IsFinished(); } - virtual void Reset() { - base_extractor_->Reset(); - remained_wav_.Resize(0); + typedef LinearSpectrogramOptions Options; + explicit LinearSpectrogramComputer(const Options& opts); + + kaldi::FrameExtractionOptions& GetFrameOptions() { + return opts_.frame_opts; } - private: - bool Compute(const kaldi::Vector& waves, - kaldi::Vector* feats); + bool Compute(kaldi::Vector* window, + kaldi::Vector* feat); - size_t dim_; - kaldi::FeatureWindowFunction feature_window_funtion_; - kaldi::BaseFloat hanning_window_energy_; - LinearSpectrogramOptions opts_; - std::unique_ptr base_extractor_; - kaldi::Vector remained_wav_; - int chunk_sample_size_; - DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram); + int32 Dim() const { return dim_; } + + bool NeedRawLogEnergy() { return false; } + + private: + kaldi::BaseFloat scale_; + Options opts_; + int32 frame_length_; + int32 dim_; + DISALLOW_COPY_AND_ASSIGN(LinearSpectrogramComputer); }; +typedef StreamingFeatureTpl LinearSpectrogram; + } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/kaldi/CMakeLists.txt b/speechx/speechx/kaldi/CMakeLists.txt index 6f7398cd1..ce6b43f63 100644 --- a/speechx/speechx/kaldi/CMakeLists.txt +++ b/speechx/speechx/kaldi/CMakeLists.txt @@ -7,3 +7,7 @@ add_subdirectory(matrix) add_subdirectory(lat) add_subdirectory(fstext) add_subdirectory(decoder) +add_subdirectory(lm) + +add_subdirectory(fstbin) +add_subdirectory(lmbin) \ No newline at end of file diff --git a/speechx/speechx/kaldi/fstbin/CMakeLists.txt b/speechx/speechx/kaldi/fstbin/CMakeLists.txt new file mode 100644 index 000000000..05d0501f3 --- /dev/null +++ b/speechx/speechx/kaldi/fstbin/CMakeLists.txt @@ -0,0 +1,15 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +set(BINS +fstaddselfloops +fstisstochastic +fstminimizeencoded +fstdeterminizestar +fsttablecompose +) + +foreach(binary IN LISTS BINS) + add_executable(${binary} ${CMAKE_CURRENT_SOURCE_DIR}/${binary}.cc) + target_include_directories(${binary} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + target_link_libraries(${binary} PUBLIC kaldi-fstext glog gflags fst dl) +endforeach() diff --git a/speechx/tools/fstbin/fstaddselfloops.cc b/speechx/speechx/kaldi/fstbin/fstaddselfloops.cc similarity index 100% rename from speechx/tools/fstbin/fstaddselfloops.cc rename to speechx/speechx/kaldi/fstbin/fstaddselfloops.cc diff --git a/speechx/tools/fstbin/fstdeterminizestar.cc b/speechx/speechx/kaldi/fstbin/fstdeterminizestar.cc similarity index 100% rename from speechx/tools/fstbin/fstdeterminizestar.cc rename to speechx/speechx/kaldi/fstbin/fstdeterminizestar.cc diff --git a/speechx/tools/fstbin/fstisstochastic.cc b/speechx/speechx/kaldi/fstbin/fstisstochastic.cc similarity index 100% rename from speechx/tools/fstbin/fstisstochastic.cc rename to speechx/speechx/kaldi/fstbin/fstisstochastic.cc diff --git a/speechx/tools/fstbin/fstminimizeencoded.cc b/speechx/speechx/kaldi/fstbin/fstminimizeencoded.cc similarity index 100% rename from speechx/tools/fstbin/fstminimizeencoded.cc rename to speechx/speechx/kaldi/fstbin/fstminimizeencoded.cc diff --git a/speechx/tools/fstbin/fsttablecompose.cc b/speechx/speechx/kaldi/fstbin/fsttablecompose.cc similarity index 100% rename from speechx/tools/fstbin/fsttablecompose.cc rename to speechx/speechx/kaldi/fstbin/fsttablecompose.cc diff --git a/speechx/speechx/kaldi/fstext/CMakeLists.txt b/speechx/speechx/kaldi/fstext/CMakeLists.txt index af91fd985..465d9dba7 100644 --- a/speechx/speechx/kaldi/fstext/CMakeLists.txt +++ b/speechx/speechx/kaldi/fstext/CMakeLists.txt @@ -1,5 +1,5 @@ add_library(kaldi-fstext -kaldi-fst-io.cc + kaldi-fst-io.cc ) target_link_libraries(kaldi-fstext PUBLIC kaldi-util) diff --git a/speechx/speechx/kaldi/lm/CMakeLists.txt b/speechx/speechx/kaldi/lm/CMakeLists.txt new file mode 100644 index 000000000..75c1567e7 --- /dev/null +++ b/speechx/speechx/kaldi/lm/CMakeLists.txt @@ -0,0 +1,6 @@ + +add_library(kaldi-lm + arpa-file-parser.cc + arpa-lm-compiler.cc +) +target_link_libraries(kaldi-lm PUBLIC kaldi-util) \ No newline at end of file diff --git a/speechx/speechx/kaldi/lm/arpa-file-parser.cc b/speechx/speechx/kaldi/lm/arpa-file-parser.cc new file mode 100644 index 000000000..81b63ed13 --- /dev/null +++ b/speechx/speechx/kaldi/lm/arpa-file-parser.cc @@ -0,0 +1,281 @@ +// lm/arpa-file-parser.cc + +// Copyright 2014 Guoguo Chen +// Copyright 2016 Smart Action Company LLC (kkm) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "base/kaldi-error.h" +#include "base/kaldi-math.h" +#include "lm/arpa-file-parser.h" +#include "util/text-utils.h" + +namespace kaldi { + +ArpaFileParser::ArpaFileParser(const ArpaParseOptions& options, + fst::SymbolTable* symbols) + : options_(options), symbols_(symbols), + line_number_(0), warning_count_(0) { +} + +ArpaFileParser::~ArpaFileParser() { +} + +void TrimTrailingWhitespace(std::string *str) { + str->erase(str->find_last_not_of(" \n\r\t") + 1); +} + +void ArpaFileParser::Read(std::istream &is) { + // Argument sanity checks. + if (options_.bos_symbol <= 0 || options_.eos_symbol <= 0 || + options_.bos_symbol == options_.eos_symbol) + KALDI_ERR << "BOS and EOS symbols are required, must not be epsilons, and " + << "differ from each other. Given:" + << " BOS=" << options_.bos_symbol + << " EOS=" << options_.eos_symbol; + if (symbols_ != NULL && + options_.oov_handling == ArpaParseOptions::kReplaceWithUnk && + (options_.unk_symbol <= 0 || + options_.unk_symbol == options_.bos_symbol || + options_.unk_symbol == options_.eos_symbol)) + KALDI_ERR << "When symbol table is given and OOV mode is kReplaceWithUnk, " + << "UNK symbol is required, must not be epsilon, and " + << "differ from both BOS and EOS symbols. Given:" + << " UNK=" << options_.unk_symbol + << " BOS=" << options_.bos_symbol + << " EOS=" << options_.eos_symbol; + if (symbols_ != NULL && symbols_->Find(options_.bos_symbol).empty()) + KALDI_ERR << "BOS symbol must exist in symbol table"; + if (symbols_ != NULL && symbols_->Find(options_.eos_symbol).empty()) + KALDI_ERR << "EOS symbol must exist in symbol table"; + if (symbols_ != NULL && options_.unk_symbol > 0 && + symbols_->Find(options_.unk_symbol).empty()) + KALDI_ERR << "UNK symbol must exist in symbol table"; + + ngram_counts_.clear(); + line_number_ = 0; + warning_count_ = 0; + current_line_.clear(); + +#define PARSE_ERR KALDI_ERR << LineReference() << ": " + + // Give derived class an opportunity to prepare its state. + ReadStarted(); + + // Processes "\data\" section. + bool keyword_found = false; + while (++line_number_, getline(is, current_line_) && !is.eof()) { + if (current_line_.find_first_not_of(" \t\n\r") == std::string::npos) { + continue; + } + + TrimTrailingWhitespace(¤t_line_); + + // Continue skipping lines until the \data\ marker alone on a line is found. + if (!keyword_found) { + if (current_line_ == "\\data\\") { + KALDI_LOG << "Reading \\data\\ section."; + keyword_found = true; + } + continue; + } + + if (current_line_[0] == '\\') break; + + // Enters "\data\" section, and looks for patterns like "ngram 1=1000", + // which means there are 1000 unigrams. + std::size_t equal_symbol_pos = current_line_.find("="); + if (equal_symbol_pos != std::string::npos) + // Guaranteed spaces around the "=". + current_line_.replace(equal_symbol_pos, 1, " = "); + std::vector col; + SplitStringToVector(current_line_, " \t", true, &col); + if (col.size() == 4 && col[0] == "ngram" && col[2] == "=") { + int32 order, ngram_count = 0; + if (!ConvertStringToInteger(col[1], &order) || + !ConvertStringToInteger(col[3], &ngram_count)) { + PARSE_ERR << "cannot parse ngram count"; + } + if (ngram_counts_.size() <= order) { + ngram_counts_.resize(order); + } + ngram_counts_[order - 1] = ngram_count; + } else { + KALDI_WARN << LineReference() + << ": uninterpretable line in \\data\\ section"; + } + } + + if (ngram_counts_.size() == 0) + PARSE_ERR << "\\data\\ section missing or empty."; + + // Signal that grammar order and n-gram counts are known. + HeaderAvailable(); + + NGram ngram; + ngram.words.reserve(ngram_counts_.size()); + + // Processes "\N-grams:" section. + for (int32 cur_order = 1; cur_order <= ngram_counts_.size(); ++cur_order) { + // Skips n-grams with zero count. + if (ngram_counts_[cur_order - 1] == 0) + KALDI_WARN << "Zero ngram count in ngram order " << cur_order + << "(look for 'ngram " << cur_order << "=0' in the \\data\\ " + << " section). There is possibly a problem with the file."; + + // Must be looking at a \k-grams: directive at this point. + std::ostringstream keyword; + keyword << "\\" << cur_order << "-grams:"; + if (current_line_ != keyword.str()) { + PARSE_ERR << "invalid directive, expecting '" << keyword.str() << "'"; + } + KALDI_LOG << "Reading " << current_line_ << " section."; + + int32 ngram_count = 0; + while (++line_number_, getline(is, current_line_) && !is.eof()) { + if (current_line_.find_first_not_of(" \n\t\r") == std::string::npos) { + continue; + } + if (current_line_[0] == '\\') { + TrimTrailingWhitespace(¤t_line_); + std::ostringstream next_keyword; + next_keyword << "\\" << cur_order + 1 << "-grams:"; + if ((current_line_ != next_keyword.str()) && + (current_line_ != "\\end\\")) { + if (ShouldWarn()) { + KALDI_WARN << "ignoring possible directive '" << current_line_ + << "' expecting '" << next_keyword.str() << "'"; + + if (warning_count_ > 0 && + warning_count_ > static_cast(options_.max_warnings)) { + KALDI_WARN << "Of " << warning_count_ << " parse warnings, " + << options_.max_warnings << " were reported. " + << "Run program with --max-arpa-warnings=-1 " + << "to see all warnings"; + } + } + } else { + break; + } + } + + std::vector col; + SplitStringToVector(current_line_, " \t", true, &col); + + if (col.size() < 1 + cur_order || + col.size() > 2 + cur_order || + (cur_order == ngram_counts_.size() && col.size() != 1 + cur_order)) { + PARSE_ERR << "Invalid n-gram data line"; + } + ++ngram_count; + + // Parse out n-gram logprob and, if present, backoff weight. + if (!ConvertStringToReal(col[0], &ngram.logprob)) { + PARSE_ERR << "invalid n-gram logprob '" << col[0] << "'"; + } + ngram.backoff = 0.0; + if (col.size() > cur_order + 1) { + if (!ConvertStringToReal(col[cur_order + 1], &ngram.backoff)) + PARSE_ERR << "invalid backoff weight '" << col[cur_order + 1] << "'"; + } + // Convert to natural log. + ngram.logprob *= M_LN10; + ngram.backoff *= M_LN10; + + ngram.words.resize(cur_order); + bool skip_ngram = false; + for (int32 index = 0; !skip_ngram && index < cur_order; ++index) { + int32 word; + if (symbols_) { + // Symbol table provided, so symbol labels are expected. + if (options_.oov_handling == ArpaParseOptions::kAddToSymbols) { + word = symbols_->AddSymbol(col[1 + index]); + } else { + word = symbols_->Find(col[1 + index]); + if (word == -1) { // fst::kNoSymbol + switch (options_.oov_handling) { + case ArpaParseOptions::kReplaceWithUnk: + word = options_.unk_symbol; + break; + case ArpaParseOptions::kSkipNGram: + if (ShouldWarn()) + KALDI_WARN << LineReference() << " skipped: word '" + << col[1 + index] << "' not in symbol table"; + skip_ngram = true; + break; + default: + PARSE_ERR << "word '" << col[1 + index] + << "' not in symbol table"; + } + } + } + } else { + // Symbols not provided, LM file should contain integers. + if (!ConvertStringToInteger(col[1 + index], &word) || word < 0) { + PARSE_ERR << "invalid symbol '" << col[1 + index] << "'"; + } + } + // Whichever way we got it, an epsilon is invalid. + if (word == 0) { + PARSE_ERR << "epsilon symbol '" << col[1 + index] + << "' is illegal in ARPA LM"; + } + ngram.words[index] = word; + } + if (!skip_ngram) { + ConsumeNGram(ngram); + } + } + if (ngram_count > ngram_counts_[cur_order - 1]) { + PARSE_ERR << "header said there would be " << ngram_counts_[cur_order - 1] + << " n-grams of order " << cur_order + << ", but we saw more already."; + } + } + + if (current_line_ != "\\end\\") { + PARSE_ERR << "invalid or unexpected directive line, expecting \\end\\"; + } + + if (warning_count_ > 0 && + warning_count_ > static_cast(options_.max_warnings)) { + KALDI_WARN << "Of " << warning_count_ << " parse warnings, " + << options_.max_warnings << " were reported. Run program with " + << "--max-arpa-warnings=-1 to see all warnings"; + } + + current_line_.clear(); + ReadComplete(); + +#undef PARSE_ERR +} + +std::string ArpaFileParser::LineReference() const { + std::ostringstream ss; + ss << "line " << line_number_ << " [" << current_line_ << "]"; + return ss.str(); +} + +bool ArpaFileParser::ShouldWarn() { + return (warning_count_ != -1) && + (++warning_count_ <= static_cast(options_.max_warnings)); +} + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/lm/arpa-file-parser.h b/speechx/speechx/kaldi/lm/arpa-file-parser.h new file mode 100644 index 000000000..99ffba029 --- /dev/null +++ b/speechx/speechx/kaldi/lm/arpa-file-parser.h @@ -0,0 +1,146 @@ +// lm/arpa-file-parser.h + +// Copyright 2014 Guoguo Chen +// Copyright 2016 Smart Action Company LLC (kkm) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_LM_ARPA_FILE_PARSER_H_ +#define KALDI_LM_ARPA_FILE_PARSER_H_ + +#include + +#include +#include + +#include "base/kaldi-types.h" +#include "util/options-itf.h" + +namespace kaldi { + +/** + Options that control ArpaFileParser +*/ +struct ArpaParseOptions { + enum OovHandling { + kRaiseError, ///< Abort on OOV words + kAddToSymbols, ///< Add novel words to the symbol table. + kReplaceWithUnk, ///< Replace OOV words with . + kSkipNGram ///< Skip n-gram with OOV word and continue. + }; + + ArpaParseOptions(): + bos_symbol(-1), eos_symbol(-1), unk_symbol(-1), + oov_handling(kRaiseError), max_warnings(30) { } + + void Register(OptionsItf *opts) { + // Registering only the max_warnings count, since other options are + // treated differently by client programs: some want integer symbols, + // while other are passed words in their command line. + opts->Register("max-arpa-warnings", &max_warnings, + "Maximum warnings to report on ARPA parsing, " + "0 to disable, -1 to show all"); + } + + int32 bos_symbol; ///< Symbol for , Required non-epsilon. + int32 eos_symbol; ///< Symbol for , Required non-epsilon. + int32 unk_symbol; ///< Symbol for , Required for kReplaceWithUnk. + OovHandling oov_handling; ///< How to handle OOV words in the file. + int32 max_warnings; ///< Maximum warnings to report, <0 unlimited. +}; + +/** + A parsed n-gram from ARPA LM file. +*/ +struct NGram { + NGram() : logprob(0.0), backoff(0.0) { } + std::vector words; ///< Symbols in left to right order. + float logprob; ///< Log-prob of the n-gram. + float backoff; ///< log-backoff weight of the n-gram. + ///< Defaults to zero if not specified. +}; + +/** + ArpaFileParser is an abstract base class for ARPA LM file conversion. + + See ConstArpaLmBuilder and ArpaLmCompiler for usage examples. +*/ +class ArpaFileParser { + public: + /// Constructs the parser with the given options and optional symbol table. + /// If symbol table is provided, then the file should contain text n-grams, + /// and the words are mapped to symbols through it. bos_symbol and + /// eos_symbol in the options structure must be valid symbols in the table, + /// and so must be unk_symbol if provided. The table is not owned by the + /// parser, but may be augmented, if oov_handling is set to kAddToSymbols. + /// If symbol table is a null pointer, the file should contain integer + /// symbol values, and oov_handling has no effect. bos_symbol and eos_symbol + /// must be valid symbols still. + ArpaFileParser(const ArpaParseOptions& options, fst::SymbolTable* symbols); + virtual ~ArpaFileParser(); + + /// Read ARPA LM file from a stream. + void Read(std::istream &is); + + /// Parser options. + const ArpaParseOptions& Options() const { return options_; } + + protected: + /// Override called before reading starts. This is the point to prepare + /// any state in the derived class. + virtual void ReadStarted() { } + + /// Override function called to signal that ARPA header with the expected + /// number of n-grams has been read, and ngram_counts() is now valid. + virtual void HeaderAvailable() { } + + /// Pure override that must be implemented to process current n-gram. The + /// n-grams are sent in the file order, which guarantees that all + /// (k-1)-grams are processed before the first k-gram is. + virtual void ConsumeNGram(const NGram&) = 0; + + /// Override function called after the last n-gram has been consumed. + virtual void ReadComplete() { } + + /// Read-only access to symbol table. Not owned, do not make public. + const fst::SymbolTable* Symbols() const { return symbols_; } + + /// Inside ConsumeNGram(), provides the current line number. + int32 LineNumber() const { return line_number_; } + + /// Inside ConsumeNGram(), returns a formatted reference to the line being + /// compiled, to print out as part of diagnostics. + std::string LineReference() const; + + /// Increments warning count, and returns true if a warning should be + /// printed or false if the count has exceeded the set maximum. + bool ShouldWarn(); + + /// N-gram counts. Valid from the point when HeaderAvailable() is called. + const std::vector& NgramCounts() const { return ngram_counts_; } + + private: + ArpaParseOptions options_; + fst::SymbolTable* symbols_; // the pointer is not owned here. + int32 line_number_; + uint32 warning_count_; + std::string current_line_; + std::vector ngram_counts_; +}; + +} // namespace kaldi + +#endif // KALDI_LM_ARPA_FILE_PARSER_H_ diff --git a/speechx/speechx/kaldi/lm/arpa-lm-compiler.cc b/speechx/speechx/kaldi/lm/arpa-lm-compiler.cc new file mode 100644 index 000000000..47bd20d47 --- /dev/null +++ b/speechx/speechx/kaldi/lm/arpa-lm-compiler.cc @@ -0,0 +1,377 @@ +// lm/arpa-lm-compiler.cc + +// Copyright 2009-2011 Gilles Boulianne +// Copyright 2016 Smart Action LLC (kkm) +// Copyright 2017 Xiaohui Zhang + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "base/kaldi-math.h" +#include "lm/arpa-lm-compiler.h" +#include "util/stl-utils.h" +#include "util/text-utils.h" +#include "fstext/remove-eps-local.h" + +namespace kaldi { + +class ArpaLmCompilerImplInterface { + public: + virtual ~ArpaLmCompilerImplInterface() { } + virtual void ConsumeNGram(const NGram& ngram, bool is_highest) = 0; +}; + +namespace { + +typedef int32 StateId; +typedef int32 Symbol; + +// GeneralHistKey can represent state history in an arbitrarily large n +// n-gram model with symbol ids fitting int32. +class GeneralHistKey { + public: + // Construct key from being and end iterators. + template + GeneralHistKey(InputIt begin, InputIt end) : vector_(begin, end) { } + // Construct empty history key. + GeneralHistKey() : vector_() { } + // Return tails of the key as a GeneralHistKey. The tails of an n-gram + // w[1..n] is the sequence w[2..n] (and the heads is w[1..n-1], but the + // key class does not need this operartion). + GeneralHistKey Tails() const { + return GeneralHistKey(vector_.begin() + 1, vector_.end()); + } + // Keys are equal if represent same state. + friend bool operator==(const GeneralHistKey& a, const GeneralHistKey& b) { + return a.vector_ == b.vector_; + } + // Public typename HashType for hashing. + struct HashType : public std::unary_function { + size_t operator()(const GeneralHistKey& key) const { + return VectorHasher().operator()(key.vector_); + } + }; + + private: + std::vector vector_; +}; + +// OptimizedHistKey combines 3 21-bit symbol ID values into one 64-bit +// machine word. allowing significant memory reduction and some runtime +// benefit over GeneralHistKey. Since 3 symbols are enough to track history +// in a 4-gram model, this optimized key is used for smaller models with up +// to 4-gram and symbol values up to 2^21-1. +// +// See GeneralHistKey for interface requirements of a key class. +class OptimizedHistKey { + public: + enum { + kShift = 21, // 21 * 3 = 63 bits for data. + kMaxData = (1 << kShift) - 1 + }; + template + OptimizedHistKey(InputIt begin, InputIt end) : data_(0) { + for (uint32 shift = 0; begin != end; ++begin, shift += kShift) { + data_ |= static_cast(*begin) << shift; + } + } + OptimizedHistKey() : data_(0) { } + OptimizedHistKey Tails() const { + return OptimizedHistKey(data_ >> kShift); + } + friend bool operator==(const OptimizedHistKey& a, const OptimizedHistKey& b) { + return a.data_ == b.data_; + } + struct HashType : public std::unary_function { + size_t operator()(const OptimizedHistKey& key) const { return key.data_; } + }; + + private: + explicit OptimizedHistKey(uint64 data) : data_(data) { } + uint64 data_; +}; + +} // namespace + +template +class ArpaLmCompilerImpl : public ArpaLmCompilerImplInterface { + public: + ArpaLmCompilerImpl(ArpaLmCompiler* parent, fst::StdVectorFst* fst, + Symbol sub_eps); + + virtual void ConsumeNGram(const NGram &ngram, bool is_highest); + + private: + StateId AddStateWithBackoff(HistKey key, float backoff); + void CreateBackoff(HistKey key, StateId state, float weight); + + ArpaLmCompiler *parent_; // Not owned. + fst::StdVectorFst* fst_; // Not owned. + Symbol bos_symbol_; + Symbol eos_symbol_; + Symbol sub_eps_; + + StateId eos_state_; + typedef unordered_map HistoryMap; + HistoryMap history_; +}; + +template +ArpaLmCompilerImpl::ArpaLmCompilerImpl( + ArpaLmCompiler* parent, fst::StdVectorFst* fst, Symbol sub_eps) + : parent_(parent), fst_(fst), bos_symbol_(parent->Options().bos_symbol), + eos_symbol_(parent->Options().eos_symbol), sub_eps_(sub_eps) { + // The algorithm maintains state per history. The 0-gram is a special state + // for empty history. All unigrams (including BOS) backoff into this state. + StateId zerogram = fst_->AddState(); + history_[HistKey()] = zerogram; + + // Also, if is not treated as epsilon, create a common end state for + // all transitions accepting the , since they do not back off. This small + // optimization saves about 2% states in an average grammar. + if (sub_eps_ == 0) { + eos_state_ = fst_->AddState(); + fst_->SetFinal(eos_state_, 0); + } +} + +template +void ArpaLmCompilerImpl::ConsumeNGram(const NGram &ngram, + bool is_highest) { + // Generally, we do the following. Suppose we are adding an n-gram "A B + // C". Then find the node for "A B", add a new node for "A B C", and connect + // them with the arc accepting "C" with the specified weight. Also, add a + // backoff arc from the new "A B C" node to its backoff state "B C". + // + // Two notable exceptions are the highest order n-grams, and final n-grams. + // + // When adding a highest order n-gram (e. g., our "A B C" is in a 3-gram LM), + // the following optimization is performed. There is no point adding a node + // for "A B C" with a "C" arc from "A B", since there will be no other + // arcs ingoing to this node, and an epsilon backoff arc into the backoff + // model "B C", with the weight of \bar{1}. To save a node, create an arc + // accepting "C" directly from "A B" to "B C". This saves as many nodes + // as there are the highest order n-grams, which is typically about half + // the size of a large 3-gram model. + // + // Indeed, this does not apply to n-grams ending in EOS, since they do not + // back off. These are special, as they do not have a back-off state, and + // the node for "(..anything..) " is always final. These are handled + // in one of the two possible ways, If symbols and are being + // replaced by epsilons, neither node nor arc is created, and the logprob + // of the n-gram is applied to its source node as final weight. If and + // are preserved, then a special final node for is allocated and + // used as the destination of the "" acceptor arc. + HistKey heads(ngram.words.begin(), ngram.words.end() - 1); + typename HistoryMap::iterator source_it = history_.find(heads); + if (source_it == history_.end()) { + // There was no "A B", therefore the probability of "A B C" is zero. + // Print a warning and discard current n-gram. + if (parent_->ShouldWarn()) + KALDI_WARN << parent_->LineReference() + << " skipped: no parent (n-1)-gram exists"; + return; + } + + StateId source = source_it->second; + StateId dest; + Symbol sym = ngram.words.back(); + float weight = -ngram.logprob; + if (sym == sub_eps_ || sym == 0) { + KALDI_ERR << " or disambiguation symbol " << sym << "found in the ARPA file. "; + } + if (sym == eos_symbol_) { + if (sub_eps_ == 0) { + // Keep as a real symbol when not substituting. + dest = eos_state_; + } else { + // Treat as if it was epsilon: mark source final, with the weight + // of the n-gram. + fst_->SetFinal(source, weight); + return; + } + } else { + // For the highest order n-gram, this may find an existing state, for + // non-highest, will create one (unless there are duplicate n-grams + // in the grammar, which cannot be reliably detected if highest order, + // so we better do not do that at all). + dest = AddStateWithBackoff( + HistKey(ngram.words.begin() + (is_highest ? 1 : 0), + ngram.words.end()), + -ngram.backoff); + } + + if (sym == bos_symbol_) { + weight = 0; // Accepting is always free. + if (sub_eps_ == 0) { + // is as a real symbol, only accepted in the start state. + source = fst_->AddState(); + fst_->SetStart(source); + } else { + // The new state for unigram history *is* the start state. + fst_->SetStart(dest); + return; + } + } + + // Add arc from source to dest, whichever way it was found. + fst_->AddArc(source, fst::StdArc(sym, sym, weight, dest)); + return; +} + +// Find or create a new state for n-gram defined by key, and ensure it has a +// backoff transition. The key is either the current n-gram for all but +// highest orders, or the tails of the n-gram for the highest order. The +// latter arises from the chain-collapsing optimization described above. +template +StateId ArpaLmCompilerImpl::AddStateWithBackoff(HistKey key, + float backoff) { + typename HistoryMap::iterator dest_it = history_.find(key); + if (dest_it != history_.end()) { + // Found an existing state in the history map. Invariant: if the state in + // the map, then its backoff arc is in the FST. We are done. + return dest_it->second; + } + // Otherwise create a new state and its backoff arc, and register in the map. + StateId dest = fst_->AddState(); + history_[key] = dest; + CreateBackoff(key.Tails(), dest, backoff); + return dest; +} + +// Create a backoff arc for a state. Key is a backoff destination that may or +// may not exist. When the destination is not found, naturally fall back to +// the lower order model, and all the way down until one is found (since the +// 0-gram model is always present, the search is guaranteed to terminate). +template +inline void ArpaLmCompilerImpl::CreateBackoff( + HistKey key, StateId state, float weight) { + typename HistoryMap::iterator dest_it = history_.find(key); + while (dest_it == history_.end()) { + key = key.Tails(); + dest_it = history_.find(key); + } + + // The arc should transduce either or #0 to , depending on the + // epsilon substitution mode. This is the only case when input and output + // label may differ. + fst_->AddArc(state, fst::StdArc(sub_eps_, 0, weight, dest_it->second)); +} + +ArpaLmCompiler::~ArpaLmCompiler() { + if (impl_ != NULL) + delete impl_; +} + +void ArpaLmCompiler::HeaderAvailable() { + KALDI_ASSERT(impl_ == NULL); + // Use optimized implementation if the grammar is 4-gram or less, and the + // maximum attained symbol id will fit into the optimized range. + int64 max_symbol = 0; + if (Symbols() != NULL) + max_symbol = Symbols()->AvailableKey() - 1; + // If augmenting the symbol table, assume the worst case when all words in + // the model being read are novel. + if (Options().oov_handling == ArpaParseOptions::kAddToSymbols) + max_symbol += NgramCounts()[0]; + + if (NgramCounts().size() <= 4 && max_symbol < OptimizedHistKey::kMaxData) { + impl_ = new ArpaLmCompilerImpl(this, &fst_, sub_eps_); + } else { + impl_ = new ArpaLmCompilerImpl(this, &fst_, sub_eps_); + KALDI_LOG << "Reverting to slower state tracking because model is large: " + << NgramCounts().size() << "-gram with symbols up to " + << max_symbol; + } +} + +void ArpaLmCompiler::ConsumeNGram(const NGram &ngram) { + // is invalid in tails, in heads of an n-gram. + for (int i = 0; i < ngram.words.size(); ++i) { + if ((i > 0 && ngram.words[i] == Options().bos_symbol) || + (i + 1 < ngram.words.size() + && ngram.words[i] == Options().eos_symbol)) { + if (ShouldWarn()) + KALDI_WARN << LineReference() + << " skipped: n-gram has invalid BOS/EOS placement"; + return; + } + } + + bool is_highest = ngram.words.size() == NgramCounts().size(); + impl_->ConsumeNGram(ngram, is_highest); +} + +void ArpaLmCompiler::RemoveRedundantStates() { + fst::StdArc::Label backoff_symbol = sub_eps_; + if (backoff_symbol == 0) { + // The method of removing redundant states implemented in this function + // leads to slow determinization of L o G when people use the older style of + // usage of arpa2fst where the --disambig-symbol option was not specified. + // The issue seems to be that it creates a non-deterministic FST, while G is + // supposed to be deterministic. By 'return'ing below, we just disable this + // method if people were using an older script. This method isn't really + // that consequential anyway, and people will move to the newer-style + // scripts (see current utils/format_lm.sh), so this isn't much of a + // problem. + return; + } + + fst::StdArc::StateId num_states = fst_.NumStates(); + + + // replace the #0 symbols on the input of arcs out of redundant states (states + // that are not final and have only a backoff arc leaving them), with . + for (fst::StdArc::StateId state = 0; state < num_states; state++) { + if (fst_.NumArcs(state) == 1 && fst_.Final(state) == fst::TropicalWeight::Zero()) { + fst::MutableArcIterator iter(&fst_, state); + fst::StdArc arc = iter.Value(); + if (arc.ilabel == backoff_symbol) { + arc.ilabel = 0; + iter.SetValue(arc); + } + } + } + + // we could call fst::RemoveEps, and it would have the same effect in normal + // cases, where backoff_symbol != 0 and there are no epsilons in unexpected + // places, but RemoveEpsLocal is a bit safer in case something weird is going + // on; it guarantees not to blow up the FST. + fst::RemoveEpsLocal(&fst_); + KALDI_LOG << "Reduced num-states from " << num_states << " to " + << fst_.NumStates(); +} + +void ArpaLmCompiler::Check() const { + if (fst_.Start() == fst::kNoStateId) { + KALDI_ERR << "Arpa file did not contain the beginning-of-sentence symbol " + << Symbols()->Find(Options().bos_symbol) << "."; + } +} + +void ArpaLmCompiler::ReadComplete() { + fst_.SetInputSymbols(Symbols()); + fst_.SetOutputSymbols(Symbols()); + RemoveRedundantStates(); + Check(); +} + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/lm/arpa-lm-compiler.h b/speechx/speechx/kaldi/lm/arpa-lm-compiler.h new file mode 100644 index 000000000..67a18273f --- /dev/null +++ b/speechx/speechx/kaldi/lm/arpa-lm-compiler.h @@ -0,0 +1,65 @@ +// lm/arpa-lm-compiler.h + +// Copyright 2009-2011 Gilles Boulianne +// Copyright 2016 Smart Action LLC (kkm) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_LM_ARPA_LM_COMPILER_H_ +#define KALDI_LM_ARPA_LM_COMPILER_H_ + +#include + +#include "lm/arpa-file-parser.h" + +namespace kaldi { + +class ArpaLmCompilerImplInterface; + +class ArpaLmCompiler : public ArpaFileParser { + public: + ArpaLmCompiler(const ArpaParseOptions& options, int sub_eps, + fst::SymbolTable* symbols) + : ArpaFileParser(options, symbols), + sub_eps_(sub_eps), impl_(NULL) { + } + ~ArpaLmCompiler(); + + const fst::StdVectorFst& Fst() const { return fst_; } + fst::StdVectorFst* MutableFst() { return &fst_; } + + protected: + // ArpaFileParser overrides. + virtual void HeaderAvailable(); + virtual void ConsumeNGram(const NGram& ngram); + virtual void ReadComplete(); + + + private: + // this function removes states that only have a backoff arc coming + // out of them. + void RemoveRedundantStates(); + void Check() const; + + int sub_eps_; + ArpaLmCompilerImplInterface* impl_; // Owned. + fst::StdVectorFst fst_; + template friend class ArpaLmCompilerImpl; +}; + +} // namespace kaldi + +#endif // KALDI_LM_ARPA_LM_COMPILER_H_ diff --git a/speechx/tools/lmbin/CMakeLists.txt b/speechx/speechx/kaldi/lmbin/CMakeLists.txt similarity index 64% rename from speechx/tools/lmbin/CMakeLists.txt rename to speechx/speechx/kaldi/lmbin/CMakeLists.txt index 277e20776..2b0932f7d 100644 --- a/speechx/tools/lmbin/CMakeLists.txt +++ b/speechx/speechx/kaldi/lmbin/CMakeLists.txt @@ -1,5 +1,4 @@ -cmake_minimum_required(VERSION 3.14 FATAL_ERROR) add_executable(arpa2fst ${CMAKE_CURRENT_SOURCE_DIR}/arpa2fst.cc) target_include_directories(arpa2fst PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(arpa2fst ) +target_link_libraries(arpa2fst PUBLIC kaldi-lm glog gflags fst) diff --git a/speechx/tools/lmbin/arpa2fst.cc b/speechx/speechx/kaldi/lmbin/arpa2fst.cc similarity index 100% rename from speechx/tools/lmbin/arpa2fst.cc rename to speechx/speechx/kaldi/lmbin/arpa2fst.cc diff --git a/speechx/speechx/nnet/CMakeLists.txt b/speechx/speechx/nnet/CMakeLists.txt index cee881de8..c325ce755 100644 --- a/speechx/speechx/nnet/CMakeLists.txt +++ b/speechx/speechx/nnet/CMakeLists.txt @@ -4,4 +4,11 @@ add_library(nnet STATIC decodable.cc paddle_nnet.cc ) -target_link_libraries(nnet absl::strings) \ No newline at end of file +target_link_libraries(nnet absl::strings) + +set(bin_name nnet_forward_main) +add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet ${DEPS}) + + diff --git a/speechx/speechx/nnet/nnet_forward_main.cc b/speechx/speechx/nnet/nnet_forward_main.cc new file mode 100644 index 000000000..0d4ea8ff7 --- /dev/null +++ b/speechx/speechx/nnet/nnet_forward_main.cc @@ -0,0 +1,165 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/flags.h" +#include "base/log.h" +#include "frontend/audio/assembler.h" +#include "frontend/audio/data_cache.h" +#include "kaldi/util/table-types.h" +#include "nnet/decodable.h" +#include "nnet/paddle_nnet.h" + +DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); +DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier"); +DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); +DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); +DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk"); +DEFINE_int32(receptive_field_length, + 7, + "receptive field of two CNN(kernel=3) downsampling module."); +DEFINE_int32(downsampling_rate, + 4, + "two CNN(kernel=3) module downsampling rate."); +DEFINE_string( + model_input_names, + "audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box", + "model input names"); +DEFINE_string(model_output_names, + "softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0", + "model output names"); +DEFINE_string(model_cache_names, + "chunk_state_h_box,chunk_state_c_box", + "model cache names"); +DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes"); +DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); + +using kaldi::BaseFloat; +using kaldi::Matrix; +using std::vector; + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + kaldi::SequentialBaseFloatMatrixReader feature_reader( + FLAGS_feature_rspecifier); + kaldi::BaseFloatMatrixWriter nnet_writer(FLAGS_nnet_prob_wspecifier); + std::string model_graph = FLAGS_model_path; + std::string model_params = FLAGS_param_path; + LOG(INFO) << "model path: " << model_graph; + LOG(INFO) << "model param: " << model_params; + + int32 num_done = 0, num_err = 0; + + ppspeech::ModelOptions model_opts; + model_opts.model_path = model_graph; + model_opts.param_path = model_params; + model_opts.cache_names = FLAGS_model_cache_names; + model_opts.cache_shape = FLAGS_model_cache_shapes; + model_opts.input_names = FLAGS_model_input_names; + model_opts.output_names = FLAGS_model_output_names; + std::shared_ptr nnet( + new ppspeech::PaddleNnet(model_opts)); + std::shared_ptr raw_data(new ppspeech::DataCache()); + std::shared_ptr decodable( + new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale)); + + int32 chunk_size = FLAGS_receptive_field_length + + (FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; + int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk; + int32 receptive_field_length = FLAGS_receptive_field_length; + LOG(INFO) << "chunk size (frame): " << chunk_size; + LOG(INFO) << "chunk stride (frame): " << chunk_stride; + LOG(INFO) << "receptive field (frame): " << receptive_field_length; + kaldi::Timer timer; + for (; !feature_reader.Done(); feature_reader.Next()) { + string utt = feature_reader.Key(); + kaldi::Matrix feature = feature_reader.Value(); + raw_data->SetDim(feature.NumCols()); + LOG(INFO) << "process utt: " << utt; + LOG(INFO) << "rows: " << feature.NumRows(); + LOG(INFO) << "cols: " << feature.NumCols(); + + int32 row_idx = 0; + int32 padding_len = 0; + int32 ori_feature_len = feature.NumRows(); + if ((feature.NumRows() - chunk_size) % chunk_stride != 0) { + padding_len = + chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride; + feature.Resize(feature.NumRows() + padding_len, + feature.NumCols(), + kaldi::kCopyData); + } + int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1; + int32 frame_idx = 0; + std::vector> prob_vec; + for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { + kaldi::Vector feature_chunk(chunk_size * + feature.NumCols()); + int32 feature_chunk_size = 0; + if (ori_feature_len > chunk_idx * chunk_stride) { + feature_chunk_size = std::min( + ori_feature_len - chunk_idx * chunk_stride, chunk_size); + } + if (feature_chunk_size < receptive_field_length) break; + + int32 start = chunk_idx * chunk_stride; + for (int row_id = 0; row_id < chunk_size; ++row_id) { + kaldi::SubVector tmp(feature, start); + kaldi::SubVector f_chunk_tmp( + feature_chunk.Data() + row_id * feature.NumCols(), + feature.NumCols()); + f_chunk_tmp.CopyFromVec(tmp); + ++start; + } + raw_data->Accept(feature_chunk); + if (chunk_idx == num_chunks - 1) { + raw_data->SetFinished(); + } + vector prob; + while (decodable->FrameLikelihood(frame_idx, &prob)) { + kaldi::Vector vec_tmp(prob.size()); + std::memcpy(vec_tmp.Data(), + prob.data(), + sizeof(kaldi::BaseFloat) * prob.size()); + prob_vec.push_back(vec_tmp); + frame_idx++; + } + } + decodable->Reset(); + if (prob_vec.size() == 0) { + // the TokenWriter can not write empty string. + ++num_err; + KALDI_LOG << " the nnet prob of " << utt << " is empty"; + continue; + } + kaldi::Matrix result(prob_vec.size(), + prob_vec[0].Dim()); + for (int32 row_idx = 0; row_idx < prob_vec.size(); ++row_idx) { + for (int32 col_idx = 0; col_idx < prob_vec[0].Dim(); ++col_idx) { + result(row_idx, col_idx) = prob_vec[row_idx](col_idx); + } + } + + nnet_writer.Write(utt, result); + ++num_done; + } + + double elapsed = timer.Elapsed(); + KALDI_LOG << " cost:" << elapsed << " s"; + + KALDI_LOG << "Done " << num_done << " utterances, " << num_err + << " with errors."; + return (num_done != 0 ? 0 : 1); +} diff --git a/speechx/speechx/protocol/CMakeLists.txt b/speechx/speechx/protocol/CMakeLists.txt index e69de29bb..98b2f38b4 100644 --- a/speechx/speechx/protocol/CMakeLists.txt +++ b/speechx/speechx/protocol/CMakeLists.txt @@ -0,0 +1,3 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +add_subdirectory(websocket) diff --git a/speechx/speechx/protocol/websocket/CMakeLists.txt b/speechx/speechx/protocol/websocket/CMakeLists.txt new file mode 100644 index 000000000..c3454c399 --- /dev/null +++ b/speechx/speechx/protocol/websocket/CMakeLists.txt @@ -0,0 +1,15 @@ +project(websocket) + +add_library(websocket STATIC + websocket_server.cc + websocket_client.cc +) +target_link_libraries(websocket PUBLIC frontend decoder nnet) + +add_executable(websocket_server_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_server_main.cc) +target_include_directories(websocket_server_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(websocket_server_main PUBLIC fst websocket ${DEPS}) + +add_executable(websocket_client_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_client_main.cc) +target_include_directories(websocket_client_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(websocket_client_main PUBLIC fst websocket ${DEPS}) diff --git a/speechx/speechx/websocket/websocket_client.cc b/speechx/speechx/protocol/websocket/websocket_client.cc similarity index 96% rename from speechx/speechx/websocket/websocket_client.cc rename to speechx/speechx/protocol/websocket/websocket_client.cc index 6bd930b85..60e06db63 100644 --- a/speechx/speechx/websocket/websocket_client.cc +++ b/speechx/speechx/protocol/websocket/websocket_client.cc @@ -67,6 +67,9 @@ void WebSocketClient::ReadLoopFunc() { if (obj["type"] == "final_result") { result_ = obj["result"].as_string().c_str(); } + if (obj["type"] == "partial_result") { + partial_result_ = obj["result"].as_string().c_str(); + } if (obj["type"] == "speech_end") { done_ = true; break; diff --git a/speechx/speechx/websocket/websocket_client.h b/speechx/speechx/protocol/websocket/websocket_client.h similarity index 91% rename from speechx/speechx/websocket/websocket_client.h rename to speechx/speechx/protocol/websocket/websocket_client.h index ac0aed310..886da2929 100644 --- a/speechx/speechx/websocket/websocket_client.h +++ b/speechx/speechx/protocol/websocket/websocket_client.h @@ -40,12 +40,14 @@ class WebSocketClient { void SendEndSignal(); void SendDataEnd(); bool Done() const { return done_; } - std::string GetResult() { return result_; } + std::string GetResult() const { return result_; } + std::string GetPartialResult() const { return partial_result_; } private: void Connect(); std::string host_; std::string result_; + std::string partial_result_; int port_; bool done_ = false; asio::io_context ioc_; diff --git a/speechx/examples/ds2_ol/websocket/websocket_client_main.cc b/speechx/speechx/protocol/websocket/websocket_client_main.cc similarity index 99% rename from speechx/examples/ds2_ol/websocket/websocket_client_main.cc rename to speechx/speechx/protocol/websocket/websocket_client_main.cc index df658b0a2..7ad36e3a5 100644 --- a/speechx/examples/ds2_ol/websocket/websocket_client_main.cc +++ b/speechx/speechx/protocol/websocket/websocket_client_main.cc @@ -59,7 +59,6 @@ int main(int argc, char* argv[]) { client.SendBinaryData(wav_chunk.data(), wav_chunk.size() * sizeof(int16)); - sample_offset += cur_chunk_size; LOG(INFO) << "Send " << cur_chunk_size << " samples"; std::this_thread::sleep_for( diff --git a/speechx/speechx/websocket/websocket_server.cc b/speechx/speechx/protocol/websocket/websocket_server.cc similarity index 96% rename from speechx/speechx/websocket/websocket_server.cc rename to speechx/speechx/protocol/websocket/websocket_server.cc index 28c9eca4e..14f2f6e9f 100644 --- a/speechx/speechx/websocket/websocket_server.cc +++ b/speechx/speechx/protocol/websocket/websocket_server.cc @@ -75,9 +75,11 @@ void ConnectionHandler::OnSpeechData(const beast::flat_buffer& buffer) { CHECK(recognizer_ != nullptr); recognizer_->Accept(pcm_data); - // TODO: return lpartial result - json::value rv = { - {"status", "ok"}, {"type", "partial_result"}, {"result", "TODO"}}; + std::string partial_result = recognizer_->GetPartialResult(); + + json::value rv = {{"status", "ok"}, + {"type", "partial_result"}, + {"result", partial_result}}; ws_.text(true); ws_.write(asio::buffer(json::serialize(rv))); } diff --git a/speechx/speechx/websocket/websocket_server.h b/speechx/speechx/protocol/websocket/websocket_server.h similarity index 98% rename from speechx/speechx/websocket/websocket_server.h rename to speechx/speechx/protocol/websocket/websocket_server.h index 9ea88282e..009fc42ed 100644 --- a/speechx/speechx/websocket/websocket_server.h +++ b/speechx/speechx/protocol/websocket/websocket_server.h @@ -44,7 +44,6 @@ class ConnectionHandler { void OnFinish(); void OnSpeechData(const beast::flat_buffer& buffer); void OnError(const std::string& message); - void OnPartialResult(const std::string& result); void OnFinalResult(const std::string& result); void DecodeThreadFunc(); std::string SerializeResult(bool finish); diff --git a/speechx/examples/ds2_ol/websocket/websocket_server_main.cc b/speechx/speechx/protocol/websocket/websocket_server_main.cc similarity index 100% rename from speechx/examples/ds2_ol/websocket/websocket_server_main.cc rename to speechx/speechx/protocol/websocket/websocket_server_main.cc diff --git a/speechx/speechx/utils/CMakeLists.txt b/speechx/speechx/utils/CMakeLists.txt index 08d115281..95e865744 100644 --- a/speechx/speechx/utils/CMakeLists.txt +++ b/speechx/speechx/utils/CMakeLists.txt @@ -1,5 +1,4 @@ add_library(utils file_utils.cc - simdjson.cpp -) +) \ No newline at end of file diff --git a/speechx/speechx/utils/simdjson.cpp b/speechx/speechx/utils/simdjson.cpp deleted file mode 100644 index 8f1a9e284..000000000 --- a/speechx/speechx/utils/simdjson.cpp +++ /dev/null @@ -1,16016 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* auto-generated on 2022-01-31 11:38:54 -0500. Do not edit! */ -/* begin file src/simdjson.cpp */ -#include "simdjson.h" - -SIMDJSON_PUSH_DISABLE_WARNINGS -SIMDJSON_DISABLE_UNDESIRED_WARNINGS - -/* begin file src/to_chars.cpp */ -#include -#include -#include -#include - -namespace simdjson { -namespace internal { -/*! -implements the Grisu2 algorithm for binary to decimal floating-point -conversion. -Adapted from JSON for Modern C++ - -This implementation is a slightly modified version of the reference -implementation which may be obtained from -http://florian.loitsch.com/publications (bench.tar.gz). -The code is distributed under the MIT license, Copyright (c) 2009 Florian -Loitsch. For a detailed description of the algorithm see: [1] Loitsch, "Printing -Floating-Point Numbers Quickly and Accurately with Integers", Proceedings of the -ACM SIGPLAN 2010 Conference on Programming Language Design and Implementation, -PLDI 2010 [2] Burger, Dybvig, "Printing Floating-Point Numbers Quickly and -Accurately", Proceedings of the ACM SIGPLAN 1996 Conference on Programming -Language Design and Implementation, PLDI 1996 -*/ -namespace dtoa_impl { - -template -Target reinterpret_bits(const Source source) { - static_assert(sizeof(Target) == sizeof(Source), "size mismatch"); - - Target target; - std::memcpy(&target, &source, sizeof(Source)); - return target; -} - -struct diyfp // f * 2^e -{ - static constexpr int kPrecision = 64; // = q - - std::uint64_t f = 0; - int e = 0; - - constexpr diyfp(std::uint64_t f_, int e_) noexcept : f(f_), e(e_) {} - - /*! - @brief returns x - y - @pre x.e == y.e and x.f >= y.f - */ - static diyfp sub(const diyfp &x, const diyfp &y) noexcept { - return {x.f - y.f, x.e}; - } - - /*! - @brief returns x * y - @note The result is rounded. (Only the upper q bits are returned.) - */ - static diyfp mul(const diyfp &x, const diyfp &y) noexcept { - static_assert(kPrecision == 64, "internal error"); - - // Computes: - // f = round((x.f * y.f) / 2^q) - // e = x.e + y.e + q - - // Emulate the 64-bit * 64-bit multiplication: - // - // p = u * v - // = (u_lo + 2^32 u_hi) (v_lo + 2^32 v_hi) - // = (u_lo v_lo ) + 2^32 ((u_lo v_hi ) + (u_hi v_lo )) - // + - // 2^64 (u_hi v_hi ) = (p0 ) + 2^32 ((p1 ) + - // (p2 )) - // + 2^64 (p3 ) = (p0_lo + 2^32 p0_hi) + 2^32 ((p1_lo + - // 2^32 p1_hi) + (p2_lo + 2^32 p2_hi)) + 2^64 (p3 ) = - // (p0_lo ) + 2^32 (p0_hi + p1_lo + p2_lo ) + 2^64 (p1_hi - // + - // p2_hi + p3) = (p0_lo ) + 2^32 (Q ) + 2^64 (H ) = (p0_lo - // ) + - // 2^32 (Q_lo + 2^32 Q_hi ) + 2^64 (H ) - // - // (Since Q might be larger than 2^32 - 1) - // - // = (p0_lo + 2^32 Q_lo) + 2^64 (Q_hi + H) - // - // (Q_hi + H does not overflow a 64-bit int) - // - // = p_lo + 2^64 p_hi - - const std::uint64_t u_lo = x.f & 0xFFFFFFFFu; - const std::uint64_t u_hi = x.f >> 32u; - const std::uint64_t v_lo = y.f & 0xFFFFFFFFu; - const std::uint64_t v_hi = y.f >> 32u; - - const std::uint64_t p0 = u_lo * v_lo; - const std::uint64_t p1 = u_lo * v_hi; - const std::uint64_t p2 = u_hi * v_lo; - const std::uint64_t p3 = u_hi * v_hi; - - const std::uint64_t p0_hi = p0 >> 32u; - const std::uint64_t p1_lo = p1 & 0xFFFFFFFFu; - const std::uint64_t p1_hi = p1 >> 32u; - const std::uint64_t p2_lo = p2 & 0xFFFFFFFFu; - const std::uint64_t p2_hi = p2 >> 32u; - - std::uint64_t Q = p0_hi + p1_lo + p2_lo; - - // The full product might now be computed as - // - // p_hi = p3 + p2_hi + p1_hi + (Q >> 32) - // p_lo = p0_lo + (Q << 32) - // - // But in this particular case here, the full p_lo is not required. - // Effectively we only need to add the highest bit in p_lo to p_hi (and - // Q_hi + 1 does not overflow). - - Q += std::uint64_t{1} << (64u - 32u - 1u); // round, ties up - - const std::uint64_t h = p3 + p2_hi + p1_hi + (Q >> 32u); - - return {h, x.e + y.e + 64}; - } - - /*! - @brief normalize x such that the significand is >= 2^(q-1) - @pre x.f != 0 - */ - static diyfp normalize(diyfp x) noexcept { - while ((x.f >> 63u) == 0) { - x.f <<= 1u; - x.e--; - } - - return x; - } - - /*! - @brief normalize x such that the result has the exponent E - @pre e >= x.e and the upper e - x.e bits of x.f must be zero. - */ - static diyfp normalize_to(const diyfp &x, - const int target_exponent) noexcept { - const int delta = x.e - target_exponent; - - return {x.f << delta, target_exponent}; - } -}; - -struct boundaries { - diyfp w; - diyfp minus; - diyfp plus; -}; - -/*! -Compute the (normalized) diyfp representing the input number 'value' and its -boundaries. -@pre value must be finite and positive -*/ -template -boundaries compute_boundaries(FloatType value) { - // Convert the IEEE representation into a diyfp. - // - // If v is denormal: - // value = 0.F * 2^(1 - bias) = ( F) * 2^(1 - bias - (p-1)) - // If v is normalized: - // value = 1.F * 2^(E - bias) = (2^(p-1) + F) * 2^(E - bias - (p-1)) - - static_assert(std::numeric_limits::is_iec559, - "internal error: dtoa_short requires an IEEE-754 " - "floating-point implementation"); - - constexpr int kPrecision = std::numeric_limits< - FloatType>::digits; // = p (includes the hidden bit) - constexpr int kBias = - std::numeric_limits::max_exponent - 1 + (kPrecision - 1); - constexpr int kMinExp = 1 - kBias; - constexpr std::uint64_t kHiddenBit = std::uint64_t{1} - << (kPrecision - 1); // = 2^(p-1) - - using bits_type = typename std::conditional::type; - - const std::uint64_t bits = reinterpret_bits(value); - const std::uint64_t E = bits >> (kPrecision - 1); - const std::uint64_t F = bits & (kHiddenBit - 1); - - const bool is_denormal = E == 0; - const diyfp v = is_denormal - ? diyfp(F, kMinExp) - : diyfp(F + kHiddenBit, static_cast(E) - kBias); - - // Compute the boundaries m- and m+ of the floating-point value - // v = f * 2^e. - // - // Determine v- and v+, the floating-point predecessor and successor if v, - // respectively. - // - // v- = v - 2^e if f != 2^(p-1) or e == e_min (A) - // = v - 2^(e-1) if f == 2^(p-1) and e > e_min (B) - // - // v+ = v + 2^e - // - // Let m- = (v- + v) / 2 and m+ = (v + v+) / 2. All real numbers _strictly_ - // between m- and m+ round to v, regardless of how the input rounding - // algorithm breaks ties. - // - // ---+-------------+-------------+-------------+-------------+--- (A) - // v- m- v m+ v+ - // - // -----------------+------+------+-------------+-------------+--- (B) - // v- m- v m+ v+ - - const bool lower_boundary_is_closer = F == 0 && E > 1; - const diyfp m_plus = diyfp(2 * v.f + 1, v.e - 1); - const diyfp m_minus = lower_boundary_is_closer - ? diyfp(4 * v.f - 1, v.e - 2) // (B) - : diyfp(2 * v.f - 1, v.e - 1); // (A) - - // Determine the normalized w+ = m+. - const diyfp w_plus = diyfp::normalize(m_plus); - - // Determine w- = m- such that e_(w-) = e_(w+). - const diyfp w_minus = diyfp::normalize_to(m_minus, w_plus.e); - - return {diyfp::normalize(v), w_minus, w_plus}; -} - -// Given normalized diyfp w, Grisu needs to find a (normalized) cached -// power-of-ten c, such that the exponent of the product c * w = f * 2^e lies -// within a certain range [alpha, gamma] (Definition 3.2 from [1]) -// -// alpha <= e = e_c + e_w + q <= gamma -// -// or -// -// f_c * f_w * 2^alpha <= f_c 2^(e_c) * f_w 2^(e_w) * 2^q -// <= f_c * f_w * 2^gamma -// -// Since c and w are normalized, i.e. 2^(q-1) <= f < 2^q, this implies -// -// 2^(q-1) * 2^(q-1) * 2^alpha <= c * w * 2^q < 2^q * 2^q * 2^gamma -// -// or -// -// 2^(q - 2 + alpha) <= c * w < 2^(q + gamma) -// -// The choice of (alpha,gamma) determines the size of the table and the form of -// the digit generation procedure. Using (alpha,gamma)=(-60,-32) works out well -// in practice: -// -// The idea is to cut the number c * w = f * 2^e into two parts, which can be -// processed independently: An integral part p1, and a fractional part p2: -// -// f * 2^e = ( (f div 2^-e) * 2^-e + (f mod 2^-e) ) * 2^e -// = (f div 2^-e) + (f mod 2^-e) * 2^e -// = p1 + p2 * 2^e -// -// The conversion of p1 into decimal form requires a series of divisions and -// modulos by (a power of) 10. These operations are faster for 32-bit than for -// 64-bit integers, so p1 should ideally fit into a 32-bit integer. This can be -// achieved by choosing -// -// -e >= 32 or e <= -32 := gamma -// -// In order to convert the fractional part -// -// p2 * 2^e = p2 / 2^-e = d[-1] / 10^1 + d[-2] / 10^2 + ... -// -// into decimal form, the fraction is repeatedly multiplied by 10 and the digits -// d[-i] are extracted in order: -// -// (10 * p2) div 2^-e = d[-1] -// (10 * p2) mod 2^-e = d[-2] / 10^1 + ... -// -// The multiplication by 10 must not overflow. It is sufficient to choose -// -// 10 * p2 < 16 * p2 = 2^4 * p2 <= 2^64. -// -// Since p2 = f mod 2^-e < 2^-e, -// -// -e <= 60 or e >= -60 := alpha - -constexpr int kAlpha = -60; -constexpr int kGamma = -32; - -struct cached_power // c = f * 2^e ~= 10^k -{ - std::uint64_t f; - int e; - int k; -}; - -/*! -For a normalized diyfp w = f * 2^e, this function returns a (normalized) cached -power-of-ten c = f_c * 2^e_c, such that the exponent of the product w * c -satisfies (Definition 3.2 from [1]) - alpha <= e_c + e + q <= gamma. -*/ -inline cached_power get_cached_power_for_binary_exponent(int e) { - // Now - // - // alpha <= e_c + e + q <= gamma (1) - // ==> f_c * 2^alpha <= c * 2^e * 2^q - // - // and since the c's are normalized, 2^(q-1) <= f_c, - // - // ==> 2^(q - 1 + alpha) <= c * 2^(e + q) - // ==> 2^(alpha - e - 1) <= c - // - // If c were an exact power of ten, i.e. c = 10^k, one may determine k as - // - // k = ceil( log_10( 2^(alpha - e - 1) ) ) - // = ceil( (alpha - e - 1) * log_10(2) ) - // - // From the paper: - // "In theory the result of the procedure could be wrong since c is rounded, - // and the computation itself is approximated [...]. In practice, however, - // this simple function is sufficient." - // - // For IEEE double precision floating-point numbers converted into - // normalized diyfp's w = f * 2^e, with q = 64, - // - // e >= -1022 (min IEEE exponent) - // -52 (p - 1) - // -52 (p - 1, possibly normalize denormal IEEE numbers) - // -11 (normalize the diyfp) - // = -1137 - // - // and - // - // e <= +1023 (max IEEE exponent) - // -52 (p - 1) - // -11 (normalize the diyfp) - // = 960 - // - // This binary exponent range [-1137,960] results in a decimal exponent - // range [-307,324]. One does not need to store a cached power for each - // k in this range. For each such k it suffices to find a cached power - // such that the exponent of the product lies in [alpha,gamma]. - // This implies that the difference of the decimal exponents of adjacent - // table entries must be less than or equal to - // - // floor( (gamma - alpha) * log_10(2) ) = 8. - // - // (A smaller distance gamma-alpha would require a larger table.) - - // NB: - // Actually this function returns c, such that -60 <= e_c + e + 64 <= -34. - - constexpr int kCachedPowersMinDecExp = -300; - constexpr int kCachedPowersDecStep = 8; - - static constexpr std::array kCachedPowers = {{ - {0xAB70FE17C79AC6CA, -1060, -300}, {0xFF77B1FCBEBCDC4F, -1034, -292}, - {0xBE5691EF416BD60C, -1007, -284}, {0x8DD01FAD907FFC3C, -980, -276}, - {0xD3515C2831559A83, -954, -268}, {0x9D71AC8FADA6C9B5, -927, -260}, - {0xEA9C227723EE8BCB, -901, -252}, {0xAECC49914078536D, -874, -244}, - {0x823C12795DB6CE57, -847, -236}, {0xC21094364DFB5637, -821, -228}, - {0x9096EA6F3848984F, -794, -220}, {0xD77485CB25823AC7, -768, -212}, - {0xA086CFCD97BF97F4, -741, -204}, {0xEF340A98172AACE5, -715, -196}, - {0xB23867FB2A35B28E, -688, -188}, {0x84C8D4DFD2C63F3B, -661, -180}, - {0xC5DD44271AD3CDBA, -635, -172}, {0x936B9FCEBB25C996, -608, -164}, - {0xDBAC6C247D62A584, -582, -156}, {0xA3AB66580D5FDAF6, -555, -148}, - {0xF3E2F893DEC3F126, -529, -140}, {0xB5B5ADA8AAFF80B8, -502, -132}, - {0x87625F056C7C4A8B, -475, -124}, {0xC9BCFF6034C13053, -449, -116}, - {0x964E858C91BA2655, -422, -108}, {0xDFF9772470297EBD, -396, -100}, - {0xA6DFBD9FB8E5B88F, -369, -92}, {0xF8A95FCF88747D94, -343, -84}, - {0xB94470938FA89BCF, -316, -76}, {0x8A08F0F8BF0F156B, -289, -68}, - {0xCDB02555653131B6, -263, -60}, {0x993FE2C6D07B7FAC, -236, -52}, - {0xE45C10C42A2B3B06, -210, -44}, {0xAA242499697392D3, -183, -36}, - {0xFD87B5F28300CA0E, -157, -28}, {0xBCE5086492111AEB, -130, -20}, - {0x8CBCCC096F5088CC, -103, -12}, {0xD1B71758E219652C, -77, -4}, - {0x9C40000000000000, -50, 4}, {0xE8D4A51000000000, -24, 12}, - {0xAD78EBC5AC620000, 3, 20}, {0x813F3978F8940984, 30, 28}, - {0xC097CE7BC90715B3, 56, 36}, {0x8F7E32CE7BEA5C70, 83, 44}, - {0xD5D238A4ABE98068, 109, 52}, {0x9F4F2726179A2245, 136, 60}, - {0xED63A231D4C4FB27, 162, 68}, {0xB0DE65388CC8ADA8, 189, 76}, - {0x83C7088E1AAB65DB, 216, 84}, {0xC45D1DF942711D9A, 242, 92}, - {0x924D692CA61BE758, 269, 100}, {0xDA01EE641A708DEA, 295, 108}, - {0xA26DA3999AEF774A, 322, 116}, {0xF209787BB47D6B85, 348, 124}, - {0xB454E4A179DD1877, 375, 132}, {0x865B86925B9BC5C2, 402, 140}, - {0xC83553C5C8965D3D, 428, 148}, {0x952AB45CFA97A0B3, 455, 156}, - {0xDE469FBD99A05FE3, 481, 164}, {0xA59BC234DB398C25, 508, 172}, - {0xF6C69A72A3989F5C, 534, 180}, {0xB7DCBF5354E9BECE, 561, 188}, - {0x88FCF317F22241E2, 588, 196}, {0xCC20CE9BD35C78A5, 614, 204}, - {0x98165AF37B2153DF, 641, 212}, {0xE2A0B5DC971F303A, 667, 220}, - {0xA8D9D1535CE3B396, 694, 228}, {0xFB9B7CD9A4A7443C, 720, 236}, - {0xBB764C4CA7A44410, 747, 244}, {0x8BAB8EEFB6409C1A, 774, 252}, - {0xD01FEF10A657842C, 800, 260}, {0x9B10A4E5E9913129, 827, 268}, - {0xE7109BFBA19C0C9D, 853, 276}, {0xAC2820D9623BF429, 880, 284}, - {0x80444B5E7AA7CF85, 907, 292}, {0xBF21E44003ACDD2D, 933, 300}, - {0x8E679C2F5E44FF8F, 960, 308}, {0xD433179D9C8CB841, 986, 316}, - {0x9E19DB92B4E31BA9, 1013, 324}, - }}; - - // This computation gives exactly the same results for k as - // k = ceil((kAlpha - e - 1) * 0.30102999566398114) - // for |e| <= 1500, but doesn't require floating-point operations. - // NB: log_10(2) ~= 78913 / 2^18 - const int f = kAlpha - e - 1; - const int k = (f * 78913) / (1 << 18) + static_cast(f > 0); - - const int index = - (-kCachedPowersMinDecExp + k + (kCachedPowersDecStep - 1)) / - kCachedPowersDecStep; - - const cached_power cached = kCachedPowers[static_cast(index)]; - - return cached; -} - -/*! -For n != 0, returns k, such that pow10 := 10^(k-1) <= n < 10^k. -For n == 0, returns 1 and sets pow10 := 1. -*/ -inline int find_largest_pow10(const std::uint32_t n, std::uint32_t &pow10) { - // LCOV_EXCL_START - if (n >= 1000000000) { - pow10 = 1000000000; - return 10; - } - // LCOV_EXCL_STOP - else if (n >= 100000000) { - pow10 = 100000000; - return 9; - } else if (n >= 10000000) { - pow10 = 10000000; - return 8; - } else if (n >= 1000000) { - pow10 = 1000000; - return 7; - } else if (n >= 100000) { - pow10 = 100000; - return 6; - } else if (n >= 10000) { - pow10 = 10000; - return 5; - } else if (n >= 1000) { - pow10 = 1000; - return 4; - } else if (n >= 100) { - pow10 = 100; - return 3; - } else if (n >= 10) { - pow10 = 10; - return 2; - } else { - pow10 = 1; - return 1; - } -} - -inline void grisu2_round(char *buf, - int len, - std::uint64_t dist, - std::uint64_t delta, - std::uint64_t rest, - std::uint64_t ten_k) { - // <--------------------------- delta ----> - // <---- dist ---------> - // --------------[------------------+-------------------]-------------- - // M- w M+ - // - // ten_k - // <------> - // <---- rest ----> - // --------------[------------------+----+--------------]-------------- - // w V - // = buf * 10^k - // - // ten_k represents a unit-in-the-last-place in the decimal representation - // stored in buf. - // Decrement buf by ten_k while this takes buf closer to w. - - // The tests are written in this order to avoid overflow in unsigned - // integer arithmetic. - - while (rest < dist && delta - rest >= ten_k && - (rest + ten_k < dist || dist - rest > rest + ten_k - dist)) { - buf[len - 1]--; - rest += ten_k; - } -} - -/*! -Generates V = buffer * 10^decimal_exponent, such that M- <= V <= M+. -M- and M+ must be normalized and share the same exponent -60 <= e <= -32. -*/ -inline void grisu2_digit_gen(char *buffer, - int &length, - int &decimal_exponent, - diyfp M_minus, - diyfp w, - diyfp M_plus) { - static_assert(kAlpha >= -60, "internal error"); - static_assert(kGamma <= -32, "internal error"); - - // Generates the digits (and the exponent) of a decimal floating-point - // number V = buffer * 10^decimal_exponent in the range [M-, M+]. The - // diyfp's - // w, M- and M+ share the same exponent e, which satisfies alpha <= e <= - // gamma. - // - // <--------------------------- delta ----> - // <---- dist ---------> - // --------------[------------------+-------------------]-------------- - // M- w M+ - // - // Grisu2 generates the digits of M+ from left to right and stops as soon as - // V is in [M-,M+]. - - std::uint64_t delta = - diyfp::sub(M_plus, M_minus) - .f; // (significand of (M+ - M-), implicit exponent is e) - std::uint64_t dist = - diyfp::sub(M_plus, w) - .f; // (significand of (M+ - w ), implicit exponent is e) - - // Split M+ = f * 2^e into two parts p1 and p2 (note: e < 0): - // - // M+ = f * 2^e - // = ((f div 2^-e) * 2^-e + (f mod 2^-e)) * 2^e - // = ((p1 ) * 2^-e + (p2 )) * 2^e - // = p1 + p2 * 2^e - - const diyfp one(std::uint64_t{1} << -M_plus.e, M_plus.e); - - auto p1 = static_cast(M_plus.f >> -one.e); // p1 = f div - // 2^-e (Since -e - // >= 32, p1 fits - // into a 32-bit - // int.) - std::uint64_t p2 = M_plus.f & (one.f - 1); // p2 = f mod 2^-e - - // 1) - // - // Generate the digits of the integral part p1 = d[n-1]...d[1]d[0] - - std::uint32_t pow10; - const int k = find_largest_pow10(p1, pow10); - - // 10^(k-1) <= p1 < 10^k, pow10 = 10^(k-1) - // - // p1 = (p1 div 10^(k-1)) * 10^(k-1) + (p1 mod 10^(k-1)) - // = (d[k-1] ) * 10^(k-1) + (p1 mod 10^(k-1)) - // - // M+ = p1 + p2 * 2^e - // = d[k-1] * 10^(k-1) + (p1 mod 10^(k-1)) + p2 * 2^e - // = d[k-1] * 10^(k-1) + ((p1 mod 10^(k-1)) * 2^-e + p2) * 2^e - // = d[k-1] * 10^(k-1) + ( rest) * 2^e - // - // Now generate the digits d[n] of p1 from left to right (n = k-1,...,0) - // - // p1 = d[k-1]...d[n] * 10^n + d[n-1]...d[0] - // - // but stop as soon as - // - // rest * 2^e = (d[n-1]...d[0] * 2^-e + p2) * 2^e <= delta * 2^e - - int n = k; - while (n > 0) { - // Invariants: - // M+ = buffer * 10^n + (p1 + p2 * 2^e) (buffer = 0 for n = k) - // pow10 = 10^(n-1) <= p1 < 10^n - // - const std::uint32_t d = p1 / pow10; // d = p1 div 10^(n-1) - const std::uint32_t r = p1 % pow10; // r = p1 mod 10^(n-1) - // - // M+ = buffer * 10^n + (d * 10^(n-1) + r) + p2 * 2^e - // = (buffer * 10 + d) * 10^(n-1) + (r + p2 * 2^e) - // - buffer[length++] = - static_cast('0' + d); // buffer := buffer * 10 + d - // - // M+ = buffer * 10^(n-1) + (r + p2 * 2^e) - // - p1 = r; - n--; - // - // M+ = buffer * 10^n + (p1 + p2 * 2^e) - // pow10 = 10^n - // - - // Now check if enough digits have been generated. - // Compute - // - // p1 + p2 * 2^e = (p1 * 2^-e + p2) * 2^e = rest * 2^e - // - // Note: - // Since rest and delta share the same exponent e, it suffices to - // compare the significands. - const std::uint64_t rest = (std::uint64_t{p1} << -one.e) + p2; - if (rest <= delta) { - // V = buffer * 10^n, with M- <= V <= M+. - - decimal_exponent += n; - - // We may now just stop. But instead look if the buffer could be - // decremented to bring V closer to w. - // - // pow10 = 10^n is now 1 ulp in the decimal representation V. - // The rounding procedure works with diyfp's with an implicit - // exponent of e. - // - // 10^n = (10^n * 2^-e) * 2^e = ulp * 2^e - // - const std::uint64_t ten_n = std::uint64_t{pow10} << -one.e; - grisu2_round(buffer, length, dist, delta, rest, ten_n); - - return; - } - - pow10 /= 10; - // - // pow10 = 10^(n-1) <= p1 < 10^n - // Invariants restored. - } - - // 2) - // - // The digits of the integral part have been generated: - // - // M+ = d[k-1]...d[1]d[0] + p2 * 2^e - // = buffer + p2 * 2^e - // - // Now generate the digits of the fractional part p2 * 2^e. - // - // Note: - // No decimal point is generated: the exponent is adjusted instead. - // - // p2 actually represents the fraction - // - // p2 * 2^e - // = p2 / 2^-e - // = d[-1] / 10^1 + d[-2] / 10^2 + ... - // - // Now generate the digits d[-m] of p1 from left to right (m = 1,2,...) - // - // p2 * 2^e = d[-1]d[-2]...d[-m] * 10^-m - // + 10^-m * (d[-m-1] / 10^1 + d[-m-2] / 10^2 + ...) - // - // using - // - // 10^m * p2 = ((10^m * p2) div 2^-e) * 2^-e + ((10^m * p2) mod 2^-e) - // = ( d) * 2^-e + ( r) - // - // or - // 10^m * p2 * 2^e = d + r * 2^e - // - // i.e. - // - // M+ = buffer + p2 * 2^e - // = buffer + 10^-m * (d + r * 2^e) - // = (buffer * 10^m + d) * 10^-m + 10^-m * r * 2^e - // - // and stop as soon as 10^-m * r * 2^e <= delta * 2^e - - int m = 0; - for (;;) { - // Invariant: - // M+ = buffer * 10^-m + 10^-m * (d[-m-1] / 10 + d[-m-2] / 10^2 + - // ...) - // * 2^e - // = buffer * 10^-m + 10^-m * (p2 ) - // * 2^e = buffer * 10^-m + 10^-m * (1/10 * (10 * p2) ) * 2^e = - // buffer * 10^-m + 10^-m * (1/10 * ((10*p2 div 2^-e) * 2^-e + - // (10*p2 mod 2^-e)) * 2^e - // - p2 *= 10; - const std::uint64_t d = p2 >> -one.e; // d = (10 * p2) div 2^-e - const std::uint64_t r = p2 & (one.f - 1); // r = (10 * p2) mod 2^-e - // - // M+ = buffer * 10^-m + 10^-m * (1/10 * (d * 2^-e + r) * 2^e - // = buffer * 10^-m + 10^-m * (1/10 * (d + r * 2^e)) - // = (buffer * 10 + d) * 10^(-m-1) + 10^(-m-1) * r * 2^e - // - buffer[length++] = - static_cast('0' + d); // buffer := buffer * 10 + d - // - // M+ = buffer * 10^(-m-1) + 10^(-m-1) * r * 2^e - // - p2 = r; - m++; - // - // M+ = buffer * 10^-m + 10^-m * p2 * 2^e - // Invariant restored. - - // Check if enough digits have been generated. - // - // 10^-m * p2 * 2^e <= delta * 2^e - // p2 * 2^e <= 10^m * delta * 2^e - // p2 <= 10^m * delta - delta *= 10; - dist *= 10; - if (p2 <= delta) { - break; - } - } - - // V = buffer * 10^-m, with M- <= V <= M+. - - decimal_exponent -= m; - - // 1 ulp in the decimal representation is now 10^-m. - // Since delta and dist are now scaled by 10^m, we need to do the - // same with ulp in order to keep the units in sync. - // - // 10^m * 10^-m = 1 = 2^-e * 2^e = ten_m * 2^e - // - const std::uint64_t ten_m = one.f; - grisu2_round(buffer, length, dist, delta, p2, ten_m); - - // By construction this algorithm generates the shortest possible decimal - // number (Loitsch, Theorem 6.2) which rounds back to w. - // For an input number of precision p, at least - // - // N = 1 + ceil(p * log_10(2)) - // - // decimal digits are sufficient to identify all binary floating-point - // numbers (Matula, "In-and-Out conversions"). - // This implies that the algorithm does not produce more than N decimal - // digits. - // - // N = 17 for p = 53 (IEEE double precision) - // N = 9 for p = 24 (IEEE single precision) -} - -/*! -v = buf * 10^decimal_exponent -len is the length of the buffer (number of decimal digits) -The buffer must be large enough, i.e. >= max_digits10. -*/ -inline void grisu2(char *buf, - int &len, - int &decimal_exponent, - diyfp m_minus, - diyfp v, - diyfp m_plus) { - // --------(-----------------------+-----------------------)-------- (A) - // m- v m+ - // - // --------------------(-----------+-----------------------)-------- (B) - // m- v m+ - // - // First scale v (and m- and m+) such that the exponent is in the range - // [alpha, gamma]. - - const cached_power cached = get_cached_power_for_binary_exponent(m_plus.e); - - const diyfp c_minus_k(cached.f, cached.e); // = c ~= 10^-k - - // The exponent of the products is = v.e + c_minus_k.e + q and is in the - // range - // [alpha,gamma] - const diyfp w = diyfp::mul(v, c_minus_k); - const diyfp w_minus = diyfp::mul(m_minus, c_minus_k); - const diyfp w_plus = diyfp::mul(m_plus, c_minus_k); - - // ----(---+---)---------------(---+---)---------------(---+---)---- - // w- w w+ - // = c*m- = c*v = c*m+ - // - // diyfp::mul rounds its result and c_minus_k is approximated too. w, w- and - // w+ are now off by a small amount. - // In fact: - // - // w - v * 10^k < 1 ulp - // - // To account for this inaccuracy, add resp. subtract 1 ulp. - // - // --------+---[---------------(---+---)---------------]---+-------- - // w- M- w M+ w+ - // - // Now any number in [M-, M+] (bounds included) will round to w when input, - // regardless of how the input rounding algorithm breaks ties. - // - // And digit_gen generates the shortest possible such number in [M-, M+]. - // Note that this does not mean that Grisu2 always generates the shortest - // possible number in the interval (m-, m+). - const diyfp M_minus(w_minus.f + 1, w_minus.e); - const diyfp M_plus(w_plus.f - 1, w_plus.e); - - decimal_exponent = -cached.k; // = -(-k) = k - - grisu2_digit_gen(buf, len, decimal_exponent, M_minus, w, M_plus); -} - -/*! -v = buf * 10^decimal_exponent -len is the length of the buffer (number of decimal digits) -The buffer must be large enough, i.e. >= max_digits10. -*/ -template -void grisu2(char *buf, int &len, int &decimal_exponent, FloatType value) { - static_assert( - diyfp::kPrecision >= std::numeric_limits::digits + 3, - "internal error: not enough precision"); - -// If the neighbors (and boundaries) of 'value' are always computed for -// double-precision numbers, all float's can be recovered using strtod (and -// strtof). However, the resulting decimal representations are not exactly -// "short". -// -// The documentation for 'std::to_chars' -// (https://en.cppreference.com/w/cpp/utility/to_chars) says "value is -// converted to a string as if by std::sprintf in the default ("C") locale" -// and since sprintf promotes float's to double's, I think this is exactly -// what 'std::to_chars' does. On the other hand, the documentation for -// 'std::to_chars' requires that "parsing the representation using the -// corresponding std::from_chars function recovers value exactly". That -// indicates that single precision floating-point numbers should be recovered -// using 'std::strtof'. -// -// NB: If the neighbors are computed for single-precision numbers, there is a -// single float -// (7.0385307e-26f) which can't be recovered using strtod. The resulting -// double precision value is off by 1 ulp. -#if 0 - const boundaries w = compute_boundaries(static_cast(value)); -#else - const boundaries w = compute_boundaries(value); -#endif - - grisu2(buf, len, decimal_exponent, w.minus, w.w, w.plus); -} - -/*! -@brief appends a decimal representation of e to buf -@return a pointer to the element following the exponent. -@pre -1000 < e < 1000 -*/ -inline char *append_exponent(char *buf, int e) { - if (e < 0) { - e = -e; - *buf++ = '-'; - } else { - *buf++ = '+'; - } - - auto k = static_cast(e); - if (k < 10) { - // Always print at least two digits in the exponent. - // This is for compatibility with printf("%g"). - *buf++ = '0'; - *buf++ = static_cast('0' + k); - } else if (k < 100) { - *buf++ = static_cast('0' + k / 10); - k %= 10; - *buf++ = static_cast('0' + k); - } else { - *buf++ = static_cast('0' + k / 100); - k %= 100; - *buf++ = static_cast('0' + k / 10); - k %= 10; - *buf++ = static_cast('0' + k); - } - - return buf; -} - -/*! -@brief prettify v = buf * 10^decimal_exponent -If v is in the range [10^min_exp, 10^max_exp) it will be printed in fixed-point -notation. Otherwise it will be printed in exponential notation. -@pre min_exp < 0 -@pre max_exp > 0 -*/ -inline char *format_buffer( - char *buf, int len, int decimal_exponent, int min_exp, int max_exp) { - const int k = len; - const int n = len + decimal_exponent; - - // v = buf * 10^(n-k) - // k is the length of the buffer (number of decimal digits) - // n is the position of the decimal point relative to the start of the - // buffer. - - if (k <= n && n <= max_exp) { - // digits[000] - // len <= max_exp + 2 - - std::memset( - buf + k, '0', static_cast(n) - static_cast(k)); - // Make it look like a floating-point number (#362, #378) - // buf[n + 0] = '.'; - // buf[n + 1] = '0'; - return buf + (static_cast(n)); - } - - if (0 < n && n <= max_exp) { - // dig.its - // len <= max_digits10 + 1 - std::memmove(buf + (static_cast(n) + 1), - buf + n, - static_cast(k) - static_cast(n)); - buf[n] = '.'; - return buf + (static_cast(k) + 1U); - } - - if (min_exp < n && n <= 0) { - // 0.[000]digits - // len <= 2 + (-min_exp - 1) + max_digits10 - - std::memmove( - buf + (2 + static_cast(-n)), buf, static_cast(k)); - buf[0] = '0'; - buf[1] = '.'; - std::memset(buf + 2, '0', static_cast(-n)); - return buf + (2U + static_cast(-n) + static_cast(k)); - } - - if (k == 1) { - // dE+123 - // len <= 1 + 5 - - buf += 1; - } else { - // d.igitsE+123 - // len <= max_digits10 + 1 + 5 - - std::memmove(buf + 2, buf + 1, static_cast(k) - 1); - buf[1] = '.'; - buf += 1 + static_cast(k); - } - - *buf++ = 'e'; - return append_exponent(buf, n - 1); -} - -} // namespace dtoa_impl - -/*! -The format of the resulting decimal representation is similar to printf's %g -format. Returns an iterator pointing past-the-end of the decimal representation. -@note The input number must be finite, i.e. NaN's and Inf's are not supported. -@note The buffer must be large enough. -@note The result is NOT null-terminated. -*/ -char *to_chars(char *first, const char *last, double value) { - static_cast(last); // maybe unused - fix warning - bool negative = std::signbit(value); - if (negative) { - value = -value; - *first++ = '-'; - } - - if (value == 0) // +-0 - { - *first++ = '0'; - // Make it look like a floating-point number (#362, #378) - if (negative) { - *first++ = '.'; - *first++ = '0'; - } - return first; - } - // Compute v = buffer * 10^decimal_exponent. - // The decimal digits are stored in the buffer, which needs to be - // interpreted - // as an unsigned decimal integer. - // len is the length of the buffer, i.e. the number of decimal digits. - int len = 0; - int decimal_exponent = 0; - dtoa_impl::grisu2(first, len, decimal_exponent, value); - // Format the buffer like printf("%.*g", prec, value) - constexpr int kMinExp = -4; - constexpr int kMaxExp = std::numeric_limits::digits10; - - return dtoa_impl::format_buffer( - first, len, decimal_exponent, kMinExp, kMaxExp); -} -} // namespace internal -} // namespace simdjson -/* end file src/to_chars.cpp */ -/* begin file src/from_chars.cpp */ -#include -namespace simdjson { -namespace internal { - -/** - * The code in the internal::from_chars function is meant to handle the - *floating-point number parsing - * when we have more than 19 digits in the decimal mantissa. This should only be - *seen - * in adversarial scenarios: we do not expect production systems to even produce - * such floating-point numbers. - * - * The parser is based on work by Nigel Tao (at - *https://github.com/google/wuffs/) - * who credits Ken Thompson for the design (via a reference to the Go source - * code). See - * https://github.com/google/wuffs/blob/aa46859ea40c72516deffa1b146121952d6dfd3b/internal/cgen/base/floatconv-submodule-data.c - * https://github.com/google/wuffs/blob/46cd8105f47ca07ae2ba8e6a7818ef9c0df6c152/internal/cgen/base/floatconv-submodule-code.c - * It is probably not very fast but it is a fallback that should almost never be - * called in real life. Google Wuffs is published under APL 2.0. - **/ - -namespace { -constexpr uint32_t max_digits = 768; -constexpr int32_t decimal_point_range = 2047; -} // namespace - -struct adjusted_mantissa { - uint64_t mantissa; - int power2; - adjusted_mantissa() : mantissa(0), power2(0) {} -}; - -struct decimal { - uint32_t num_digits; - int32_t decimal_point; - bool negative; - bool truncated; - uint8_t digits[max_digits]; -}; - -template -struct binary_format { - static constexpr int mantissa_explicit_bits(); - static constexpr int minimum_exponent(); - static constexpr int infinite_power(); - static constexpr int sign_index(); -}; - -template <> -constexpr int binary_format::mantissa_explicit_bits() { - return 52; -} - -template <> -constexpr int binary_format::minimum_exponent() { - return -1023; -} -template <> -constexpr int binary_format::infinite_power() { - return 0x7FF; -} - -template <> -constexpr int binary_format::sign_index() { - return 63; -} - -bool is_integer(char c) noexcept { return (c >= '0' && c <= '9'); } - -// This should always succeed since it follows a call to parse_number. -decimal parse_decimal(const char *&p) noexcept { - decimal answer; - answer.num_digits = 0; - answer.decimal_point = 0; - answer.truncated = false; - answer.negative = (*p == '-'); - if ((*p == '-') || (*p == '+')) { - ++p; - } - - while (*p == '0') { - ++p; - } - while (is_integer(*p)) { - if (answer.num_digits < max_digits) { - answer.digits[answer.num_digits] = uint8_t(*p - '0'); - } - answer.num_digits++; - ++p; - } - if (*p == '.') { - ++p; - const char *first_after_period = p; - // if we have not yet encountered a zero, we have to skip it as well - if (answer.num_digits == 0) { - // skip zeros - while (*p == '0') { - ++p; - } - } - while (is_integer(*p)) { - if (answer.num_digits < max_digits) { - answer.digits[answer.num_digits] = uint8_t(*p - '0'); - } - answer.num_digits++; - ++p; - } - answer.decimal_point = int32_t(first_after_period - p); - } - if (answer.num_digits > 0) { - const char *preverse = p - 1; - int32_t trailing_zeros = 0; - while ((*preverse == '0') || (*preverse == '.')) { - if (*preverse == '0') { - trailing_zeros++; - }; - --preverse; - } - answer.decimal_point += int32_t(answer.num_digits); - answer.num_digits -= uint32_t(trailing_zeros); - } - if (answer.num_digits > max_digits) { - answer.num_digits = max_digits; - answer.truncated = true; - } - if (('e' == *p) || ('E' == *p)) { - ++p; - bool neg_exp = false; - if ('-' == *p) { - neg_exp = true; - ++p; - } else if ('+' == *p) { - ++p; - } - int32_t exp_number = 0; // exponential part - while (is_integer(*p)) { - uint8_t digit = uint8_t(*p - '0'); - if (exp_number < 0x10000) { - exp_number = 10 * exp_number + digit; - } - ++p; - } - answer.decimal_point += (neg_exp ? -exp_number : exp_number); - } - return answer; -} - -// This should always succeed since it follows a call to parse_number. -// Will not read at or beyond the "end" pointer. -decimal parse_decimal(const char *&p, const char *end) noexcept { - decimal answer; - answer.num_digits = 0; - answer.decimal_point = 0; - answer.truncated = false; - if (p == end) { - return answer; - } // should never happen - answer.negative = (*p == '-'); - if ((*p == '-') || (*p == '+')) { - ++p; - } - - while ((p != end) && (*p == '0')) { - ++p; - } - while ((p != end) && is_integer(*p)) { - if (answer.num_digits < max_digits) { - answer.digits[answer.num_digits] = uint8_t(*p - '0'); - } - answer.num_digits++; - ++p; - } - if ((p != end) && (*p == '.')) { - ++p; - if (p == end) { - return answer; - } // should never happen - const char *first_after_period = p; - // if we have not yet encountered a zero, we have to skip it as well - if (answer.num_digits == 0) { - // skip zeros - while (*p == '0') { - ++p; - } - } - while ((p != end) && is_integer(*p)) { - if (answer.num_digits < max_digits) { - answer.digits[answer.num_digits] = uint8_t(*p - '0'); - } - answer.num_digits++; - ++p; - } - answer.decimal_point = int32_t(first_after_period - p); - } - if (answer.num_digits > 0) { - const char *preverse = p - 1; - int32_t trailing_zeros = 0; - while ((*preverse == '0') || (*preverse == '.')) { - if (*preverse == '0') { - trailing_zeros++; - }; - --preverse; - } - answer.decimal_point += int32_t(answer.num_digits); - answer.num_digits -= uint32_t(trailing_zeros); - } - if (answer.num_digits > max_digits) { - answer.num_digits = max_digits; - answer.truncated = true; - } - if ((p != end) && (('e' == *p) || ('E' == *p))) { - ++p; - if (p == end) { - return answer; - } // should never happen - bool neg_exp = false; - if ('-' == *p) { - neg_exp = true; - ++p; - } else if ('+' == *p) { - ++p; - } - int32_t exp_number = 0; // exponential part - while ((p != end) && is_integer(*p)) { - uint8_t digit = uint8_t(*p - '0'); - if (exp_number < 0x10000) { - exp_number = 10 * exp_number + digit; - } - ++p; - } - answer.decimal_point += (neg_exp ? -exp_number : exp_number); - } - return answer; -} - -namespace { - -// remove all final zeroes -inline void trim(decimal &h) { - while ((h.num_digits > 0) && (h.digits[h.num_digits - 1] == 0)) { - h.num_digits--; - } -} - -uint32_t number_of_digits_decimal_left_shift(decimal &h, uint32_t shift) { - shift &= 63; - const static uint16_t number_of_digits_decimal_left_shift_table[65] = { - 0x0000, 0x0800, 0x0801, 0x0803, 0x1006, 0x1009, 0x100D, 0x1812, 0x1817, - 0x181D, 0x2024, 0x202B, 0x2033, 0x203C, 0x2846, 0x2850, 0x285B, 0x3067, - 0x3073, 0x3080, 0x388E, 0x389C, 0x38AB, 0x38BB, 0x40CC, 0x40DD, 0x40EF, - 0x4902, 0x4915, 0x4929, 0x513E, 0x5153, 0x5169, 0x5180, 0x5998, 0x59B0, - 0x59C9, 0x61E3, 0x61FD, 0x6218, 0x6A34, 0x6A50, 0x6A6D, 0x6A8B, 0x72AA, - 0x72C9, 0x72E9, 0x7B0A, 0x7B2B, 0x7B4D, 0x8370, 0x8393, 0x83B7, 0x83DC, - 0x8C02, 0x8C28, 0x8C4F, 0x9477, 0x949F, 0x94C8, 0x9CF2, 0x051C, 0x051C, - 0x051C, 0x051C, - }; - uint32_t x_a = number_of_digits_decimal_left_shift_table[shift]; - uint32_t x_b = number_of_digits_decimal_left_shift_table[shift + 1]; - uint32_t num_new_digits = x_a >> 11; - uint32_t pow5_a = 0x7FF & x_a; - uint32_t pow5_b = 0x7FF & x_b; - const static uint8_t - number_of_digits_decimal_left_shift_table_powers_of_5[0x051C] = { - 5, 2, 5, 1, 2, 5, 6, 2, 5, 3, 1, 2, 5, 1, 5, 6, 2, 5, 7, 8, 1, 2, 5, - 3, 9, 0, 6, 2, 5, 1, 9, 5, 3, 1, 2, 5, 9, 7, 6, 5, 6, 2, 5, 4, 8, 8, - 2, 8, 1, 2, 5, 2, 4, 4, 1, 4, 0, 6, 2, 5, 1, 2, 2, 0, 7, 0, 3, 1, 2, - 5, 6, 1, 0, 3, 5, 1, 5, 6, 2, 5, 3, 0, 5, 1, 7, 5, 7, 8, 1, 2, 5, 1, - 5, 2, 5, 8, 7, 8, 9, 0, 6, 2, 5, 7, 6, 2, 9, 3, 9, 4, 5, 3, 1, 2, 5, - 3, 8, 1, 4, 6, 9, 7, 2, 6, 5, 6, 2, 5, 1, 9, 0, 7, 3, 4, 8, 6, 3, 2, - 8, 1, 2, 5, 9, 5, 3, 6, 7, 4, 3, 1, 6, 4, 0, 6, 2, 5, 4, 7, 6, 8, 3, - 7, 1, 5, 8, 2, 0, 3, 1, 2, 5, 2, 3, 8, 4, 1, 8, 5, 7, 9, 1, 0, 1, 5, - 6, 2, 5, 1, 1, 9, 2, 0, 9, 2, 8, 9, 5, 5, 0, 7, 8, 1, 2, 5, 5, 9, 6, - 0, 4, 6, 4, 4, 7, 7, 5, 3, 9, 0, 6, 2, 5, 2, 9, 8, 0, 2, 3, 2, 2, 3, - 8, 7, 6, 9, 5, 3, 1, 2, 5, 1, 4, 9, 0, 1, 1, 6, 1, 1, 9, 3, 8, 4, 7, - 6, 5, 6, 2, 5, 7, 4, 5, 0, 5, 8, 0, 5, 9, 6, 9, 2, 3, 8, 2, 8, 1, 2, - 5, 3, 7, 2, 5, 2, 9, 0, 2, 9, 8, 4, 6, 1, 9, 1, 4, 0, 6, 2, 5, 1, 8, - 6, 2, 6, 4, 5, 1, 4, 9, 2, 3, 0, 9, 5, 7, 0, 3, 1, 2, 5, 9, 3, 1, 3, - 2, 2, 5, 7, 4, 6, 1, 5, 4, 7, 8, 5, 1, 5, 6, 2, 5, 4, 6, 5, 6, 6, 1, - 2, 8, 7, 3, 0, 7, 7, 3, 9, 2, 5, 7, 8, 1, 2, 5, 2, 3, 2, 8, 3, 0, 6, - 4, 3, 6, 5, 3, 8, 6, 9, 6, 2, 8, 9, 0, 6, 2, 5, 1, 1, 6, 4, 1, 5, 3, - 2, 1, 8, 2, 6, 9, 3, 4, 8, 1, 4, 4, 5, 3, 1, 2, 5, 5, 8, 2, 0, 7, 6, - 6, 0, 9, 1, 3, 4, 6, 7, 4, 0, 7, 2, 2, 6, 5, 6, 2, 5, 2, 9, 1, 0, 3, - 8, 3, 0, 4, 5, 6, 7, 3, 3, 7, 0, 3, 6, 1, 3, 2, 8, 1, 2, 5, 1, 4, 5, - 5, 1, 9, 1, 5, 2, 2, 8, 3, 6, 6, 8, 5, 1, 8, 0, 6, 6, 4, 0, 6, 2, 5, - 7, 2, 7, 5, 9, 5, 7, 6, 1, 4, 1, 8, 3, 4, 2, 5, 9, 0, 3, 3, 2, 0, 3, - 1, 2, 5, 3, 6, 3, 7, 9, 7, 8, 8, 0, 7, 0, 9, 1, 7, 1, 2, 9, 5, 1, 6, - 6, 0, 1, 5, 6, 2, 5, 1, 8, 1, 8, 9, 8, 9, 4, 0, 3, 5, 4, 5, 8, 5, 6, - 4, 7, 5, 8, 3, 0, 0, 7, 8, 1, 2, 5, 9, 0, 9, 4, 9, 4, 7, 0, 1, 7, 7, - 2, 9, 2, 8, 2, 3, 7, 9, 1, 5, 0, 3, 9, 0, 6, 2, 5, 4, 5, 4, 7, 4, 7, - 3, 5, 0, 8, 8, 6, 4, 6, 4, 1, 1, 8, 9, 5, 7, 5, 1, 9, 5, 3, 1, 2, 5, - 2, 2, 7, 3, 7, 3, 6, 7, 5, 4, 4, 3, 2, 3, 2, 0, 5, 9, 4, 7, 8, 7, 5, - 9, 7, 6, 5, 6, 2, 5, 1, 1, 3, 6, 8, 6, 8, 3, 7, 7, 2, 1, 6, 1, 6, 0, - 2, 9, 7, 3, 9, 3, 7, 9, 8, 8, 2, 8, 1, 2, 5, 5, 6, 8, 4, 3, 4, 1, 8, - 8, 6, 0, 8, 0, 8, 0, 1, 4, 8, 6, 9, 6, 8, 9, 9, 4, 1, 4, 0, 6, 2, 5, - 2, 8, 4, 2, 1, 7, 0, 9, 4, 3, 0, 4, 0, 4, 0, 0, 7, 4, 3, 4, 8, 4, 4, - 9, 7, 0, 7, 0, 3, 1, 2, 5, 1, 4, 2, 1, 0, 8, 5, 4, 7, 1, 5, 2, 0, 2, - 0, 0, 3, 7, 1, 7, 4, 2, 2, 4, 8, 5, 3, 5, 1, 5, 6, 2, 5, 7, 1, 0, 5, - 4, 2, 7, 3, 5, 7, 6, 0, 1, 0, 0, 1, 8, 5, 8, 7, 1, 1, 2, 4, 2, 6, 7, - 5, 7, 8, 1, 2, 5, 3, 5, 5, 2, 7, 1, 3, 6, 7, 8, 8, 0, 0, 5, 0, 0, 9, - 2, 9, 3, 5, 5, 6, 2, 1, 3, 3, 7, 8, 9, 0, 6, 2, 5, 1, 7, 7, 6, 3, 5, - 6, 8, 3, 9, 4, 0, 0, 2, 5, 0, 4, 6, 4, 6, 7, 7, 8, 1, 0, 6, 6, 8, 9, - 4, 5, 3, 1, 2, 5, 8, 8, 8, 1, 7, 8, 4, 1, 9, 7, 0, 0, 1, 2, 5, 2, 3, - 2, 3, 3, 8, 9, 0, 5, 3, 3, 4, 4, 7, 2, 6, 5, 6, 2, 5, 4, 4, 4, 0, 8, - 9, 2, 0, 9, 8, 5, 0, 0, 6, 2, 6, 1, 6, 1, 6, 9, 4, 5, 2, 6, 6, 7, 2, - 3, 6, 3, 2, 8, 1, 2, 5, 2, 2, 2, 0, 4, 4, 6, 0, 4, 9, 2, 5, 0, 3, 1, - 3, 0, 8, 0, 8, 4, 7, 2, 6, 3, 3, 3, 6, 1, 8, 1, 6, 4, 0, 6, 2, 5, 1, - 1, 1, 0, 2, 2, 3, 0, 2, 4, 6, 2, 5, 1, 5, 6, 5, 4, 0, 4, 2, 3, 6, 3, - 1, 6, 6, 8, 0, 9, 0, 8, 2, 0, 3, 1, 2, 5, 5, 5, 5, 1, 1, 1, 5, 1, 2, - 3, 1, 2, 5, 7, 8, 2, 7, 0, 2, 1, 1, 8, 1, 5, 8, 3, 4, 0, 4, 5, 4, 1, - 0, 1, 5, 6, 2, 5, 2, 7, 7, 5, 5, 5, 7, 5, 6, 1, 5, 6, 2, 8, 9, 1, 3, - 5, 1, 0, 5, 9, 0, 7, 9, 1, 7, 0, 2, 2, 7, 0, 5, 0, 7, 8, 1, 2, 5, 1, - 3, 8, 7, 7, 7, 8, 7, 8, 0, 7, 8, 1, 4, 4, 5, 6, 7, 5, 5, 2, 9, 5, 3, - 9, 5, 8, 5, 1, 1, 3, 5, 2, 5, 3, 9, 0, 6, 2, 5, 6, 9, 3, 8, 8, 9, 3, - 9, 0, 3, 9, 0, 7, 2, 2, 8, 3, 7, 7, 6, 4, 7, 6, 9, 7, 9, 2, 5, 5, 6, - 7, 6, 2, 6, 9, 5, 3, 1, 2, 5, 3, 4, 6, 9, 4, 4, 6, 9, 5, 1, 9, 5, 3, - 6, 1, 4, 1, 8, 8, 8, 2, 3, 8, 4, 8, 9, 6, 2, 7, 8, 3, 8, 1, 3, 4, 7, - 6, 5, 6, 2, 5, 1, 7, 3, 4, 7, 2, 3, 4, 7, 5, 9, 7, 6, 8, 0, 7, 0, 9, - 4, 4, 1, 1, 9, 2, 4, 4, 8, 1, 3, 9, 1, 9, 0, 6, 7, 3, 8, 2, 8, 1, 2, - 5, 8, 6, 7, 3, 6, 1, 7, 3, 7, 9, 8, 8, 4, 0, 3, 5, 4, 7, 2, 0, 5, 9, - 6, 2, 2, 4, 0, 6, 9, 5, 9, 5, 3, 3, 6, 9, 1, 4, 0, 6, 2, 5, - }; - const uint8_t *pow5 = - &number_of_digits_decimal_left_shift_table_powers_of_5[pow5_a]; - uint32_t i = 0; - uint32_t n = pow5_b - pow5_a; - for (; i < n; i++) { - if (i >= h.num_digits) { - return num_new_digits - 1; - } else if (h.digits[i] == pow5[i]) { - continue; - } else if (h.digits[i] < pow5[i]) { - return num_new_digits - 1; - } else { - return num_new_digits; - } - } - return num_new_digits; -} - -} // end of anonymous namespace - -uint64_t round(decimal &h) { - if ((h.num_digits == 0) || (h.decimal_point < 0)) { - return 0; - } else if (h.decimal_point > 18) { - return UINT64_MAX; - } - // at this point, we know that h.decimal_point >= 0 - uint32_t dp = uint32_t(h.decimal_point); - uint64_t n = 0; - for (uint32_t i = 0; i < dp; i++) { - n = (10 * n) + ((i < h.num_digits) ? h.digits[i] : 0); - } - bool round_up = false; - if (dp < h.num_digits) { - round_up = h.digits[dp] >= 5; // normally, we round up - // but we may need to round to even! - if ((h.digits[dp] == 5) && (dp + 1 == h.num_digits)) { - round_up = h.truncated || ((dp > 0) && (1 & h.digits[dp - 1])); - } - } - if (round_up) { - n++; - } - return n; -} - -// computes h * 2^-shift -void decimal_left_shift(decimal &h, uint32_t shift) { - if (h.num_digits == 0) { - return; - } - uint32_t num_new_digits = number_of_digits_decimal_left_shift(h, shift); - int32_t read_index = int32_t(h.num_digits - 1); - uint32_t write_index = h.num_digits - 1 + num_new_digits; - uint64_t n = 0; - - while (read_index >= 0) { - n += uint64_t(h.digits[read_index]) << shift; - uint64_t quotient = n / 10; - uint64_t remainder = n - (10 * quotient); - if (write_index < max_digits) { - h.digits[write_index] = uint8_t(remainder); - } else if (remainder > 0) { - h.truncated = true; - } - n = quotient; - write_index--; - read_index--; - } - while (n > 0) { - uint64_t quotient = n / 10; - uint64_t remainder = n - (10 * quotient); - if (write_index < max_digits) { - h.digits[write_index] = uint8_t(remainder); - } else if (remainder > 0) { - h.truncated = true; - } - n = quotient; - write_index--; - } - h.num_digits += num_new_digits; - if (h.num_digits > max_digits) { - h.num_digits = max_digits; - } - h.decimal_point += int32_t(num_new_digits); - trim(h); -} - -// computes h * 2^shift -void decimal_right_shift(decimal &h, uint32_t shift) { - uint32_t read_index = 0; - uint32_t write_index = 0; - - uint64_t n = 0; - - while ((n >> shift) == 0) { - if (read_index < h.num_digits) { - n = (10 * n) + h.digits[read_index++]; - } else if (n == 0) { - return; - } else { - while ((n >> shift) == 0) { - n = 10 * n; - read_index++; - } - break; - } - } - h.decimal_point -= int32_t(read_index - 1); - if (h.decimal_point < -decimal_point_range) { // it is zero - h.num_digits = 0; - h.decimal_point = 0; - h.negative = false; - h.truncated = false; - return; - } - uint64_t mask = (uint64_t(1) << shift) - 1; - while (read_index < h.num_digits) { - uint8_t new_digit = uint8_t(n >> shift); - n = (10 * (n & mask)) + h.digits[read_index++]; - h.digits[write_index++] = new_digit; - } - while (n > 0) { - uint8_t new_digit = uint8_t(n >> shift); - n = 10 * (n & mask); - if (write_index < max_digits) { - h.digits[write_index++] = new_digit; - } else if (new_digit > 0) { - h.truncated = true; - } - } - h.num_digits = write_index; - trim(h); -} - -template -adjusted_mantissa compute_float(decimal &d) { - adjusted_mantissa answer; - if (d.num_digits == 0) { - // should be zero - answer.power2 = 0; - answer.mantissa = 0; - return answer; - } - // At this point, going further, we can assume that d.num_digits > 0. - // We want to guard against excessive decimal point values because - // they can result in long running times. Indeed, we do - // shifts by at most 60 bits. We have that log(10**400)/log(2**60) ~= 22 - // which is fine, but log(10**299995)/log(2**60) ~= 16609 which is not - // fine (runs for a long time). - // - if (d.decimal_point < -324) { - // We have something smaller than 1e-324 which is always zero - // in binary64 and binary32. - // It should be zero. - answer.power2 = 0; - answer.mantissa = 0; - return answer; - } else if (d.decimal_point >= 310) { - // We have something at least as large as 0.1e310 which is - // always infinite. - answer.power2 = binary::infinite_power(); - answer.mantissa = 0; - return answer; - } - - static const uint32_t max_shift = 60; - static const uint32_t num_powers = 19; - static const uint8_t powers[19] = { - 0, 3, 6, 9, 13, 16, 19, 23, 26, 29, // - 33, 36, 39, 43, 46, 49, 53, 56, 59, // - }; - int32_t exp2 = 0; - while (d.decimal_point > 0) { - uint32_t n = uint32_t(d.decimal_point); - uint32_t shift = (n < num_powers) ? powers[n] : max_shift; - decimal_right_shift(d, shift); - if (d.decimal_point < -decimal_point_range) { - // should be zero - answer.power2 = 0; - answer.mantissa = 0; - return answer; - } - exp2 += int32_t(shift); - } - // We shift left toward [1/2 ... 1]. - while (d.decimal_point <= 0) { - uint32_t shift; - if (d.decimal_point == 0) { - if (d.digits[0] >= 5) { - break; - } - shift = (d.digits[0] < 2) ? 2 : 1; - } else { - uint32_t n = uint32_t(-d.decimal_point); - shift = (n < num_powers) ? powers[n] : max_shift; - } - decimal_left_shift(d, shift); - if (d.decimal_point > decimal_point_range) { - // we want to get infinity: - answer.power2 = 0xFF; - answer.mantissa = 0; - return answer; - } - exp2 -= int32_t(shift); - } - // We are now in the range [1/2 ... 1] but the binary format uses [1 ... 2]. - exp2--; - constexpr int32_t minimum_exponent = binary::minimum_exponent(); - while ((minimum_exponent + 1) > exp2) { - uint32_t n = uint32_t((minimum_exponent + 1) - exp2); - if (n > max_shift) { - n = max_shift; - } - decimal_right_shift(d, n); - exp2 += int32_t(n); - } - if ((exp2 - minimum_exponent) >= binary::infinite_power()) { - answer.power2 = binary::infinite_power(); - answer.mantissa = 0; - return answer; - } - - const int mantissa_size_in_bits = binary::mantissa_explicit_bits() + 1; - decimal_left_shift(d, mantissa_size_in_bits); - - uint64_t mantissa = round(d); - // It is possible that we have an overflow, in which case we need - // to shift back. - if (mantissa >= (uint64_t(1) << mantissa_size_in_bits)) { - decimal_right_shift(d, 1); - exp2 += 1; - mantissa = round(d); - if ((exp2 - minimum_exponent) >= binary::infinite_power()) { - answer.power2 = binary::infinite_power(); - answer.mantissa = 0; - return answer; - } - } - answer.power2 = exp2 - binary::minimum_exponent(); - if (mantissa < (uint64_t(1) << binary::mantissa_explicit_bits())) { - answer.power2--; - } - answer.mantissa = - mantissa & ((uint64_t(1) << binary::mantissa_explicit_bits()) - 1); - return answer; -} - -template -adjusted_mantissa parse_long_mantissa(const char *first) { - decimal d = parse_decimal(first); - return compute_float(d); -} - -template -adjusted_mantissa parse_long_mantissa(const char *first, const char *end) { - decimal d = parse_decimal(first, end); - return compute_float(d); -} - -double from_chars(const char *first) noexcept { - bool negative = first[0] == '-'; - if (negative) { - first++; - } - adjusted_mantissa am = parse_long_mantissa>(first); - uint64_t word = am.mantissa; - word |= uint64_t(am.power2) - << binary_format::mantissa_explicit_bits(); - word = negative - ? word | (uint64_t(1) << binary_format::sign_index()) - : word; - double value; - std::memcpy(&value, &word, sizeof(double)); - return value; -} - - -double from_chars(const char *first, const char *end) noexcept { - bool negative = first[0] == '-'; - if (negative) { - first++; - } - adjusted_mantissa am = - parse_long_mantissa>(first, end); - uint64_t word = am.mantissa; - word |= uint64_t(am.power2) - << binary_format::mantissa_explicit_bits(); - word = negative - ? word | (uint64_t(1) << binary_format::sign_index()) - : word; - double value; - std::memcpy(&value, &word, sizeof(double)); - return value; -} - -} // internal -} // simdjson -/* end file src/from_chars.cpp */ -/* begin file src/internal/error_tables.cpp */ - -namespace simdjson { -namespace internal { - -SIMDJSON_DLLIMPORTEXPORT const error_code_info error_codes[]{ - {SUCCESS, "No error"}, - {CAPACITY, "This parser can't support a document that big"}, - {MEMALLOC, "Error allocating memory, we're most likely out of memory"}, - {TAPE_ERROR, - "The JSON document has an improper structure: missing or superfluous " - "commas, braces, missing keys, etc."}, - {DEPTH_ERROR, - "The JSON document was too deep (too many nested objects and arrays)"}, - {STRING_ERROR, "Problem while parsing a string"}, - {T_ATOM_ERROR, - "Problem while parsing an atom starting with the letter 't'"}, - {F_ATOM_ERROR, - "Problem while parsing an atom starting with the letter 'f'"}, - {N_ATOM_ERROR, - "Problem while parsing an atom starting with the letter 'n'"}, - {NUMBER_ERROR, "Problem while parsing a number"}, - {UTF8_ERROR, "The input is not valid UTF-8"}, - {UNINITIALIZED, "Uninitialized"}, - {EMPTY, "Empty: no JSON found"}, - {UNESCAPED_CHARS, - "Within strings, some characters must be escaped, we found unescaped " - "characters"}, - {UNCLOSED_STRING, "A string is opened, but never closed."}, - {UNSUPPORTED_ARCHITECTURE, - "simdjson does not have an implementation supported by this CPU " - "architecture (perhaps it's a non-SIMD CPU?)."}, - {INCORRECT_TYPE, "The JSON element does not have the requested type."}, - {NUMBER_OUT_OF_RANGE, - "The JSON number is too large or too small to fit within the requested " - "type."}, - {INDEX_OUT_OF_BOUNDS, - "Attempted to access an element of a JSON array that is beyond its " - "length."}, - {NO_SUCH_FIELD, "The JSON field referenced does not exist in this object."}, - {IO_ERROR, "Error reading the file."}, - {INVALID_JSON_POINTER, "Invalid JSON pointer syntax."}, - {INVALID_URI_FRAGMENT, "Invalid URI fragment syntax."}, - {UNEXPECTED_ERROR, - "Unexpected error, consider reporting this problem as you may have found " - "a bug in simdjson"}, - {PARSER_IN_USE, - "Cannot parse a new document while a document is still in use."}, - {OUT_OF_ORDER_ITERATION, - "Objects and arrays can only be iterated when they are first " - "encountered."}, - {INSUFFICIENT_PADDING, - "simdjson requires the input JSON string to have at least " - "SIMDJSON_PADDING extra bytes allocated, beyond the string's length. " - "Consider using the simdjson::padded_string class if needed."}, - {INCOMPLETE_ARRAY_OR_OBJECT, - "JSON document ended early in the middle of an object or array."}, - {SCALAR_DOCUMENT_AS_VALUE, - "A JSON document made of a scalar (number, Boolean, null or string) is " - "treated as a value. Use get_bool(), get_double(), etc. on the document " - "instead. "}, - {OUT_OF_BOUNDS, - "Attempted to access location outside of document."}}; // error_messages[] - -} // namespace internal -} // namespace simdjson -/* end file src/internal/error_tables.cpp */ -/* begin file src/internal/jsoncharutils_tables.cpp */ - -namespace simdjson { -namespace internal { - -// structural chars here are -// they are { 0x7b } 0x7d : 0x3a [ 0x5b ] 0x5d , 0x2c (and NULL) -// we are also interested in the four whitespace characters -// space 0x20, linefeed 0x0a, horizontal tab 0x09 and carriage return 0x0d - -SIMDJSON_DLLIMPORTEXPORT const bool structural_or_whitespace_negated[256] = { - 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, - - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, - - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; - -SIMDJSON_DLLIMPORTEXPORT const bool structural_or_whitespace[256] = { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - -SIMDJSON_DLLIMPORTEXPORT const uint32_t digit_to_val32[886] = { - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, - 0x6, 0x7, 0x8, 0x9, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xa, - 0xb, 0xc, 0xd, 0xe, 0xf, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xa, 0xb, 0xc, 0xd, 0xe, - 0xf, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0x0, 0x10, 0x20, 0x30, 0x40, 0x50, - 0x60, 0x70, 0x80, 0x90, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xa0, - 0xb0, 0xc0, 0xd0, 0xe0, 0xf0, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, - 0xf0, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0x0, 0x100, 0x200, 0x300, 0x400, 0x500, - 0x600, 0x700, 0x800, 0x900, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xa00, - 0xb00, 0xc00, 0xd00, 0xe00, 0xf00, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xa00, 0xb00, 0xc00, 0xd00, 0xe00, - 0xf00, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0x0, 0x1000, 0x2000, 0x3000, 0x4000, 0x5000, - 0x6000, 0x7000, 0x8000, 0x9000, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xa000, - 0xb000, 0xc000, 0xd000, 0xe000, 0xf000, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xa000, 0xb000, 0xc000, 0xd000, 0xe000, - 0xf000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF}; - -} // namespace internal -} // namespace simdjson -/* end file src/internal/jsoncharutils_tables.cpp */ -/* begin file src/internal/numberparsing_tables.cpp */ - -namespace simdjson { -namespace internal { - -// Precomputed powers of ten from 10^0 to 10^22. These -// can be represented exactly using the double type. -SIMDJSON_DLLIMPORTEXPORT const double power_of_ten[] = { - 1e0, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10, 1e11, - 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18, 1e19, 1e20, 1e21, 1e22}; - -/** - * When mapping numbers from decimal to binary, - * we go from w * 10^q to m * 2^p but we have - * 10^q = 5^q * 2^q, so effectively - * we are trying to match - * w * 2^q * 5^q to m * 2^p. Thus the powers of two - * are not a concern since they can be represented - * exactly using the binary notation, only the powers of five - * affect the binary significand. - */ - - -// The truncated powers of five from 5^-342 all the way to 5^308 -// The mantissa is truncated to 128 bits, and -// never rounded up. Uses about 10KB. -SIMDJSON_DLLIMPORTEXPORT const uint64_t power_of_five_128[] = { - 0xeef453d6923bd65a, 0x113faa2906a13b3f, - 0x9558b4661b6565f8, 0x4ac7ca59a424c507, - 0xbaaee17fa23ebf76, 0x5d79bcf00d2df649, - 0xe95a99df8ace6f53, 0xf4d82c2c107973dc, - 0x91d8a02bb6c10594, 0x79071b9b8a4be869, - 0xb64ec836a47146f9, 0x9748e2826cdee284, - 0xe3e27a444d8d98b7, 0xfd1b1b2308169b25, - 0x8e6d8c6ab0787f72, 0xfe30f0f5e50e20f7, - 0xb208ef855c969f4f, 0xbdbd2d335e51a935, - 0xde8b2b66b3bc4723, 0xad2c788035e61382, - 0x8b16fb203055ac76, 0x4c3bcb5021afcc31, - 0xaddcb9e83c6b1793, 0xdf4abe242a1bbf3d, - 0xd953e8624b85dd78, 0xd71d6dad34a2af0d, - 0x87d4713d6f33aa6b, 0x8672648c40e5ad68, - 0xa9c98d8ccb009506, 0x680efdaf511f18c2, - 0xd43bf0effdc0ba48, 0x212bd1b2566def2, - 0x84a57695fe98746d, 0x14bb630f7604b57, - 0xa5ced43b7e3e9188, 0x419ea3bd35385e2d, - 0xcf42894a5dce35ea, 0x52064cac828675b9, - 0x818995ce7aa0e1b2, 0x7343efebd1940993, - 0xa1ebfb4219491a1f, 0x1014ebe6c5f90bf8, - 0xca66fa129f9b60a6, 0xd41a26e077774ef6, - 0xfd00b897478238d0, 0x8920b098955522b4, - 0x9e20735e8cb16382, 0x55b46e5f5d5535b0, - 0xc5a890362fddbc62, 0xeb2189f734aa831d, - 0xf712b443bbd52b7b, 0xa5e9ec7501d523e4, - 0x9a6bb0aa55653b2d, 0x47b233c92125366e, - 0xc1069cd4eabe89f8, 0x999ec0bb696e840a, - 0xf148440a256e2c76, 0xc00670ea43ca250d, - 0x96cd2a865764dbca, 0x380406926a5e5728, - 0xbc807527ed3e12bc, 0xc605083704f5ecf2, - 0xeba09271e88d976b, 0xf7864a44c633682e, - 0x93445b8731587ea3, 0x7ab3ee6afbe0211d, - 0xb8157268fdae9e4c, 0x5960ea05bad82964, - 0xe61acf033d1a45df, 0x6fb92487298e33bd, - 0x8fd0c16206306bab, 0xa5d3b6d479f8e056, - 0xb3c4f1ba87bc8696, 0x8f48a4899877186c, - 0xe0b62e2929aba83c, 0x331acdabfe94de87, - 0x8c71dcd9ba0b4925, 0x9ff0c08b7f1d0b14, - 0xaf8e5410288e1b6f, 0x7ecf0ae5ee44dd9, - 0xdb71e91432b1a24a, 0xc9e82cd9f69d6150, - 0x892731ac9faf056e, 0xbe311c083a225cd2, - 0xab70fe17c79ac6ca, 0x6dbd630a48aaf406, - 0xd64d3d9db981787d, 0x92cbbccdad5b108, - 0x85f0468293f0eb4e, 0x25bbf56008c58ea5, - 0xa76c582338ed2621, 0xaf2af2b80af6f24e, - 0xd1476e2c07286faa, 0x1af5af660db4aee1, - 0x82cca4db847945ca, 0x50d98d9fc890ed4d, - 0xa37fce126597973c, 0xe50ff107bab528a0, - 0xcc5fc196fefd7d0c, 0x1e53ed49a96272c8, - 0xff77b1fcbebcdc4f, 0x25e8e89c13bb0f7a, - 0x9faacf3df73609b1, 0x77b191618c54e9ac, - 0xc795830d75038c1d, 0xd59df5b9ef6a2417, - 0xf97ae3d0d2446f25, 0x4b0573286b44ad1d, - 0x9becce62836ac577, 0x4ee367f9430aec32, - 0xc2e801fb244576d5, 0x229c41f793cda73f, - 0xf3a20279ed56d48a, 0x6b43527578c1110f, - 0x9845418c345644d6, 0x830a13896b78aaa9, - 0xbe5691ef416bd60c, 0x23cc986bc656d553, - 0xedec366b11c6cb8f, 0x2cbfbe86b7ec8aa8, - 0x94b3a202eb1c3f39, 0x7bf7d71432f3d6a9, - 0xb9e08a83a5e34f07, 0xdaf5ccd93fb0cc53, - 0xe858ad248f5c22c9, 0xd1b3400f8f9cff68, - 0x91376c36d99995be, 0x23100809b9c21fa1, - 0xb58547448ffffb2d, 0xabd40a0c2832a78a, - 0xe2e69915b3fff9f9, 0x16c90c8f323f516c, - 0x8dd01fad907ffc3b, 0xae3da7d97f6792e3, - 0xb1442798f49ffb4a, 0x99cd11cfdf41779c, - 0xdd95317f31c7fa1d, 0x40405643d711d583, - 0x8a7d3eef7f1cfc52, 0x482835ea666b2572, - 0xad1c8eab5ee43b66, 0xda3243650005eecf, - 0xd863b256369d4a40, 0x90bed43e40076a82, - 0x873e4f75e2224e68, 0x5a7744a6e804a291, - 0xa90de3535aaae202, 0x711515d0a205cb36, - 0xd3515c2831559a83, 0xd5a5b44ca873e03, - 0x8412d9991ed58091, 0xe858790afe9486c2, - 0xa5178fff668ae0b6, 0x626e974dbe39a872, - 0xce5d73ff402d98e3, 0xfb0a3d212dc8128f, - 0x80fa687f881c7f8e, 0x7ce66634bc9d0b99, - 0xa139029f6a239f72, 0x1c1fffc1ebc44e80, - 0xc987434744ac874e, 0xa327ffb266b56220, - 0xfbe9141915d7a922, 0x4bf1ff9f0062baa8, - 0x9d71ac8fada6c9b5, 0x6f773fc3603db4a9, - 0xc4ce17b399107c22, 0xcb550fb4384d21d3, - 0xf6019da07f549b2b, 0x7e2a53a146606a48, - 0x99c102844f94e0fb, 0x2eda7444cbfc426d, - 0xc0314325637a1939, 0xfa911155fefb5308, - 0xf03d93eebc589f88, 0x793555ab7eba27ca, - 0x96267c7535b763b5, 0x4bc1558b2f3458de, - 0xbbb01b9283253ca2, 0x9eb1aaedfb016f16, - 0xea9c227723ee8bcb, 0x465e15a979c1cadc, - 0x92a1958a7675175f, 0xbfacd89ec191ec9, - 0xb749faed14125d36, 0xcef980ec671f667b, - 0xe51c79a85916f484, 0x82b7e12780e7401a, - 0x8f31cc0937ae58d2, 0xd1b2ecb8b0908810, - 0xb2fe3f0b8599ef07, 0x861fa7e6dcb4aa15, - 0xdfbdcece67006ac9, 0x67a791e093e1d49a, - 0x8bd6a141006042bd, 0xe0c8bb2c5c6d24e0, - 0xaecc49914078536d, 0x58fae9f773886e18, - 0xda7f5bf590966848, 0xaf39a475506a899e, - 0x888f99797a5e012d, 0x6d8406c952429603, - 0xaab37fd7d8f58178, 0xc8e5087ba6d33b83, - 0xd5605fcdcf32e1d6, 0xfb1e4a9a90880a64, - 0x855c3be0a17fcd26, 0x5cf2eea09a55067f, - 0xa6b34ad8c9dfc06f, 0xf42faa48c0ea481e, - 0xd0601d8efc57b08b, 0xf13b94daf124da26, - 0x823c12795db6ce57, 0x76c53d08d6b70858, - 0xa2cb1717b52481ed, 0x54768c4b0c64ca6e, - 0xcb7ddcdda26da268, 0xa9942f5dcf7dfd09, - 0xfe5d54150b090b02, 0xd3f93b35435d7c4c, - 0x9efa548d26e5a6e1, 0xc47bc5014a1a6daf, - 0xc6b8e9b0709f109a, 0x359ab6419ca1091b, - 0xf867241c8cc6d4c0, 0xc30163d203c94b62, - 0x9b407691d7fc44f8, 0x79e0de63425dcf1d, - 0xc21094364dfb5636, 0x985915fc12f542e4, - 0xf294b943e17a2bc4, 0x3e6f5b7b17b2939d, - 0x979cf3ca6cec5b5a, 0xa705992ceecf9c42, - 0xbd8430bd08277231, 0x50c6ff782a838353, - 0xece53cec4a314ebd, 0xa4f8bf5635246428, - 0x940f4613ae5ed136, 0x871b7795e136be99, - 0xb913179899f68584, 0x28e2557b59846e3f, - 0xe757dd7ec07426e5, 0x331aeada2fe589cf, - 0x9096ea6f3848984f, 0x3ff0d2c85def7621, - 0xb4bca50b065abe63, 0xfed077a756b53a9, - 0xe1ebce4dc7f16dfb, 0xd3e8495912c62894, - 0x8d3360f09cf6e4bd, 0x64712dd7abbbd95c, - 0xb080392cc4349dec, 0xbd8d794d96aacfb3, - 0xdca04777f541c567, 0xecf0d7a0fc5583a0, - 0x89e42caaf9491b60, 0xf41686c49db57244, - 0xac5d37d5b79b6239, 0x311c2875c522ced5, - 0xd77485cb25823ac7, 0x7d633293366b828b, - 0x86a8d39ef77164bc, 0xae5dff9c02033197, - 0xa8530886b54dbdeb, 0xd9f57f830283fdfc, - 0xd267caa862a12d66, 0xd072df63c324fd7b, - 0x8380dea93da4bc60, 0x4247cb9e59f71e6d, - 0xa46116538d0deb78, 0x52d9be85f074e608, - 0xcd795be870516656, 0x67902e276c921f8b, - 0x806bd9714632dff6, 0xba1cd8a3db53b6, - 0xa086cfcd97bf97f3, 0x80e8a40eccd228a4, - 0xc8a883c0fdaf7df0, 0x6122cd128006b2cd, - 0xfad2a4b13d1b5d6c, 0x796b805720085f81, - 0x9cc3a6eec6311a63, 0xcbe3303674053bb0, - 0xc3f490aa77bd60fc, 0xbedbfc4411068a9c, - 0xf4f1b4d515acb93b, 0xee92fb5515482d44, - 0x991711052d8bf3c5, 0x751bdd152d4d1c4a, - 0xbf5cd54678eef0b6, 0xd262d45a78a0635d, - 0xef340a98172aace4, 0x86fb897116c87c34, - 0x9580869f0e7aac0e, 0xd45d35e6ae3d4da0, - 0xbae0a846d2195712, 0x8974836059cca109, - 0xe998d258869facd7, 0x2bd1a438703fc94b, - 0x91ff83775423cc06, 0x7b6306a34627ddcf, - 0xb67f6455292cbf08, 0x1a3bc84c17b1d542, - 0xe41f3d6a7377eeca, 0x20caba5f1d9e4a93, - 0x8e938662882af53e, 0x547eb47b7282ee9c, - 0xb23867fb2a35b28d, 0xe99e619a4f23aa43, - 0xdec681f9f4c31f31, 0x6405fa00e2ec94d4, - 0x8b3c113c38f9f37e, 0xde83bc408dd3dd04, - 0xae0b158b4738705e, 0x9624ab50b148d445, - 0xd98ddaee19068c76, 0x3badd624dd9b0957, - 0x87f8a8d4cfa417c9, 0xe54ca5d70a80e5d6, - 0xa9f6d30a038d1dbc, 0x5e9fcf4ccd211f4c, - 0xd47487cc8470652b, 0x7647c3200069671f, - 0x84c8d4dfd2c63f3b, 0x29ecd9f40041e073, - 0xa5fb0a17c777cf09, 0xf468107100525890, - 0xcf79cc9db955c2cc, 0x7182148d4066eeb4, - 0x81ac1fe293d599bf, 0xc6f14cd848405530, - 0xa21727db38cb002f, 0xb8ada00e5a506a7c, - 0xca9cf1d206fdc03b, 0xa6d90811f0e4851c, - 0xfd442e4688bd304a, 0x908f4a166d1da663, - 0x9e4a9cec15763e2e, 0x9a598e4e043287fe, - 0xc5dd44271ad3cdba, 0x40eff1e1853f29fd, - 0xf7549530e188c128, 0xd12bee59e68ef47c, - 0x9a94dd3e8cf578b9, 0x82bb74f8301958ce, - 0xc13a148e3032d6e7, 0xe36a52363c1faf01, - 0xf18899b1bc3f8ca1, 0xdc44e6c3cb279ac1, - 0x96f5600f15a7b7e5, 0x29ab103a5ef8c0b9, - 0xbcb2b812db11a5de, 0x7415d448f6b6f0e7, - 0xebdf661791d60f56, 0x111b495b3464ad21, - 0x936b9fcebb25c995, 0xcab10dd900beec34, - 0xb84687c269ef3bfb, 0x3d5d514f40eea742, - 0xe65829b3046b0afa, 0xcb4a5a3112a5112, - 0x8ff71a0fe2c2e6dc, 0x47f0e785eaba72ab, - 0xb3f4e093db73a093, 0x59ed216765690f56, - 0xe0f218b8d25088b8, 0x306869c13ec3532c, - 0x8c974f7383725573, 0x1e414218c73a13fb, - 0xafbd2350644eeacf, 0xe5d1929ef90898fa, - 0xdbac6c247d62a583, 0xdf45f746b74abf39, - 0x894bc396ce5da772, 0x6b8bba8c328eb783, - 0xab9eb47c81f5114f, 0x66ea92f3f326564, - 0xd686619ba27255a2, 0xc80a537b0efefebd, - 0x8613fd0145877585, 0xbd06742ce95f5f36, - 0xa798fc4196e952e7, 0x2c48113823b73704, - 0xd17f3b51fca3a7a0, 0xf75a15862ca504c5, - 0x82ef85133de648c4, 0x9a984d73dbe722fb, - 0xa3ab66580d5fdaf5, 0xc13e60d0d2e0ebba, - 0xcc963fee10b7d1b3, 0x318df905079926a8, - 0xffbbcfe994e5c61f, 0xfdf17746497f7052, - 0x9fd561f1fd0f9bd3, 0xfeb6ea8bedefa633, - 0xc7caba6e7c5382c8, 0xfe64a52ee96b8fc0, - 0xf9bd690a1b68637b, 0x3dfdce7aa3c673b0, - 0x9c1661a651213e2d, 0x6bea10ca65c084e, - 0xc31bfa0fe5698db8, 0x486e494fcff30a62, - 0xf3e2f893dec3f126, 0x5a89dba3c3efccfa, - 0x986ddb5c6b3a76b7, 0xf89629465a75e01c, - 0xbe89523386091465, 0xf6bbb397f1135823, - 0xee2ba6c0678b597f, 0x746aa07ded582e2c, - 0x94db483840b717ef, 0xa8c2a44eb4571cdc, - 0xba121a4650e4ddeb, 0x92f34d62616ce413, - 0xe896a0d7e51e1566, 0x77b020baf9c81d17, - 0x915e2486ef32cd60, 0xace1474dc1d122e, - 0xb5b5ada8aaff80b8, 0xd819992132456ba, - 0xe3231912d5bf60e6, 0x10e1fff697ed6c69, - 0x8df5efabc5979c8f, 0xca8d3ffa1ef463c1, - 0xb1736b96b6fd83b3, 0xbd308ff8a6b17cb2, - 0xddd0467c64bce4a0, 0xac7cb3f6d05ddbde, - 0x8aa22c0dbef60ee4, 0x6bcdf07a423aa96b, - 0xad4ab7112eb3929d, 0x86c16c98d2c953c6, - 0xd89d64d57a607744, 0xe871c7bf077ba8b7, - 0x87625f056c7c4a8b, 0x11471cd764ad4972, - 0xa93af6c6c79b5d2d, 0xd598e40d3dd89bcf, - 0xd389b47879823479, 0x4aff1d108d4ec2c3, - 0x843610cb4bf160cb, 0xcedf722a585139ba, - 0xa54394fe1eedb8fe, 0xc2974eb4ee658828, - 0xce947a3da6a9273e, 0x733d226229feea32, - 0x811ccc668829b887, 0x806357d5a3f525f, - 0xa163ff802a3426a8, 0xca07c2dcb0cf26f7, - 0xc9bcff6034c13052, 0xfc89b393dd02f0b5, - 0xfc2c3f3841f17c67, 0xbbac2078d443ace2, - 0x9d9ba7832936edc0, 0xd54b944b84aa4c0d, - 0xc5029163f384a931, 0xa9e795e65d4df11, - 0xf64335bcf065d37d, 0x4d4617b5ff4a16d5, - 0x99ea0196163fa42e, 0x504bced1bf8e4e45, - 0xc06481fb9bcf8d39, 0xe45ec2862f71e1d6, - 0xf07da27a82c37088, 0x5d767327bb4e5a4c, - 0x964e858c91ba2655, 0x3a6a07f8d510f86f, - 0xbbe226efb628afea, 0x890489f70a55368b, - 0xeadab0aba3b2dbe5, 0x2b45ac74ccea842e, - 0x92c8ae6b464fc96f, 0x3b0b8bc90012929d, - 0xb77ada0617e3bbcb, 0x9ce6ebb40173744, - 0xe55990879ddcaabd, 0xcc420a6a101d0515, - 0x8f57fa54c2a9eab6, 0x9fa946824a12232d, - 0xb32df8e9f3546564, 0x47939822dc96abf9, - 0xdff9772470297ebd, 0x59787e2b93bc56f7, - 0x8bfbea76c619ef36, 0x57eb4edb3c55b65a, - 0xaefae51477a06b03, 0xede622920b6b23f1, - 0xdab99e59958885c4, 0xe95fab368e45eced, - 0x88b402f7fd75539b, 0x11dbcb0218ebb414, - 0xaae103b5fcd2a881, 0xd652bdc29f26a119, - 0xd59944a37c0752a2, 0x4be76d3346f0495f, - 0x857fcae62d8493a5, 0x6f70a4400c562ddb, - 0xa6dfbd9fb8e5b88e, 0xcb4ccd500f6bb952, - 0xd097ad07a71f26b2, 0x7e2000a41346a7a7, - 0x825ecc24c873782f, 0x8ed400668c0c28c8, - 0xa2f67f2dfa90563b, 0x728900802f0f32fa, - 0xcbb41ef979346bca, 0x4f2b40a03ad2ffb9, - 0xfea126b7d78186bc, 0xe2f610c84987bfa8, - 0x9f24b832e6b0f436, 0xdd9ca7d2df4d7c9, - 0xc6ede63fa05d3143, 0x91503d1c79720dbb, - 0xf8a95fcf88747d94, 0x75a44c6397ce912a, - 0x9b69dbe1b548ce7c, 0xc986afbe3ee11aba, - 0xc24452da229b021b, 0xfbe85badce996168, - 0xf2d56790ab41c2a2, 0xfae27299423fb9c3, - 0x97c560ba6b0919a5, 0xdccd879fc967d41a, - 0xbdb6b8e905cb600f, 0x5400e987bbc1c920, - 0xed246723473e3813, 0x290123e9aab23b68, - 0x9436c0760c86e30b, 0xf9a0b6720aaf6521, - 0xb94470938fa89bce, 0xf808e40e8d5b3e69, - 0xe7958cb87392c2c2, 0xb60b1d1230b20e04, - 0x90bd77f3483bb9b9, 0xb1c6f22b5e6f48c2, - 0xb4ecd5f01a4aa828, 0x1e38aeb6360b1af3, - 0xe2280b6c20dd5232, 0x25c6da63c38de1b0, - 0x8d590723948a535f, 0x579c487e5a38ad0e, - 0xb0af48ec79ace837, 0x2d835a9df0c6d851, - 0xdcdb1b2798182244, 0xf8e431456cf88e65, - 0x8a08f0f8bf0f156b, 0x1b8e9ecb641b58ff, - 0xac8b2d36eed2dac5, 0xe272467e3d222f3f, - 0xd7adf884aa879177, 0x5b0ed81dcc6abb0f, - 0x86ccbb52ea94baea, 0x98e947129fc2b4e9, - 0xa87fea27a539e9a5, 0x3f2398d747b36224, - 0xd29fe4b18e88640e, 0x8eec7f0d19a03aad, - 0x83a3eeeef9153e89, 0x1953cf68300424ac, - 0xa48ceaaab75a8e2b, 0x5fa8c3423c052dd7, - 0xcdb02555653131b6, 0x3792f412cb06794d, - 0x808e17555f3ebf11, 0xe2bbd88bbee40bd0, - 0xa0b19d2ab70e6ed6, 0x5b6aceaeae9d0ec4, - 0xc8de047564d20a8b, 0xf245825a5a445275, - 0xfb158592be068d2e, 0xeed6e2f0f0d56712, - 0x9ced737bb6c4183d, 0x55464dd69685606b, - 0xc428d05aa4751e4c, 0xaa97e14c3c26b886, - 0xf53304714d9265df, 0xd53dd99f4b3066a8, - 0x993fe2c6d07b7fab, 0xe546a8038efe4029, - 0xbf8fdb78849a5f96, 0xde98520472bdd033, - 0xef73d256a5c0f77c, 0x963e66858f6d4440, - 0x95a8637627989aad, 0xdde7001379a44aa8, - 0xbb127c53b17ec159, 0x5560c018580d5d52, - 0xe9d71b689dde71af, 0xaab8f01e6e10b4a6, - 0x9226712162ab070d, 0xcab3961304ca70e8, - 0xb6b00d69bb55c8d1, 0x3d607b97c5fd0d22, - 0xe45c10c42a2b3b05, 0x8cb89a7db77c506a, - 0x8eb98a7a9a5b04e3, 0x77f3608e92adb242, - 0xb267ed1940f1c61c, 0x55f038b237591ed3, - 0xdf01e85f912e37a3, 0x6b6c46dec52f6688, - 0x8b61313bbabce2c6, 0x2323ac4b3b3da015, - 0xae397d8aa96c1b77, 0xabec975e0a0d081a, - 0xd9c7dced53c72255, 0x96e7bd358c904a21, - 0x881cea14545c7575, 0x7e50d64177da2e54, - 0xaa242499697392d2, 0xdde50bd1d5d0b9e9, - 0xd4ad2dbfc3d07787, 0x955e4ec64b44e864, - 0x84ec3c97da624ab4, 0xbd5af13bef0b113e, - 0xa6274bbdd0fadd61, 0xecb1ad8aeacdd58e, - 0xcfb11ead453994ba, 0x67de18eda5814af2, - 0x81ceb32c4b43fcf4, 0x80eacf948770ced7, - 0xa2425ff75e14fc31, 0xa1258379a94d028d, - 0xcad2f7f5359a3b3e, 0x96ee45813a04330, - 0xfd87b5f28300ca0d, 0x8bca9d6e188853fc, - 0x9e74d1b791e07e48, 0x775ea264cf55347e, - 0xc612062576589dda, 0x95364afe032a81a0, - 0xf79687aed3eec551, 0x3a83ddbd83f52210, - 0x9abe14cd44753b52, 0xc4926a9672793580, - 0xc16d9a0095928a27, 0x75b7053c0f178400, - 0xf1c90080baf72cb1, 0x5324c68b12dd6800, - 0x971da05074da7bee, 0xd3f6fc16ebca8000, - 0xbce5086492111aea, 0x88f4bb1ca6bd0000, - 0xec1e4a7db69561a5, 0x2b31e9e3d0700000, - 0x9392ee8e921d5d07, 0x3aff322e62600000, - 0xb877aa3236a4b449, 0x9befeb9fad487c3, - 0xe69594bec44de15b, 0x4c2ebe687989a9b4, - 0x901d7cf73ab0acd9, 0xf9d37014bf60a11, - 0xb424dc35095cd80f, 0x538484c19ef38c95, - 0xe12e13424bb40e13, 0x2865a5f206b06fba, - 0x8cbccc096f5088cb, 0xf93f87b7442e45d4, - 0xafebff0bcb24aafe, 0xf78f69a51539d749, - 0xdbe6fecebdedd5be, 0xb573440e5a884d1c, - 0x89705f4136b4a597, 0x31680a88f8953031, - 0xabcc77118461cefc, 0xfdc20d2b36ba7c3e, - 0xd6bf94d5e57a42bc, 0x3d32907604691b4d, - 0x8637bd05af6c69b5, 0xa63f9a49c2c1b110, - 0xa7c5ac471b478423, 0xfcf80dc33721d54, - 0xd1b71758e219652b, 0xd3c36113404ea4a9, - 0x83126e978d4fdf3b, 0x645a1cac083126ea, - 0xa3d70a3d70a3d70a, 0x3d70a3d70a3d70a4, - 0xcccccccccccccccc, 0xcccccccccccccccd, - 0x8000000000000000, 0x0, - 0xa000000000000000, 0x0, - 0xc800000000000000, 0x0, - 0xfa00000000000000, 0x0, - 0x9c40000000000000, 0x0, - 0xc350000000000000, 0x0, - 0xf424000000000000, 0x0, - 0x9896800000000000, 0x0, - 0xbebc200000000000, 0x0, - 0xee6b280000000000, 0x0, - 0x9502f90000000000, 0x0, - 0xba43b74000000000, 0x0, - 0xe8d4a51000000000, 0x0, - 0x9184e72a00000000, 0x0, - 0xb5e620f480000000, 0x0, - 0xe35fa931a0000000, 0x0, - 0x8e1bc9bf04000000, 0x0, - 0xb1a2bc2ec5000000, 0x0, - 0xde0b6b3a76400000, 0x0, - 0x8ac7230489e80000, 0x0, - 0xad78ebc5ac620000, 0x0, - 0xd8d726b7177a8000, 0x0, - 0x878678326eac9000, 0x0, - 0xa968163f0a57b400, 0x0, - 0xd3c21bcecceda100, 0x0, - 0x84595161401484a0, 0x0, - 0xa56fa5b99019a5c8, 0x0, - 0xcecb8f27f4200f3a, 0x0, - 0x813f3978f8940984, 0x4000000000000000, - 0xa18f07d736b90be5, 0x5000000000000000, - 0xc9f2c9cd04674ede, 0xa400000000000000, - 0xfc6f7c4045812296, 0x4d00000000000000, - 0x9dc5ada82b70b59d, 0xf020000000000000, - 0xc5371912364ce305, 0x6c28000000000000, - 0xf684df56c3e01bc6, 0xc732000000000000, - 0x9a130b963a6c115c, 0x3c7f400000000000, - 0xc097ce7bc90715b3, 0x4b9f100000000000, - 0xf0bdc21abb48db20, 0x1e86d40000000000, - 0x96769950b50d88f4, 0x1314448000000000, - 0xbc143fa4e250eb31, 0x17d955a000000000, - 0xeb194f8e1ae525fd, 0x5dcfab0800000000, - 0x92efd1b8d0cf37be, 0x5aa1cae500000000, - 0xb7abc627050305ad, 0xf14a3d9e40000000, - 0xe596b7b0c643c719, 0x6d9ccd05d0000000, - 0x8f7e32ce7bea5c6f, 0xe4820023a2000000, - 0xb35dbf821ae4f38b, 0xdda2802c8a800000, - 0xe0352f62a19e306e, 0xd50b2037ad200000, - 0x8c213d9da502de45, 0x4526f422cc340000, - 0xaf298d050e4395d6, 0x9670b12b7f410000, - 0xdaf3f04651d47b4c, 0x3c0cdd765f114000, - 0x88d8762bf324cd0f, 0xa5880a69fb6ac800, - 0xab0e93b6efee0053, 0x8eea0d047a457a00, - 0xd5d238a4abe98068, 0x72a4904598d6d880, - 0x85a36366eb71f041, 0x47a6da2b7f864750, - 0xa70c3c40a64e6c51, 0x999090b65f67d924, - 0xd0cf4b50cfe20765, 0xfff4b4e3f741cf6d, - 0x82818f1281ed449f, 0xbff8f10e7a8921a4, - 0xa321f2d7226895c7, 0xaff72d52192b6a0d, - 0xcbea6f8ceb02bb39, 0x9bf4f8a69f764490, - 0xfee50b7025c36a08, 0x2f236d04753d5b4, - 0x9f4f2726179a2245, 0x1d762422c946590, - 0xc722f0ef9d80aad6, 0x424d3ad2b7b97ef5, - 0xf8ebad2b84e0d58b, 0xd2e0898765a7deb2, - 0x9b934c3b330c8577, 0x63cc55f49f88eb2f, - 0xc2781f49ffcfa6d5, 0x3cbf6b71c76b25fb, - 0xf316271c7fc3908a, 0x8bef464e3945ef7a, - 0x97edd871cfda3a56, 0x97758bf0e3cbb5ac, - 0xbde94e8e43d0c8ec, 0x3d52eeed1cbea317, - 0xed63a231d4c4fb27, 0x4ca7aaa863ee4bdd, - 0x945e455f24fb1cf8, 0x8fe8caa93e74ef6a, - 0xb975d6b6ee39e436, 0xb3e2fd538e122b44, - 0xe7d34c64a9c85d44, 0x60dbbca87196b616, - 0x90e40fbeea1d3a4a, 0xbc8955e946fe31cd, - 0xb51d13aea4a488dd, 0x6babab6398bdbe41, - 0xe264589a4dcdab14, 0xc696963c7eed2dd1, - 0x8d7eb76070a08aec, 0xfc1e1de5cf543ca2, - 0xb0de65388cc8ada8, 0x3b25a55f43294bcb, - 0xdd15fe86affad912, 0x49ef0eb713f39ebe, - 0x8a2dbf142dfcc7ab, 0x6e3569326c784337, - 0xacb92ed9397bf996, 0x49c2c37f07965404, - 0xd7e77a8f87daf7fb, 0xdc33745ec97be906, - 0x86f0ac99b4e8dafd, 0x69a028bb3ded71a3, - 0xa8acd7c0222311bc, 0xc40832ea0d68ce0c, - 0xd2d80db02aabd62b, 0xf50a3fa490c30190, - 0x83c7088e1aab65db, 0x792667c6da79e0fa, - 0xa4b8cab1a1563f52, 0x577001b891185938, - 0xcde6fd5e09abcf26, 0xed4c0226b55e6f86, - 0x80b05e5ac60b6178, 0x544f8158315b05b4, - 0xa0dc75f1778e39d6, 0x696361ae3db1c721, - 0xc913936dd571c84c, 0x3bc3a19cd1e38e9, - 0xfb5878494ace3a5f, 0x4ab48a04065c723, - 0x9d174b2dcec0e47b, 0x62eb0d64283f9c76, - 0xc45d1df942711d9a, 0x3ba5d0bd324f8394, - 0xf5746577930d6500, 0xca8f44ec7ee36479, - 0x9968bf6abbe85f20, 0x7e998b13cf4e1ecb, - 0xbfc2ef456ae276e8, 0x9e3fedd8c321a67e, - 0xefb3ab16c59b14a2, 0xc5cfe94ef3ea101e, - 0x95d04aee3b80ece5, 0xbba1f1d158724a12, - 0xbb445da9ca61281f, 0x2a8a6e45ae8edc97, - 0xea1575143cf97226, 0xf52d09d71a3293bd, - 0x924d692ca61be758, 0x593c2626705f9c56, - 0xb6e0c377cfa2e12e, 0x6f8b2fb00c77836c, - 0xe498f455c38b997a, 0xb6dfb9c0f956447, - 0x8edf98b59a373fec, 0x4724bd4189bd5eac, - 0xb2977ee300c50fe7, 0x58edec91ec2cb657, - 0xdf3d5e9bc0f653e1, 0x2f2967b66737e3ed, - 0x8b865b215899f46c, 0xbd79e0d20082ee74, - 0xae67f1e9aec07187, 0xecd8590680a3aa11, - 0xda01ee641a708de9, 0xe80e6f4820cc9495, - 0x884134fe908658b2, 0x3109058d147fdcdd, - 0xaa51823e34a7eede, 0xbd4b46f0599fd415, - 0xd4e5e2cdc1d1ea96, 0x6c9e18ac7007c91a, - 0x850fadc09923329e, 0x3e2cf6bc604ddb0, - 0xa6539930bf6bff45, 0x84db8346b786151c, - 0xcfe87f7cef46ff16, 0xe612641865679a63, - 0x81f14fae158c5f6e, 0x4fcb7e8f3f60c07e, - 0xa26da3999aef7749, 0xe3be5e330f38f09d, - 0xcb090c8001ab551c, 0x5cadf5bfd3072cc5, - 0xfdcb4fa002162a63, 0x73d9732fc7c8f7f6, - 0x9e9f11c4014dda7e, 0x2867e7fddcdd9afa, - 0xc646d63501a1511d, 0xb281e1fd541501b8, - 0xf7d88bc24209a565, 0x1f225a7ca91a4226, - 0x9ae757596946075f, 0x3375788de9b06958, - 0xc1a12d2fc3978937, 0x52d6b1641c83ae, - 0xf209787bb47d6b84, 0xc0678c5dbd23a49a, - 0x9745eb4d50ce6332, 0xf840b7ba963646e0, - 0xbd176620a501fbff, 0xb650e5a93bc3d898, - 0xec5d3fa8ce427aff, 0xa3e51f138ab4cebe, - 0x93ba47c980e98cdf, 0xc66f336c36b10137, - 0xb8a8d9bbe123f017, 0xb80b0047445d4184, - 0xe6d3102ad96cec1d, 0xa60dc059157491e5, - 0x9043ea1ac7e41392, 0x87c89837ad68db2f, - 0xb454e4a179dd1877, 0x29babe4598c311fb, - 0xe16a1dc9d8545e94, 0xf4296dd6fef3d67a, - 0x8ce2529e2734bb1d, 0x1899e4a65f58660c, - 0xb01ae745b101e9e4, 0x5ec05dcff72e7f8f, - 0xdc21a1171d42645d, 0x76707543f4fa1f73, - 0x899504ae72497eba, 0x6a06494a791c53a8, - 0xabfa45da0edbde69, 0x487db9d17636892, - 0xd6f8d7509292d603, 0x45a9d2845d3c42b6, - 0x865b86925b9bc5c2, 0xb8a2392ba45a9b2, - 0xa7f26836f282b732, 0x8e6cac7768d7141e, - 0xd1ef0244af2364ff, 0x3207d795430cd926, - 0x8335616aed761f1f, 0x7f44e6bd49e807b8, - 0xa402b9c5a8d3a6e7, 0x5f16206c9c6209a6, - 0xcd036837130890a1, 0x36dba887c37a8c0f, - 0x802221226be55a64, 0xc2494954da2c9789, - 0xa02aa96b06deb0fd, 0xf2db9baa10b7bd6c, - 0xc83553c5c8965d3d, 0x6f92829494e5acc7, - 0xfa42a8b73abbf48c, 0xcb772339ba1f17f9, - 0x9c69a97284b578d7, 0xff2a760414536efb, - 0xc38413cf25e2d70d, 0xfef5138519684aba, - 0xf46518c2ef5b8cd1, 0x7eb258665fc25d69, - 0x98bf2f79d5993802, 0xef2f773ffbd97a61, - 0xbeeefb584aff8603, 0xaafb550ffacfd8fa, - 0xeeaaba2e5dbf6784, 0x95ba2a53f983cf38, - 0x952ab45cfa97a0b2, 0xdd945a747bf26183, - 0xba756174393d88df, 0x94f971119aeef9e4, - 0xe912b9d1478ceb17, 0x7a37cd5601aab85d, - 0x91abb422ccb812ee, 0xac62e055c10ab33a, - 0xb616a12b7fe617aa, 0x577b986b314d6009, - 0xe39c49765fdf9d94, 0xed5a7e85fda0b80b, - 0x8e41ade9fbebc27d, 0x14588f13be847307, - 0xb1d219647ae6b31c, 0x596eb2d8ae258fc8, - 0xde469fbd99a05fe3, 0x6fca5f8ed9aef3bb, - 0x8aec23d680043bee, 0x25de7bb9480d5854, - 0xada72ccc20054ae9, 0xaf561aa79a10ae6a, - 0xd910f7ff28069da4, 0x1b2ba1518094da04, - 0x87aa9aff79042286, 0x90fb44d2f05d0842, - 0xa99541bf57452b28, 0x353a1607ac744a53, - 0xd3fa922f2d1675f2, 0x42889b8997915ce8, - 0x847c9b5d7c2e09b7, 0x69956135febada11, - 0xa59bc234db398c25, 0x43fab9837e699095, - 0xcf02b2c21207ef2e, 0x94f967e45e03f4bb, - 0x8161afb94b44f57d, 0x1d1be0eebac278f5, - 0xa1ba1ba79e1632dc, 0x6462d92a69731732, - 0xca28a291859bbf93, 0x7d7b8f7503cfdcfe, - 0xfcb2cb35e702af78, 0x5cda735244c3d43e, - 0x9defbf01b061adab, 0x3a0888136afa64a7, - 0xc56baec21c7a1916, 0x88aaa1845b8fdd0, - 0xf6c69a72a3989f5b, 0x8aad549e57273d45, - 0x9a3c2087a63f6399, 0x36ac54e2f678864b, - 0xc0cb28a98fcf3c7f, 0x84576a1bb416a7dd, - 0xf0fdf2d3f3c30b9f, 0x656d44a2a11c51d5, - 0x969eb7c47859e743, 0x9f644ae5a4b1b325, - 0xbc4665b596706114, 0x873d5d9f0dde1fee, - 0xeb57ff22fc0c7959, 0xa90cb506d155a7ea, - 0x9316ff75dd87cbd8, 0x9a7f12442d588f2, - 0xb7dcbf5354e9bece, 0xc11ed6d538aeb2f, - 0xe5d3ef282a242e81, 0x8f1668c8a86da5fa, - 0x8fa475791a569d10, 0xf96e017d694487bc, - 0xb38d92d760ec4455, 0x37c981dcc395a9ac, - 0xe070f78d3927556a, 0x85bbe253f47b1417, - 0x8c469ab843b89562, 0x93956d7478ccec8e, - 0xaf58416654a6babb, 0x387ac8d1970027b2, - 0xdb2e51bfe9d0696a, 0x6997b05fcc0319e, - 0x88fcf317f22241e2, 0x441fece3bdf81f03, - 0xab3c2fddeeaad25a, 0xd527e81cad7626c3, - 0xd60b3bd56a5586f1, 0x8a71e223d8d3b074, - 0x85c7056562757456, 0xf6872d5667844e49, - 0xa738c6bebb12d16c, 0xb428f8ac016561db, - 0xd106f86e69d785c7, 0xe13336d701beba52, - 0x82a45b450226b39c, 0xecc0024661173473, - 0xa34d721642b06084, 0x27f002d7f95d0190, - 0xcc20ce9bd35c78a5, 0x31ec038df7b441f4, - 0xff290242c83396ce, 0x7e67047175a15271, - 0x9f79a169bd203e41, 0xf0062c6e984d386, - 0xc75809c42c684dd1, 0x52c07b78a3e60868, - 0xf92e0c3537826145, 0xa7709a56ccdf8a82, - 0x9bbcc7a142b17ccb, 0x88a66076400bb691, - 0xc2abf989935ddbfe, 0x6acff893d00ea435, - 0xf356f7ebf83552fe, 0x583f6b8c4124d43, - 0x98165af37b2153de, 0xc3727a337a8b704a, - 0xbe1bf1b059e9a8d6, 0x744f18c0592e4c5c, - 0xeda2ee1c7064130c, 0x1162def06f79df73, - 0x9485d4d1c63e8be7, 0x8addcb5645ac2ba8, - 0xb9a74a0637ce2ee1, 0x6d953e2bd7173692, - 0xe8111c87c5c1ba99, 0xc8fa8db6ccdd0437, - 0x910ab1d4db9914a0, 0x1d9c9892400a22a2, - 0xb54d5e4a127f59c8, 0x2503beb6d00cab4b, - 0xe2a0b5dc971f303a, 0x2e44ae64840fd61d, - 0x8da471a9de737e24, 0x5ceaecfed289e5d2, - 0xb10d8e1456105dad, 0x7425a83e872c5f47, - 0xdd50f1996b947518, 0xd12f124e28f77719, - 0x8a5296ffe33cc92f, 0x82bd6b70d99aaa6f, - 0xace73cbfdc0bfb7b, 0x636cc64d1001550b, - 0xd8210befd30efa5a, 0x3c47f7e05401aa4e, - 0x8714a775e3e95c78, 0x65acfaec34810a71, - 0xa8d9d1535ce3b396, 0x7f1839a741a14d0d, - 0xd31045a8341ca07c, 0x1ede48111209a050, - 0x83ea2b892091e44d, 0x934aed0aab460432, - 0xa4e4b66b68b65d60, 0xf81da84d5617853f, - 0xce1de40642e3f4b9, 0x36251260ab9d668e, - 0x80d2ae83e9ce78f3, 0xc1d72b7c6b426019, - 0xa1075a24e4421730, 0xb24cf65b8612f81f, - 0xc94930ae1d529cfc, 0xdee033f26797b627, - 0xfb9b7cd9a4a7443c, 0x169840ef017da3b1, - 0x9d412e0806e88aa5, 0x8e1f289560ee864e, - 0xc491798a08a2ad4e, 0xf1a6f2bab92a27e2, - 0xf5b5d7ec8acb58a2, 0xae10af696774b1db, - 0x9991a6f3d6bf1765, 0xacca6da1e0a8ef29, - 0xbff610b0cc6edd3f, 0x17fd090a58d32af3, - 0xeff394dcff8a948e, 0xddfc4b4cef07f5b0, - 0x95f83d0a1fb69cd9, 0x4abdaf101564f98e, - 0xbb764c4ca7a4440f, 0x9d6d1ad41abe37f1, - 0xea53df5fd18d5513, 0x84c86189216dc5ed, - 0x92746b9be2f8552c, 0x32fd3cf5b4e49bb4, - 0xb7118682dbb66a77, 0x3fbc8c33221dc2a1, - 0xe4d5e82392a40515, 0xfabaf3feaa5334a, - 0x8f05b1163ba6832d, 0x29cb4d87f2a7400e, - 0xb2c71d5bca9023f8, 0x743e20e9ef511012, - 0xdf78e4b2bd342cf6, 0x914da9246b255416, - 0x8bab8eefb6409c1a, 0x1ad089b6c2f7548e, - 0xae9672aba3d0c320, 0xa184ac2473b529b1, - 0xda3c0f568cc4f3e8, 0xc9e5d72d90a2741e, - 0x8865899617fb1871, 0x7e2fa67c7a658892, - 0xaa7eebfb9df9de8d, 0xddbb901b98feeab7, - 0xd51ea6fa85785631, 0x552a74227f3ea565, - 0x8533285c936b35de, 0xd53a88958f87275f, - 0xa67ff273b8460356, 0x8a892abaf368f137, - 0xd01fef10a657842c, 0x2d2b7569b0432d85, - 0x8213f56a67f6b29b, 0x9c3b29620e29fc73, - 0xa298f2c501f45f42, 0x8349f3ba91b47b8f, - 0xcb3f2f7642717713, 0x241c70a936219a73, - 0xfe0efb53d30dd4d7, 0xed238cd383aa0110, - 0x9ec95d1463e8a506, 0xf4363804324a40aa, - 0xc67bb4597ce2ce48, 0xb143c6053edcd0d5, - 0xf81aa16fdc1b81da, 0xdd94b7868e94050a, - 0x9b10a4e5e9913128, 0xca7cf2b4191c8326, - 0xc1d4ce1f63f57d72, 0xfd1c2f611f63a3f0, - 0xf24a01a73cf2dccf, 0xbc633b39673c8cec, - 0x976e41088617ca01, 0xd5be0503e085d813, - 0xbd49d14aa79dbc82, 0x4b2d8644d8a74e18, - 0xec9c459d51852ba2, 0xddf8e7d60ed1219e, - 0x93e1ab8252f33b45, 0xcabb90e5c942b503, - 0xb8da1662e7b00a17, 0x3d6a751f3b936243, - 0xe7109bfba19c0c9d, 0xcc512670a783ad4, - 0x906a617d450187e2, 0x27fb2b80668b24c5, - 0xb484f9dc9641e9da, 0xb1f9f660802dedf6, - 0xe1a63853bbd26451, 0x5e7873f8a0396973, - 0x8d07e33455637eb2, 0xdb0b487b6423e1e8, - 0xb049dc016abc5e5f, 0x91ce1a9a3d2cda62, - 0xdc5c5301c56b75f7, 0x7641a140cc7810fb, - 0x89b9b3e11b6329ba, 0xa9e904c87fcb0a9d, - 0xac2820d9623bf429, 0x546345fa9fbdcd44, - 0xd732290fbacaf133, 0xa97c177947ad4095, - 0x867f59a9d4bed6c0, 0x49ed8eabcccc485d, - 0xa81f301449ee8c70, 0x5c68f256bfff5a74, - 0xd226fc195c6a2f8c, 0x73832eec6fff3111, - 0x83585d8fd9c25db7, 0xc831fd53c5ff7eab, - 0xa42e74f3d032f525, 0xba3e7ca8b77f5e55, - 0xcd3a1230c43fb26f, 0x28ce1bd2e55f35eb, - 0x80444b5e7aa7cf85, 0x7980d163cf5b81b3, - 0xa0555e361951c366, 0xd7e105bcc332621f, - 0xc86ab5c39fa63440, 0x8dd9472bf3fefaa7, - 0xfa856334878fc150, 0xb14f98f6f0feb951, - 0x9c935e00d4b9d8d2, 0x6ed1bf9a569f33d3, - 0xc3b8358109e84f07, 0xa862f80ec4700c8, - 0xf4a642e14c6262c8, 0xcd27bb612758c0fa, - 0x98e7e9cccfbd7dbd, 0x8038d51cb897789c, - 0xbf21e44003acdd2c, 0xe0470a63e6bd56c3, - 0xeeea5d5004981478, 0x1858ccfce06cac74, - 0x95527a5202df0ccb, 0xf37801e0c43ebc8, - 0xbaa718e68396cffd, 0xd30560258f54e6ba, - 0xe950df20247c83fd, 0x47c6b82ef32a2069, - 0x91d28b7416cdd27e, 0x4cdc331d57fa5441, - 0xb6472e511c81471d, 0xe0133fe4adf8e952, - 0xe3d8f9e563a198e5, 0x58180fddd97723a6, - 0x8e679c2f5e44ff8f, 0x570f09eaa7ea7648, -}; - -} // namespace internal -} // namespace simdjson -/* end file src/internal/numberparsing_tables.cpp */ -/* begin file src/internal/simdprune_tables.cpp */ -#if SIMDJSON_IMPLEMENTATION_ARM64 || SIMDJSON_IMPLEMENTATION_HASWELL || \ - SIMDJSON_IMPLEMENTATION_WESTMERE || SIMDJSON_IMPLEMENTATION_PPC64 - -#include - -namespace simdjson { // table modified and copied from -namespace internal { // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetTable -SIMDJSON_DLLIMPORTEXPORT const unsigned char BitsSetTable256mul2[256] = { - 0, 2, 2, 4, 2, 4, 4, 6, 2, 4, 4, 6, 4, 6, 6, 8, 2, 4, 4, - 6, 4, 6, 6, 8, 4, 6, 6, 8, 6, 8, 8, 10, 2, 4, 4, 6, 4, 6, - 6, 8, 4, 6, 6, 8, 6, 8, 8, 10, 4, 6, 6, 8, 6, 8, 8, 10, 6, - 8, 8, 10, 8, 10, 10, 12, 2, 4, 4, 6, 4, 6, 6, 8, 4, 6, 6, 8, - 6, 8, 8, 10, 4, 6, 6, 8, 6, 8, 8, 10, 6, 8, 8, 10, 8, 10, 10, - 12, 4, 6, 6, 8, 6, 8, 8, 10, 6, 8, 8, 10, 8, 10, 10, 12, 6, 8, - 8, 10, 8, 10, 10, 12, 8, 10, 10, 12, 10, 12, 12, 14, 2, 4, 4, 6, 4, - 6, 6, 8, 4, 6, 6, 8, 6, 8, 8, 10, 4, 6, 6, 8, 6, 8, 8, 10, - 6, 8, 8, 10, 8, 10, 10, 12, 4, 6, 6, 8, 6, 8, 8, 10, 6, 8, 8, - 10, 8, 10, 10, 12, 6, 8, 8, 10, 8, 10, 10, 12, 8, 10, 10, 12, 10, 12, - 12, 14, 4, 6, 6, 8, 6, 8, 8, 10, 6, 8, 8, 10, 8, 10, 10, 12, 6, - 8, 8, 10, 8, 10, 10, 12, 8, 10, 10, 12, 10, 12, 12, 14, 6, 8, 8, 10, - 8, 10, 10, 12, 8, 10, 10, 12, 10, 12, 12, 14, 8, 10, 10, 12, 10, 12, 12, - 14, 10, 12, 12, 14, 12, 14, 14, 16}; - -SIMDJSON_DLLIMPORTEXPORT const uint8_t pshufb_combine_table[272] = { - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, - 0x0c, 0x0d, 0x0e, 0x0f, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x08, - 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0xff, 0x00, 0x01, 0x02, 0x03, - 0x04, 0x05, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0xff, 0xff, - 0x00, 0x01, 0x02, 0x03, 0x04, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, - 0x0f, 0xff, 0xff, 0xff, 0x00, 0x01, 0x02, 0x03, 0x08, 0x09, 0x0a, 0x0b, - 0x0c, 0x0d, 0x0e, 0x0f, 0xff, 0xff, 0xff, 0xff, 0x00, 0x01, 0x02, 0x08, - 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0xff, 0xff, 0xff, 0xff, 0xff, - 0x00, 0x01, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0x00, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, - 0x0f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x08, 0x09, 0x0a, 0x0b, - 0x0c, 0x0d, 0x0e, 0x0f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, -}; - -// 256 * 8 bytes = 2kB, easily fits in cache. -SIMDJSON_DLLIMPORTEXPORT const uint64_t thintable_epi8[256] = { - 0x0706050403020100, 0x0007060504030201, 0x0007060504030200, - 0x0000070605040302, 0x0007060504030100, 0x0000070605040301, - 0x0000070605040300, 0x0000000706050403, 0x0007060504020100, - 0x0000070605040201, 0x0000070605040200, 0x0000000706050402, - 0x0000070605040100, 0x0000000706050401, 0x0000000706050400, - 0x0000000007060504, 0x0007060503020100, 0x0000070605030201, - 0x0000070605030200, 0x0000000706050302, 0x0000070605030100, - 0x0000000706050301, 0x0000000706050300, 0x0000000007060503, - 0x0000070605020100, 0x0000000706050201, 0x0000000706050200, - 0x0000000007060502, 0x0000000706050100, 0x0000000007060501, - 0x0000000007060500, 0x0000000000070605, 0x0007060403020100, - 0x0000070604030201, 0x0000070604030200, 0x0000000706040302, - 0x0000070604030100, 0x0000000706040301, 0x0000000706040300, - 0x0000000007060403, 0x0000070604020100, 0x0000000706040201, - 0x0000000706040200, 0x0000000007060402, 0x0000000706040100, - 0x0000000007060401, 0x0000000007060400, 0x0000000000070604, - 0x0000070603020100, 0x0000000706030201, 0x0000000706030200, - 0x0000000007060302, 0x0000000706030100, 0x0000000007060301, - 0x0000000007060300, 0x0000000000070603, 0x0000000706020100, - 0x0000000007060201, 0x0000000007060200, 0x0000000000070602, - 0x0000000007060100, 0x0000000000070601, 0x0000000000070600, - 0x0000000000000706, 0x0007050403020100, 0x0000070504030201, - 0x0000070504030200, 0x0000000705040302, 0x0000070504030100, - 0x0000000705040301, 0x0000000705040300, 0x0000000007050403, - 0x0000070504020100, 0x0000000705040201, 0x0000000705040200, - 0x0000000007050402, 0x0000000705040100, 0x0000000007050401, - 0x0000000007050400, 0x0000000000070504, 0x0000070503020100, - 0x0000000705030201, 0x0000000705030200, 0x0000000007050302, - 0x0000000705030100, 0x0000000007050301, 0x0000000007050300, - 0x0000000000070503, 0x0000000705020100, 0x0000000007050201, - 0x0000000007050200, 0x0000000000070502, 0x0000000007050100, - 0x0000000000070501, 0x0000000000070500, 0x0000000000000705, - 0x0000070403020100, 0x0000000704030201, 0x0000000704030200, - 0x0000000007040302, 0x0000000704030100, 0x0000000007040301, - 0x0000000007040300, 0x0000000000070403, 0x0000000704020100, - 0x0000000007040201, 0x0000000007040200, 0x0000000000070402, - 0x0000000007040100, 0x0000000000070401, 0x0000000000070400, - 0x0000000000000704, 0x0000000703020100, 0x0000000007030201, - 0x0000000007030200, 0x0000000000070302, 0x0000000007030100, - 0x0000000000070301, 0x0000000000070300, 0x0000000000000703, - 0x0000000007020100, 0x0000000000070201, 0x0000000000070200, - 0x0000000000000702, 0x0000000000070100, 0x0000000000000701, - 0x0000000000000700, 0x0000000000000007, 0x0006050403020100, - 0x0000060504030201, 0x0000060504030200, 0x0000000605040302, - 0x0000060504030100, 0x0000000605040301, 0x0000000605040300, - 0x0000000006050403, 0x0000060504020100, 0x0000000605040201, - 0x0000000605040200, 0x0000000006050402, 0x0000000605040100, - 0x0000000006050401, 0x0000000006050400, 0x0000000000060504, - 0x0000060503020100, 0x0000000605030201, 0x0000000605030200, - 0x0000000006050302, 0x0000000605030100, 0x0000000006050301, - 0x0000000006050300, 0x0000000000060503, 0x0000000605020100, - 0x0000000006050201, 0x0000000006050200, 0x0000000000060502, - 0x0000000006050100, 0x0000000000060501, 0x0000000000060500, - 0x0000000000000605, 0x0000060403020100, 0x0000000604030201, - 0x0000000604030200, 0x0000000006040302, 0x0000000604030100, - 0x0000000006040301, 0x0000000006040300, 0x0000000000060403, - 0x0000000604020100, 0x0000000006040201, 0x0000000006040200, - 0x0000000000060402, 0x0000000006040100, 0x0000000000060401, - 0x0000000000060400, 0x0000000000000604, 0x0000000603020100, - 0x0000000006030201, 0x0000000006030200, 0x0000000000060302, - 0x0000000006030100, 0x0000000000060301, 0x0000000000060300, - 0x0000000000000603, 0x0000000006020100, 0x0000000000060201, - 0x0000000000060200, 0x0000000000000602, 0x0000000000060100, - 0x0000000000000601, 0x0000000000000600, 0x0000000000000006, - 0x0000050403020100, 0x0000000504030201, 0x0000000504030200, - 0x0000000005040302, 0x0000000504030100, 0x0000000005040301, - 0x0000000005040300, 0x0000000000050403, 0x0000000504020100, - 0x0000000005040201, 0x0000000005040200, 0x0000000000050402, - 0x0000000005040100, 0x0000000000050401, 0x0000000000050400, - 0x0000000000000504, 0x0000000503020100, 0x0000000005030201, - 0x0000000005030200, 0x0000000000050302, 0x0000000005030100, - 0x0000000000050301, 0x0000000000050300, 0x0000000000000503, - 0x0000000005020100, 0x0000000000050201, 0x0000000000050200, - 0x0000000000000502, 0x0000000000050100, 0x0000000000000501, - 0x0000000000000500, 0x0000000000000005, 0x0000000403020100, - 0x0000000004030201, 0x0000000004030200, 0x0000000000040302, - 0x0000000004030100, 0x0000000000040301, 0x0000000000040300, - 0x0000000000000403, 0x0000000004020100, 0x0000000000040201, - 0x0000000000040200, 0x0000000000000402, 0x0000000000040100, - 0x0000000000000401, 0x0000000000000400, 0x0000000000000004, - 0x0000000003020100, 0x0000000000030201, 0x0000000000030200, - 0x0000000000000302, 0x0000000000030100, 0x0000000000000301, - 0x0000000000000300, 0x0000000000000003, 0x0000000000020100, - 0x0000000000000201, 0x0000000000000200, 0x0000000000000002, - 0x0000000000000100, 0x0000000000000001, 0x0000000000000000, - 0x0000000000000000, -}; // static uint64_t thintable_epi8[256] - -} // namespace internal -} // namespace simdjson - -#endif // SIMDJSON_IMPLEMENTATION_ARM64 || SIMDJSON_IMPLEMENTATION_HASWELL || - // SIMDJSON_IMPLEMENTATION_WESTMERE || SIMDJSON_IMPLEMENTATION_PPC64 -/* end file src/internal/simdprune_tables.cpp */ -/* begin file src/implementation.cpp */ -#include - -namespace simdjson { - -bool implementation::supported_by_runtime_system() const { - uint32_t required_instruction_sets = this->required_instruction_sets(); - uint32_t supported_instruction_sets = - internal::detect_supported_architectures(); - return ((supported_instruction_sets & required_instruction_sets) == - required_instruction_sets); -} - -namespace internal { - -// Static array of known implementations. We're hoping these get baked into the -// executable -// without requiring a static initializer. - -#if SIMDJSON_IMPLEMENTATION_HASWELL -static const haswell::implementation *get_haswell_singleton() { - static const haswell::implementation haswell_singleton{}; - return &haswell_singleton; -} -#endif -#if SIMDJSON_IMPLEMENTATION_WESTMERE -static const westmere::implementation *get_westmere_singleton() { - static const westmere::implementation westmere_singleton{}; - return &westmere_singleton; -} -#endif // SIMDJSON_IMPLEMENTATION_WESTMERE -#if SIMDJSON_IMPLEMENTATION_ARM64 -static const arm64::implementation *get_arm64_singleton() { - static const arm64::implementation arm64_singleton{}; - return &arm64_singleton; -} -#endif // SIMDJSON_IMPLEMENTATION_ARM64 -#if SIMDJSON_IMPLEMENTATION_PPC64 -static const ppc64::implementation *get_ppc64_singleton() { - static const ppc64::implementation ppc64_singleton{}; - return &ppc64_singleton; -} -#endif // SIMDJSON_IMPLEMENTATION_PPC64 -#if SIMDJSON_IMPLEMENTATION_FALLBACK -static const fallback::implementation *get_fallback_singleton() { - static const fallback::implementation fallback_singleton{}; - return &fallback_singleton; -} -#endif // SIMDJSON_IMPLEMENTATION_FALLBACK - -/** - * @private Detects best supported implementation on first use, and sets it - */ -class detect_best_supported_implementation_on_first_use final - : public implementation { - public: - const std::string &name() const noexcept final { - return set_best()->name(); - } - const std::string &description() const noexcept final { - return set_best()->description(); - } - uint32_t required_instruction_sets() const noexcept final { - return set_best()->required_instruction_sets(); - } - simdjson_warn_unused error_code create_dom_parser_implementation( - size_t capacity, - size_t max_length, - std::unique_ptr &dst) const - noexcept final { - return set_best()->create_dom_parser_implementation( - capacity, max_length, dst); - } - simdjson_warn_unused error_code - minify(const uint8_t *buf, size_t len, uint8_t *dst, size_t &dst_len) const - noexcept final { - return set_best()->minify(buf, len, dst, dst_len); - } - simdjson_warn_unused bool validate_utf8(const char *buf, size_t len) const - noexcept final override { - return set_best()->validate_utf8(buf, len); - } - simdjson_really_inline - detect_best_supported_implementation_on_first_use() noexcept - : implementation( - "best_supported_detector", - "Detects the best supported implementation and sets it", - 0) {} - - private: - const implementation *set_best() const noexcept; -}; - -static const std::initializer_list - &get_available_implementation_pointers() { - static const std::initializer_list - available_implementation_pointers { -#if SIMDJSON_IMPLEMENTATION_HASWELL - get_haswell_singleton(), -#endif -#if SIMDJSON_IMPLEMENTATION_WESTMERE - get_westmere_singleton(), -#endif -#if SIMDJSON_IMPLEMENTATION_ARM64 - get_arm64_singleton(), -#endif -#if SIMDJSON_IMPLEMENTATION_PPC64 - get_ppc64_singleton(), -#endif -#if SIMDJSON_IMPLEMENTATION_FALLBACK - get_fallback_singleton(), -#endif - }; // available_implementation_pointers - return available_implementation_pointers; -} - -// So we can return UNSUPPORTED_ARCHITECTURE from the parser when there is no -// support -class unsupported_implementation final : public implementation { - public: - simdjson_warn_unused error_code create_dom_parser_implementation( - size_t, - size_t, - std::unique_ptr &) const - noexcept final { - return UNSUPPORTED_ARCHITECTURE; - } - simdjson_warn_unused error_code - minify(const uint8_t *, size_t, uint8_t *, size_t &) const - noexcept final override { - return UNSUPPORTED_ARCHITECTURE; - } - simdjson_warn_unused bool validate_utf8(const char *, size_t) const - noexcept final override { - return false; // Just refuse to validate. Given that we have a fallback - // implementation - // it seems unlikely that unsupported_implementation will ever be used. - // If it is used, - // then it will flag all strings as invalid. The alternative is to - // return an error_code - // from which the user has to figure out whether the string is valid - // UTF-8... which seems - // like a lot of work just to handle the very unlikely case that we have - // an unsupported - // implementation. And, when it does happen (that we have an unsupported - // implementation), - // what are the chances that the programmer has a fallback? Given that - // *we* provide the - // fallback, it implies that the programmer would need a fallback for - // our fallback. - } - unsupported_implementation() - : implementation("unsupported", - "Unsupported CPU (no detected SIMD instructions)", - 0) {} -}; - -const unsupported_implementation *get_unsupported_singleton() { - static const unsupported_implementation unsupported_singleton{}; - return &unsupported_singleton; -} - -size_t available_implementation_list::size() const noexcept { - return internal::get_available_implementation_pointers().size(); -} -const implementation *const *available_implementation_list::begin() const - noexcept { - return internal::get_available_implementation_pointers().begin(); -} -const implementation *const *available_implementation_list::end() const - noexcept { - return internal::get_available_implementation_pointers().end(); -} -const implementation *available_implementation_list::detect_best_supported() - const noexcept { - // They are prelisted in priority order, so we just go down the list - uint32_t supported_instruction_sets = - internal::detect_supported_architectures(); - for (const implementation *impl : - internal::get_available_implementation_pointers()) { - uint32_t required_instruction_sets = impl->required_instruction_sets(); - if ((supported_instruction_sets & required_instruction_sets) == - required_instruction_sets) { - return impl; - } - } - return get_unsupported_singleton(); // this should never happen? -} - -const implementation * -detect_best_supported_implementation_on_first_use::set_best() const noexcept { - SIMDJSON_PUSH_DISABLE_WARNINGS - SIMDJSON_DISABLE_DEPRECATED_WARNING // Disable CRT_SECURE warning on MSVC: - // manually verified this is safe - char *force_implementation_name = - getenv("SIMDJSON_FORCE_IMPLEMENTATION"); - SIMDJSON_POP_DISABLE_WARNINGS - - if (force_implementation_name) { - auto force_implementation = - get_available_implementations()[force_implementation_name]; - if (force_implementation) { - return get_active_implementation() = force_implementation; - } else { - // Note: abort() and stderr usage within the library is forbidden. - return get_active_implementation() = get_unsupported_singleton(); - } - } - return get_active_implementation() = - get_available_implementations().detect_best_supported(); -} - -} // namespace internal - -SIMDJSON_DLLIMPORTEXPORT const internal::available_implementation_list & -get_available_implementations() { - static const internal::available_implementation_list - available_implementations{}; - return available_implementations; -} - -SIMDJSON_DLLIMPORTEXPORT internal::atomic_ptr - &get_active_implementation() { - static const internal::detect_best_supported_implementation_on_first_use - detect_best_supported_implementation_on_first_use_singleton; - static internal::atomic_ptr active_implementation{ - &detect_best_supported_implementation_on_first_use_singleton}; - return active_implementation; -} - -simdjson_warn_unused error_code minify(const char *buf, - size_t len, - char *dst, - size_t &dst_len) noexcept { - return get_active_implementation()->minify( - reinterpret_cast(buf), - len, - reinterpret_cast(dst), - dst_len); -} -simdjson_warn_unused bool validate_utf8(const char *buf, size_t len) noexcept { - return get_active_implementation()->validate_utf8(buf, len); -} - -const implementation *builtin_implementation() { - static const implementation *builtin_impl = - get_available_implementations()[SIMDJSON_STRINGIFY( - SIMDJSON_BUILTIN_IMPLEMENTATION)]; - assert(builtin_impl); - return builtin_impl; -} - - -} // namespace simdjson -/* end file src/implementation.cpp */ - -#if SIMDJSON_IMPLEMENTATION_ARM64 -/* begin file src/arm64/implementation.cpp */ -/* begin file include/simdjson/arm64/begin.h */ -// redefining SIMDJSON_IMPLEMENTATION to "arm64" -// #define SIMDJSON_IMPLEMENTATION arm64 -/* end file include/simdjson/arm64/begin.h */ - -namespace simdjson { -namespace arm64 { - -simdjson_warn_unused error_code -implementation::create_dom_parser_implementation( - size_t capacity, - size_t max_depth, - std::unique_ptr &dst) const noexcept { - dst.reset(new (std::nothrow) dom_parser_implementation()); - if (!dst) { - return MEMALLOC; - } - if (auto err = dst->set_capacity(capacity)) return err; - if (auto err = dst->set_max_depth(max_depth)) return err; - return SUCCESS; -} - -} // namespace arm64 -} // namespace simdjson - -/* begin file include/simdjson/arm64/end.h */ -/* end file include/simdjson/arm64/end.h */ -/* end file src/arm64/implementation.cpp */ -/* begin file src/arm64/dom_parser_implementation.cpp */ -/* begin file include/simdjson/arm64/begin.h */ -// redefining SIMDJSON_IMPLEMENTATION to "arm64" -// #define SIMDJSON_IMPLEMENTATION arm64 -/* end file include/simdjson/arm64/begin.h */ - -// -// Stage 1 -// -namespace simdjson { -namespace arm64 { -namespace { - -using namespace simd; - -struct json_character_block { - static simdjson_really_inline json_character_block - classify(const simd::simd8x64 &in); - - simdjson_really_inline uint64_t whitespace() const noexcept { - return _whitespace; - } - simdjson_really_inline uint64_t op() const noexcept { return _op; } - simdjson_really_inline uint64_t scalar() const noexcept { - return ~(op() | whitespace()); - } - - uint64_t _whitespace; - uint64_t _op; -}; - -simdjson_really_inline json_character_block -json_character_block::classify(const simd::simd8x64 &in) { - // Functional programming causes trouble with Visual Studio. - // Keeping this version in comments since it is much nicer: - // auto v = in.map([&](simd8 chunk) { - // auto nib_lo = chunk & 0xf; - // auto nib_hi = chunk.shr<4>(); - // auto shuf_lo = nib_lo.lookup_16(16, 0, 0, 0, 0, 0, 0, 0, 0, 8, - // 12, 1, 2, 9, 0, 0); - // auto shuf_hi = nib_hi.lookup_16(8, 0, 18, 4, 0, 1, 0, 1, 0, 0, - // 0, 3, 2, 1, 0, 0); - // return shuf_lo & shuf_hi; - // }); - const simd8 table1( - 16, 0, 0, 0, 0, 0, 0, 0, 0, 8, 12, 1, 2, 9, 0, 0); - const simd8 table2( - 8, 0, 18, 4, 0, 1, 0, 1, 0, 0, 0, 3, 2, 1, 0, 0); - - simd8x64 v((in.chunks[0] & 0xf).lookup_16(table1) & - (in.chunks[0].shr<4>()).lookup_16(table2), - (in.chunks[1] & 0xf).lookup_16(table1) & - (in.chunks[1].shr<4>()).lookup_16(table2), - (in.chunks[2] & 0xf).lookup_16(table1) & - (in.chunks[2].shr<4>()).lookup_16(table2), - (in.chunks[3] & 0xf).lookup_16(table1) & - (in.chunks[3].shr<4>()).lookup_16(table2)); - - - // We compute whitespace and op separately. If the code later only use one - // or the - // other, given the fact that all functions are aggressively inlined, we can - // hope that useless computations will be omitted. This is namely case when - // minifying (we only need whitespace). *However* if we only need spaces, - // it is likely that we will still compute 'v' above with two lookup_16: one - // could do it a bit cheaper. This is in contrast with the x64 - // implementations - // where we can, efficiently, do the white space and structural matching - // separately. One reason for this difference is that on ARM NEON, the table - // lookups either zero or leave unchanged the characters exceeding 0xF - // whereas - // on x64, the equivalent instruction (pshufb) automatically applies a mask, - // ignoring the 4 most significant bits. Thus the x64 implementation is - // optimized differently. This being said, if you use this code strictly - // just for minification (or just to identify the structural characters), - // there is a small untaken optimization opportunity here. We deliberately - // do not pick it up. - - uint64_t op = simd8x64(v.chunks[0].any_bits_set(0x7), - v.chunks[1].any_bits_set(0x7), - v.chunks[2].any_bits_set(0x7), - v.chunks[3].any_bits_set(0x7)) - .to_bitmask(); - - uint64_t whitespace = simd8x64(v.chunks[0].any_bits_set(0x18), - v.chunks[1].any_bits_set(0x18), - v.chunks[2].any_bits_set(0x18), - v.chunks[3].any_bits_set(0x18)) - .to_bitmask(); - - return {whitespace, op}; -} - -simdjson_really_inline bool is_ascii(const simd8x64 &input) { - simd8 bits = input.reduce_or(); - return bits.max_val() < 0b10000000u; -} - -simdjson_unused simdjson_really_inline simd8 must_be_continuation( - const simd8 prev1, - const simd8 prev2, - const simd8 prev3) { - simd8 is_second_byte = prev1 >= uint8_t(0b11000000u); - simd8 is_third_byte = prev2 >= uint8_t(0b11100000u); - simd8 is_fourth_byte = prev3 >= uint8_t(0b11110000u); - // Use ^ instead of | for is_*_byte, because ^ is commutative, and the - // caller is using ^ as well. - // This will work fine because we only have to report errors for cases with - // 0-1 lead bytes. - // Multiple lead bytes implies 2 overlapping multibyte characters, and if - // that happens, there is - // guaranteed to be at least *one* lead byte that is part of only 1 other - // multibyte character. - // The error will be detected there. - return is_second_byte ^ is_third_byte ^ is_fourth_byte; -} - -simdjson_really_inline simd8 must_be_2_3_continuation( - const simd8 prev2, const simd8 prev3) { - simd8 is_third_byte = prev2 >= uint8_t(0b11100000u); - simd8 is_fourth_byte = prev3 >= uint8_t(0b11110000u); - return is_third_byte ^ is_fourth_byte; -} - -} // unnamed namespace -} // namespace arm64 -} // namespace simdjson - -/* begin file src/generic/stage1/utf8_lookup4_algorithm.h */ -namespace simdjson { -namespace arm64 { -namespace { -namespace utf8_validation { - -using namespace simd; - -simdjson_really_inline simd8 check_special_cases( - const simd8 input, const simd8 prev1) { - // Bit 0 = Too Short (lead byte/ASCII followed by lead byte/ASCII) - // Bit 1 = Too Long (ASCII followed by continuation) - // Bit 2 = Overlong 3-byte - // Bit 4 = Surrogate - // Bit 5 = Overlong 2-byte - // Bit 7 = Two Continuations - constexpr const uint8_t TOO_SHORT = 1 << 0; // 11______ 0_______ - // 11______ 11______ - constexpr const uint8_t TOO_LONG = 1 << 1; // 0_______ 10______ - constexpr const uint8_t OVERLONG_3 = 1 << 2; // 11100000 100_____ - constexpr const uint8_t SURROGATE = 1 << 4; // 11101101 101_____ - constexpr const uint8_t OVERLONG_2 = 1 << 5; // 1100000_ 10______ - constexpr const uint8_t TWO_CONTS = 1 << 7; // 10______ 10______ - constexpr const uint8_t TOO_LARGE = 1 << 3; // 11110100 1001____ - // 11110100 101_____ - // 11110101 1001____ - // 11110101 101_____ - // 1111011_ 1001____ - // 1111011_ 101_____ - // 11111___ 1001____ - // 11111___ 101_____ - constexpr const uint8_t TOO_LARGE_1000 = 1 << 6; - // 11110101 1000____ - // 1111011_ 1000____ - // 11111___ 1000____ - constexpr const uint8_t OVERLONG_4 = 1 << 6; // 11110000 1000____ - - const simd8 byte_1_high = prev1.shr<4>().lookup_16( - // 0_______ ________ - TOO_LONG, - TOO_LONG, - TOO_LONG, - TOO_LONG, - TOO_LONG, - TOO_LONG, - TOO_LONG, - TOO_LONG, - // 10______ ________ - TWO_CONTS, - TWO_CONTS, - TWO_CONTS, - TWO_CONTS, - // 1100____ ________ - TOO_SHORT | OVERLONG_2, - // 1101____ ________ - TOO_SHORT, - // 1110____ ________ - TOO_SHORT | OVERLONG_3 | SURROGATE, - // 1111____ ________ - TOO_SHORT | TOO_LARGE | TOO_LARGE_1000 | OVERLONG_4); - constexpr const uint8_t CARRY = - TOO_SHORT | TOO_LONG | TWO_CONTS; // These all have ____ in byte 1 . - const simd8 byte_1_low = - (prev1 & 0x0F) - .lookup_16( - // ____0000 ________ - CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4, - // ____0001 ________ - CARRY | OVERLONG_2, - // ____001_ ________ - CARRY, - CARRY, - - // ____0100 ________ - CARRY | TOO_LARGE, - // ____0101 ________ - CARRY | TOO_LARGE | TOO_LARGE_1000, - // ____011_ ________ - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000, - - // ____1___ ________ - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000, - // ____1101 ________ - CARRY | TOO_LARGE | TOO_LARGE_1000 | SURROGATE, - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000); - const simd8 byte_2_high = input.shr<4>().lookup_16( - // ________ 0_______ - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - - // ________ 1000____ - TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE_1000 | - OVERLONG_4, - // ________ 1001____ - TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE, - // ________ 101_____ - TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE | TOO_LARGE, - TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE | TOO_LARGE, - - // ________ 11______ - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT); - return (byte_1_high & byte_1_low & byte_2_high); -} -simdjson_really_inline simd8 check_multibyte_lengths( - const simd8 input, - const simd8 prev_input, - const simd8 sc) { - simd8 prev2 = input.prev<2>(prev_input); - simd8 prev3 = input.prev<3>(prev_input); - simd8 must23 = - simd8(must_be_2_3_continuation(prev2, prev3)); - simd8 must23_80 = must23 & uint8_t(0x80); - return must23_80 ^ sc; -} - -// -// Return nonzero if there are incomplete multibyte characters at the end of the -// block: -// e.g. if there is a 4-byte character, but it's 3 bytes from the end. -// -simdjson_really_inline simd8 is_incomplete( - const simd8 input) { - // If the previous input's last 3 bytes match this, they're too short (they - // ended at EOF): - // ... 1111____ 111_____ 11______ - static const uint8_t max_array[32] = {255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 0b11110000u - 1, - 0b11100000u - 1, - 0b11000000u - 1}; - const simd8 max_value( - &max_array[sizeof(max_array) - sizeof(simd8)]); - return input.gt_bits(max_value); -} - -struct utf8_checker { - // If this is nonzero, there has been a UTF-8 error. - simd8 error; - // The last input we received - simd8 prev_input_block; - // Whether the last input we received was incomplete (used for ASCII fast - // path) - simd8 prev_incomplete; - - // - // Check whether the current bytes are valid UTF-8. - // - simdjson_really_inline void check_utf8_bytes( - const simd8 input, const simd8 prev_input) { - // Flip prev1...prev3 so we can easily determine if they are 2+, 3+ or - // 4+ lead bytes - // (2, 3, 4-byte leads become large positive numbers instead of small - // negative numbers) - simd8 prev1 = input.prev<1>(prev_input); - simd8 sc = check_special_cases(input, prev1); - this->error |= check_multibyte_lengths(input, prev_input, sc); - } - - // The only problem that can happen at EOF is that a multibyte character is - // too short - // or a byte value too large in the last bytes: check_special_cases only - // checks for bytes - // too large in the first of two bytes. - simdjson_really_inline void check_eof() { - // If the previous block had incomplete UTF-8 characters at the end, an - // ASCII block can't - // possibly finish them. - this->error |= this->prev_incomplete; - } - - simdjson_really_inline void check_next_input( - const simd8x64 &input) { - if (simdjson_likely(is_ascii(input))) { - this->error |= this->prev_incomplete; - } else { - // you might think that a for-loop would work, but under Visual - // Studio, it is not good enough. - static_assert( - (simd8x64::NUM_CHUNKS == 2) || - (simd8x64::NUM_CHUNKS == 4), - "We support either two or four chunks per 64-byte block."); - if (simd8x64::NUM_CHUNKS == 2) { - this->check_utf8_bytes(input.chunks[0], this->prev_input_block); - this->check_utf8_bytes(input.chunks[1], input.chunks[0]); - } else if (simd8x64::NUM_CHUNKS == 4) { - this->check_utf8_bytes(input.chunks[0], this->prev_input_block); - this->check_utf8_bytes(input.chunks[1], input.chunks[0]); - this->check_utf8_bytes(input.chunks[2], input.chunks[1]); - this->check_utf8_bytes(input.chunks[3], input.chunks[2]); - } - this->prev_incomplete = - is_incomplete(input.chunks[simd8x64::NUM_CHUNKS - 1]); - this->prev_input_block = - input.chunks[simd8x64::NUM_CHUNKS - 1]; - } - } - // do not forget to call check_eof! - simdjson_really_inline error_code errors() { - return this->error.any_bits_set_anywhere() ? error_code::UTF8_ERROR - : error_code::SUCCESS; - } - -}; // struct utf8_checker -} // namespace utf8_validation - -using utf8_validation::utf8_checker; - -} // unnamed namespace -} // namespace arm64 -} // namespace simdjson -/* end file src/generic/stage1/utf8_lookup4_algorithm.h */ -/* begin file src/generic/stage1/json_structural_indexer.h */ -// This file contains the common code every implementation uses in stage1 -// It is intended to be included multiple times and compiled multiple times -// We assume the file in which it is included already includes -// "simdjson/stage1.h" (this simplifies amalgation) - -/* begin file src/generic/stage1/buf_block_reader.h */ -namespace simdjson { -namespace arm64 { -namespace { - -// Walks through a buffer in block-sized increments, loading the last part with -// spaces -template -struct buf_block_reader { - public: - simdjson_really_inline buf_block_reader(const uint8_t *_buf, size_t _len); - simdjson_really_inline size_t block_index(); - simdjson_really_inline bool has_full_block() const; - simdjson_really_inline const uint8_t *full_block() const; - /** - * Get the last block, padded with spaces. - * - * There will always be a last block, with at least 1 byte, unless len == 0 - * (in which case this - * function fills the buffer with spaces and returns 0. In particular, if - * len == STEP_SIZE there - * will be 0 full_blocks and 1 remainder block with STEP_SIZE bytes and no - * spaces for padding. - * - * @return the number of effective characters in the last block. - */ - simdjson_really_inline size_t get_remainder(uint8_t *dst) const; - simdjson_really_inline void advance(); - - private: - const uint8_t *buf; - const size_t len; - const size_t lenminusstep; - size_t idx; -}; - -// Routines to print masks and text for debugging bitmask operations -simdjson_unused static char *format_input_text_64(const uint8_t *text) { - static char buf[sizeof(simd8x64) + 1]; - for (size_t i = 0; i < sizeof(simd8x64); i++) { - buf[i] = int8_t(text[i]) < ' ' ? '_' : int8_t(text[i]); - } - buf[sizeof(simd8x64)] = '\0'; - return buf; -} - -// Routines to print masks and text for debugging bitmask operations -simdjson_unused static char *format_input_text(const simd8x64 &in) { - static char buf[sizeof(simd8x64) + 1]; - in.store(reinterpret_cast(buf)); - for (size_t i = 0; i < sizeof(simd8x64); i++) { - if (buf[i] < ' ') { - buf[i] = '_'; - } - } - buf[sizeof(simd8x64)] = '\0'; - return buf; -} - -simdjson_unused static char *format_mask(uint64_t mask) { - static char buf[sizeof(simd8x64) + 1]; - for (size_t i = 0; i < 64; i++) { - buf[i] = (mask & (size_t(1) << i)) ? 'X' : ' '; - } - buf[64] = '\0'; - return buf; -} - -template -simdjson_really_inline buf_block_reader::buf_block_reader( - const uint8_t *_buf, size_t _len) - : buf{_buf}, - len{_len}, - lenminusstep{len < STEP_SIZE ? 0 : len - STEP_SIZE}, - idx{0} {} - -template -simdjson_really_inline size_t buf_block_reader::block_index() { - return idx; -} - -template -simdjson_really_inline bool buf_block_reader::has_full_block() - const { - return idx < lenminusstep; -} - -template -simdjson_really_inline const uint8_t *buf_block_reader::full_block() - const { - return &buf[idx]; -} - -template -simdjson_really_inline size_t -buf_block_reader::get_remainder(uint8_t *dst) const { - if (len == idx) { - return 0; - } // memcpy(dst, null, 0) will trigger an error with some sanitizers - std::memset(dst, 0x20, STEP_SIZE); // std::memset STEP_SIZE because it's - // more efficient to write out 8 or 16 - // bytes at once. - std::memcpy(dst, buf + idx, len - idx); - return len - idx; -} - -template -simdjson_really_inline void buf_block_reader::advance() { - idx += STEP_SIZE; -} - -} // unnamed namespace -} // namespace arm64 -} // namespace simdjson -/* end file src/generic/stage1/buf_block_reader.h */ -/* begin file src/generic/stage1/json_string_scanner.h */ -namespace simdjson { -namespace arm64 { -namespace { -namespace stage1 { - -struct json_string_block { - // We spell out the constructors in the hope of resolving inlining issues - // with Visual Studio 2017 - simdjson_really_inline json_string_block(uint64_t backslash, - uint64_t escaped, - uint64_t quote, - uint64_t in_string) - : _backslash(backslash), - _escaped(escaped), - _quote(quote), - _in_string(in_string) {} - - // Escaped characters (characters following an escape() character) - simdjson_really_inline uint64_t escaped() const { return _escaped; } - // Escape characters (backslashes that are not escaped--i.e. in \\, includes - // only the first \) - simdjson_really_inline uint64_t escape() const { - return _backslash & ~_escaped; - } - // Real (non-backslashed) quotes - simdjson_really_inline uint64_t quote() const { return _quote; } - // Start quotes of strings - simdjson_really_inline uint64_t string_start() const { - return _quote & _in_string; - } - // End quotes of strings - simdjson_really_inline uint64_t string_end() const { - return _quote & ~_in_string; - } - // Only characters inside the string (not including the quotes) - simdjson_really_inline uint64_t string_content() const { - return _in_string & ~_quote; - } - // Return a mask of whether the given characters are inside a string (only - // works on non-quotes) - simdjson_really_inline uint64_t - non_quote_inside_string(uint64_t mask) const { - return mask & _in_string; - } - // Return a mask of whether the given characters are inside a string (only - // works on non-quotes) - simdjson_really_inline uint64_t - non_quote_outside_string(uint64_t mask) const { - return mask & ~_in_string; - } - // Tail of string (everything except the start quote) - simdjson_really_inline uint64_t string_tail() const { - return _in_string ^ _quote; - } - - // backslash characters - uint64_t _backslash; - // escaped characters (backslashed--does not include the hex characters - // after \u) - uint64_t _escaped; - // real quotes (non-backslashed ones) - uint64_t _quote; - // string characters (includes start quote but not end quote) - uint64_t _in_string; -}; - -// Scans blocks for string characters, storing the state necessary to do so -class json_string_scanner { - public: - simdjson_really_inline json_string_block - next(const simd::simd8x64 &in); - // Returns either UNCLOSED_STRING or SUCCESS - simdjson_really_inline error_code finish(); - - private: - // Intended to be defined by the implementation - simdjson_really_inline uint64_t find_escaped(uint64_t escape); - simdjson_really_inline uint64_t find_escaped_branchless(uint64_t escape); - - // Whether the last iteration was still inside a string (all 1's = true, all - // 0's = false). - uint64_t prev_in_string = 0ULL; - // Whether the first character of the next iteration is escaped. - uint64_t prev_escaped = 0ULL; -}; - -// -// Finds escaped characters (characters following \). -// -// Handles runs of backslashes like \\\" and \\\\" correctly (yielding 0101 and -// 01010, respectively). -// -// Does this by: -// - Shift the escape mask to get potentially escaped characters (characters -// after backslashes). -// - Mask escaped sequences that start on *even* bits with 1010101010 (odd bits -// are escaped, even bits are not) -// - Mask escaped sequences that start on *odd* bits with 0101010101 (even bits -// are escaped, odd bits are not) -// -// To distinguish between escaped sequences starting on even/odd bits, it finds -// the start of all -// escape sequences, filters out the ones that start on even bits, and adds that -// to the mask of -// escape sequences. This causes the addition to clear out the sequences -// starting on odd bits (since -// the start bit causes a carry), and leaves even-bit sequences alone. -// -// Example: -// -// text | \\\ | \\\"\\\" \\\" \\"\\" | -// escape | xxx | xx xxx xxx xx xx | Removed overflow backslash; -// will | it into follows_escape -// odd_starts | x | x x x | escape & ~even_bits & -// ~follows_escape -// even_seq | c| cxxx c xx c | c = carry bit -- will be -// masked out later -// invert_mask | | cxxx c xx c| even_seq << 1 -// follows_escape | xx | x xx xxx xxx xx xx | Includes overflow bit -// escaped | x | x x x x x x x x | -// desired | x | x x x x x x x x | -// text | \\\ | \\\"\\\" \\\" \\"\\" | -// -simdjson_really_inline uint64_t -json_string_scanner::find_escaped_branchless(uint64_t backslash) { - // If there was overflow, pretend the first character isn't a backslash - backslash &= ~prev_escaped; - uint64_t follows_escape = backslash << 1 | prev_escaped; - - // Get sequences starting on even bits by clearing out the odd series using - // + - const uint64_t even_bits = 0x5555555555555555ULL; - uint64_t odd_sequence_starts = backslash & ~even_bits & ~follows_escape; - uint64_t sequences_starting_on_even_bits; - prev_escaped = add_overflow( - odd_sequence_starts, backslash, &sequences_starting_on_even_bits); - uint64_t invert_mask = - sequences_starting_on_even_bits - << 1; // The mask we want to return is the *escaped* bits, not escapes. - - // Mask every other backslashed character as an escaped character - // Flip the mask for sequences that start on even bits, to correct them - return (even_bits ^ invert_mask) & follows_escape; -} - -// -// Return a mask of all string characters plus end quotes. -// -// prev_escaped is overflow saying whether the next character is escaped. -// prev_in_string is overflow saying whether we're still in a string. -// -// Backslash sequences outside of quotes will be detected in stage 2. -// -simdjson_really_inline json_string_block -json_string_scanner::next(const simd::simd8x64 &in) { - const uint64_t backslash = in.eq('\\'); - const uint64_t escaped = find_escaped(backslash); - const uint64_t quote = in.eq('"') & ~escaped; - - // - // prefix_xor flips on bits inside the string (and flips off the end quote). - // - // Then we xor with prev_in_string: if we were in a string already, its - // effect is flipped - // (characters inside strings are outside, and characters outside strings - // are inside). - // - const uint64_t in_string = prefix_xor(quote) ^ prev_in_string; - - // - // Check if we're still in a string at the end of the box so the next block - // will know - // - // right shift of a signed value expected to be well-defined and standard - // compliant as of C++20, John Regher from Utah U. says this is fine code - // - prev_in_string = uint64_t(static_cast(in_string) >> 63); - - // Use ^ to turn the beginning quote off, and the end quote on. - - // We are returning a function-local object so either we get a move - // constructor - // or we get copy elision. - return json_string_block(backslash, escaped, quote, in_string); -} - -simdjson_really_inline error_code json_string_scanner::finish() { - if (prev_in_string) { - return UNCLOSED_STRING; - } - return SUCCESS; -} - -} // namespace stage1 -} // unnamed namespace -} // namespace arm64 -} // namespace simdjson -/* end file src/generic/stage1/json_string_scanner.h */ -/* begin file src/generic/stage1/json_scanner.h */ -namespace simdjson { -namespace arm64 { -namespace { -namespace stage1 { - -/** - * A block of scanned json, with information on operators and scalars. - * - * We seek to identify pseudo-structural characters. Anything that is inside - * a string must be omitted (hence & ~_string.string_tail()). - * Otherwise, pseudo-structural characters come in two forms. - * 1. We have the structural characters ([,],{,},:, comma). The - * term 'structural character' is from the JSON RFC. - * 2. We have the 'scalar pseudo-structural characters'. - * Scalars are quotes, and any character except structural characters and - * white space. - * - * To identify the scalar pseudo-structural characters, we must look at what - * comes - * before them: it must be a space, a quote or a structural characters. - * Starting with simdjson v0.3, we identify them by - * negation: we identify everything that is followed by a non-quote scalar, - * and we negate that. Whatever remains must be a 'scalar pseudo-structural - * character'. - */ -struct json_block { - public: - // We spell out the constructors in the hope of resolving inlining issues - // with Visual Studio 2017 - simdjson_really_inline json_block( - json_string_block &&string, - json_character_block characters, - uint64_t follows_potential_nonquote_scalar) - : _string(std::move(string)), - _characters(characters), - _follows_potential_nonquote_scalar( - follows_potential_nonquote_scalar) {} - simdjson_really_inline json_block( - json_string_block string, - json_character_block characters, - uint64_t follows_potential_nonquote_scalar) - : _string(string), - _characters(characters), - _follows_potential_nonquote_scalar( - follows_potential_nonquote_scalar) {} - - /** - * The start of structurals. - * In simdjson prior to v0.3, these were called the pseudo-structural - *characters. - **/ - simdjson_really_inline uint64_t structural_start() const noexcept { - return potential_structural_start() & ~_string.string_tail(); - } - /** All JSON whitespace (i.e. not in a string) */ - simdjson_really_inline uint64_t whitespace() const noexcept { - return non_quote_outside_string(_characters.whitespace()); - } - - // Helpers - - /** Whether the given characters are inside a string (only works on - * non-quotes) */ - simdjson_really_inline uint64_t non_quote_inside_string(uint64_t mask) const - noexcept { - return _string.non_quote_inside_string(mask); - } - /** Whether the given characters are outside a string (only works on - * non-quotes) */ - simdjson_really_inline uint64_t - non_quote_outside_string(uint64_t mask) const noexcept { - return _string.non_quote_outside_string(mask); - } - - // string and escape characters - json_string_block _string; - // whitespace, structural characters ('operators'), scalars - json_character_block _characters; - // whether the previous character was a scalar - uint64_t _follows_potential_nonquote_scalar; - - private: - // Potential structurals (i.e. disregarding strings) - - /** - * structural elements ([,],{,},:, comma) plus scalar starts like 123, true - *and "abc". - * They may reside inside a string. - **/ - simdjson_really_inline uint64_t potential_structural_start() const - noexcept { - return _characters.op() | potential_scalar_start(); - } - /** - * The start of non-operator runs, like 123, true and "abc". - * It main reside inside a string. - **/ - simdjson_really_inline uint64_t potential_scalar_start() const noexcept { - // The term "scalar" refers to anything except structural characters and - // white space - // (so letters, numbers, quotes). - // Whenever it is preceded by something that is not a structural element - // ({,},[,],:, ") nor a white-space - // then we know that it is irrelevant structurally. - return _characters.scalar() & ~follows_potential_scalar(); - } - /** - * Whether the given character is immediately after a non-operator like 123, - * true. - * The characters following a quote are not included. - */ - simdjson_really_inline uint64_t follows_potential_scalar() const noexcept { - // _follows_potential_nonquote_scalar: is defined as marking any - // character that follows a character - // that is not a structural element ({,},[,],:, comma) nor a quote (") - // and that is not a - // white space. - // It is understood that within quoted region, anything at all could be - // marked (irrelevant). - return _follows_potential_nonquote_scalar; - } -}; - -/** - * Scans JSON for important bits: structural characters or 'operators', strings, - * and scalars. - * - * The scanner starts by calculating two distinct things: - * - string characters (taking \" into account) - * - structural characters or 'operators' ([]{},:, comma) - * and scalars (runs of non-operators like 123, true and "abc") - * - * To minimize data dependency (a key component of the scanner's speed), it - * finds these in parallel: - * in particular, the operator/scalar bit will find plenty of things that are - * actually part of - * strings. When we're done, json_block will fuse the two together by masking - * out tokens that are - * part of a string. - */ -class json_scanner { - public: - json_scanner() {} - simdjson_really_inline json_block next(const simd::simd8x64 &in); - // Returns either UNCLOSED_STRING or SUCCESS - simdjson_really_inline error_code finish(); - - private: - // Whether the last character of the previous iteration is part of a scalar - // token - // (anything except whitespace or a structural character/'operator'). - uint64_t prev_scalar = 0ULL; - json_string_scanner string_scanner{}; -}; - - -// -// Check if the current character immediately follows a matching character. -// -// For example, this checks for quotes with backslashes in front of them: -// -// const uint64_t backslashed_quote = in.eq('"') & -// immediately_follows(in.eq('\'), prev_backslash); -// -simdjson_really_inline uint64_t follows(const uint64_t match, - uint64_t &overflow) { - const uint64_t result = match << 1 | overflow; - overflow = match >> 63; - return result; -} - -simdjson_really_inline json_block -json_scanner::next(const simd::simd8x64 &in) { - json_string_block strings = string_scanner.next(in); - // identifies the white-space and the structural characters - json_character_block characters = json_character_block::classify(in); - // The term "scalar" refers to anything except structural characters and - // white space - // (so letters, numbers, quotes). - // We want follows_scalar to mark anything that follows a non-quote scalar - // (so letters and numbers). - // - // A terminal quote should either be followed by a structural character - // (comma, brace, bracket, colon) - // or nothing. However, we still want ' "a string"true ' to mark the 't' of - // 'true' as a potential - // pseudo-structural character just like we would if we had ' "a string" - // true '; otherwise we - // may need to add an extra check when parsing strings. - // - // Performance: there are many ways to skin this cat. - const uint64_t nonquote_scalar = characters.scalar() & ~strings.quote(); - uint64_t follows_nonquote_scalar = follows(nonquote_scalar, prev_scalar); - // We are returning a function-local object so either we get a move - // constructor - // or we get copy elision. - return json_block(strings, // strings is a function-local object so either - // it moves or the copy is elided. - characters, - follows_nonquote_scalar); -} - -simdjson_really_inline error_code json_scanner::finish() { - return string_scanner.finish(); -} - -} // namespace stage1 -} // unnamed namespace -} // namespace arm64 -} // namespace simdjson -/* end file src/generic/stage1/json_scanner.h */ -/* begin file src/generic/stage1/json_minifier.h */ -// This file contains the common code every implementation uses in stage1 -// It is intended to be included multiple times and compiled multiple times -// We assume the file in which it is included already includes -// "simdjson/stage1.h" (this simplifies amalgation) - -namespace simdjson { -namespace arm64 { -namespace { -namespace stage1 { - -class json_minifier { - public: - template - static error_code minify(const uint8_t *buf, - size_t len, - uint8_t *dst, - size_t &dst_len) noexcept; - - private: - simdjson_really_inline json_minifier(uint8_t *_dst) : dst{_dst} {} - template - simdjson_really_inline void step( - const uint8_t *block_buf, buf_block_reader &reader) noexcept; - simdjson_really_inline void next(const simd::simd8x64 &in, - const json_block &block); - simdjson_really_inline error_code finish(uint8_t *dst_start, - size_t &dst_len); - json_scanner scanner{}; - uint8_t *dst; -}; - -simdjson_really_inline void json_minifier::next( - const simd::simd8x64 &in, const json_block &block) { - uint64_t mask = block.whitespace(); - dst += in.compress(mask, dst); -} - -simdjson_really_inline error_code json_minifier::finish(uint8_t *dst_start, - size_t &dst_len) { - error_code error = scanner.finish(); - if (error) { - dst_len = 0; - return error; - } - dst_len = dst - dst_start; - return SUCCESS; -} - -template <> -simdjson_really_inline void json_minifier::step<128>( - const uint8_t *block_buf, buf_block_reader<128> &reader) noexcept { - simd::simd8x64 in_1(block_buf); - simd::simd8x64 in_2(block_buf + 64); - json_block block_1 = scanner.next(in_1); - json_block block_2 = scanner.next(in_2); - this->next(in_1, block_1); - this->next(in_2, block_2); - reader.advance(); -} - -template <> -simdjson_really_inline void json_minifier::step<64>( - const uint8_t *block_buf, buf_block_reader<64> &reader) noexcept { - simd::simd8x64 in_1(block_buf); - json_block block_1 = scanner.next(in_1); - this->next(block_buf, block_1); - reader.advance(); -} - -template -error_code json_minifier::minify(const uint8_t *buf, - size_t len, - uint8_t *dst, - size_t &dst_len) noexcept { - buf_block_reader reader(buf, len); - json_minifier minifier(dst); - - // Index the first n-1 blocks - while (reader.has_full_block()) { - minifier.step(reader.full_block(), reader); - } - - // Index the last (remainder) block, padded with spaces - uint8_t block[STEP_SIZE]; - size_t remaining_bytes = reader.get_remainder(block); - if (remaining_bytes > 0) { - // We do not want to write directly to the output stream. Rather, we - // write - // to a local buffer (for safety). - uint8_t out_block[STEP_SIZE]; - uint8_t *const guarded_dst{minifier.dst}; - minifier.dst = out_block; - minifier.step(block, reader); - size_t to_write = minifier.dst - out_block; - // In some cases, we could be enticed to consider the padded spaces - // as part of the string. This is fine as long as we do not write more - // than we consumed. - if (to_write > remaining_bytes) { - to_write = remaining_bytes; - } - memcpy(guarded_dst, out_block, to_write); - minifier.dst = guarded_dst + to_write; - } - return minifier.finish(dst, dst_len); -} - -} // namespace stage1 -} // unnamed namespace -} // namespace arm64 -} // namespace simdjson -/* end file src/generic/stage1/json_minifier.h */ -/* begin file src/generic/stage1/find_next_document_index.h */ -namespace simdjson { -namespace arm64 { -namespace { - -/** - * This algorithm is used to quickly identify the last structural position that - * makes up a complete document. - * - * It does this by going backwards and finding the last *document boundary* (a - * place where one value follows another without a comma between them). If the - * last document (the characters after the boundary) has an equal number of - * start and end brackets, it is considered complete. - * - * Simply put, we iterate over the structural characters, starting from - * the end. We consider that we found the end of a JSON document when the - * first element of the pair is NOT one of these characters: '{' '[' ':' ',' - * and when the second element is NOT one of these characters: '}' ']' ':' ','. - * - * This simple comparison works most of the time, but it does not cover cases - * where the batch's structural indexes contain a perfect amount of documents. - * In such a case, we do not have access to the structural index which follows - * the last document, therefore, we do not have access to the second element in - * the pair, and that means we cannot identify the last document. To fix this - * issue, we keep a count of the open and closed curly/square braces we found - * while searching for the pair. When we find a pair AND the count of open and - * closed curly/square braces is the same, we know that we just passed a - * complete document, therefore the last json buffer location is the end of the - * batch. - */ -simdjson_really_inline uint32_t -find_next_document_index(dom_parser_implementation &parser) { - // Variant: do not count separately, just figure out depth - if (parser.n_structural_indexes == 0) { - return 0; - } - auto arr_cnt = 0; - auto obj_cnt = 0; - for (auto i = parser.n_structural_indexes - 1; i > 0; i--) { - auto idxb = parser.structural_indexes[i]; - switch (parser.buf[idxb]) { - case ':': - case ',': - continue; - case '}': - obj_cnt--; - continue; - case ']': - arr_cnt--; - continue; - case '{': - obj_cnt++; - break; - case '[': - arr_cnt++; - break; - } - auto idxa = parser.structural_indexes[i - 1]; - switch (parser.buf[idxa]) { - case '{': - case '[': - case ':': - case ',': - continue; - } - // Last document is complete, so the next document will appear after! - if (!arr_cnt && !obj_cnt) { - return parser.n_structural_indexes; - } - // Last document is incomplete; mark the document at i + 1 as the next - // one - return i; - } - // If we made it to the end, we want to finish counting to see if we have a - // full document. - switch (parser.buf[parser.structural_indexes[0]]) { - case '}': - obj_cnt--; - break; - case ']': - arr_cnt--; - break; - case '{': - obj_cnt++; - break; - case '[': - arr_cnt++; - break; - } - if (!arr_cnt && !obj_cnt) { - // We have a complete document. - return parser.n_structural_indexes; - } - return 0; -} - -} // unnamed namespace -} // namespace arm64 -} // namespace simdjson -/* end file src/generic/stage1/find_next_document_index.h */ - -namespace simdjson { -namespace arm64 { -namespace { -namespace stage1 { - -class bit_indexer { - public: - uint32_t *tail; - - simdjson_really_inline bit_indexer(uint32_t *index_buf) : tail(index_buf) {} - - // flatten out values in 'bits' assuming that they are are to have values of - // idx - // plus their position in the bitvector, and store these indexes at - // base_ptr[base] incrementing base as we go - // will potentially store extra values beyond end of valid bits, so base_ptr - // needs to be large enough to handle this - simdjson_really_inline void write(uint32_t idx, uint64_t bits) { - // In some instances, the next branch is expensive because it is - // mispredicted. - // Unfortunately, in other cases, - // it helps tremendously. - if (bits == 0) return; -#if defined(SIMDJSON_PREFER_REVERSE_BITS) - /** - * ARM lacks a fast trailing zero instruction, but it has a fast - * bit reversal instruction and a fast leading zero instruction. - * Thus it may be profitable to reverse the bits (once) and then - * to rely on a sequence of instructions that call the leading - * zero instruction. - * - * Performance notes: - * The chosen routine is not optimal in terms of data dependency - * since zero_leading_bit might require two instructions. However, - * it tends to minimize the total number of instructions which is - * beneficial. - */ - - uint64_t rev_bits = reverse_bits(bits); - int cnt = static_cast(count_ones(bits)); - int i = 0; - // Do the first 8 all together - for (; i < 8; i++) { - int lz = leading_zeroes(rev_bits); - this->tail[i] = static_cast(idx) + lz; - rev_bits = zero_leading_bit(rev_bits, lz); - } - // Do the next 8 all together (we hope in most cases it won't happen at - // all - // and the branch is easily predicted). - if (simdjson_unlikely(cnt > 8)) { - i = 8; - for (; i < 16; i++) { - int lz = leading_zeroes(rev_bits); - this->tail[i] = static_cast(idx) + lz; - rev_bits = zero_leading_bit(rev_bits, lz); - } - - - // Most files don't have 16+ structurals per block, so we take - // several basically guaranteed - // branch mispredictions here. 16+ structurals per block means - // either punctuation ({} [] , :) - // or the start of a value ("abc" true 123) every four characters. - if (simdjson_unlikely(cnt > 16)) { - i = 16; - while (rev_bits != 0) { - int lz = leading_zeroes(rev_bits); - this->tail[i++] = static_cast(idx) + lz; - rev_bits = zero_leading_bit(rev_bits, lz); - } - } - } - this->tail += cnt; -#else // SIMDJSON_PREFER_REVERSE_BITS - /** - * Under recent x64 systems, we often have both a fast trailing zero - * instruction and a fast 'clear-lower-bit' instruction so the following - * algorithm can be competitive. - */ - - int cnt = static_cast(count_ones(bits)); - // Do the first 8 all together - for (int i = 0; i < 8; i++) { - this->tail[i] = idx + trailing_zeroes(bits); - bits = clear_lowest_bit(bits); - } - - // Do the next 8 all together (we hope in most cases it won't happen at - // all - // and the branch is easily predicted). - if (simdjson_unlikely(cnt > 8)) { - for (int i = 8; i < 16; i++) { - this->tail[i] = idx + trailing_zeroes(bits); - bits = clear_lowest_bit(bits); - } - - // Most files don't have 16+ structurals per block, so we take - // several basically guaranteed - // branch mispredictions here. 16+ structurals per block means - // either punctuation ({} [] , :) - // or the start of a value ("abc" true 123) every four characters. - if (simdjson_unlikely(cnt > 16)) { - int i = 16; - do { - this->tail[i] = idx + trailing_zeroes(bits); - bits = clear_lowest_bit(bits); - i++; - } while (i < cnt); - } - } - - this->tail += cnt; -#endif - } -}; - -class json_structural_indexer { - public: - /** - * Find the important bits of JSON in a 128-byte chunk, and add them to - * structural_indexes. - * - * @param partial Setting the partial parameter to true allows the - * find_structural_bits to - * tolerate unclosed strings. The caller should still ensure that the - * input is valid UTF-8. If - * you are processing substrings, you may want to call on a function like - * trimmed_length_safe_utf8. - */ - template - static error_code index(const uint8_t *buf, - size_t len, - dom_parser_implementation &parser, - stage1_mode partial) noexcept; - - private: - simdjson_really_inline json_structural_indexer( - uint32_t *structural_indexes); - template - simdjson_really_inline void step( - const uint8_t *block, buf_block_reader &reader) noexcept; - simdjson_really_inline void next(const simd::simd8x64 &in, - const json_block &block, - size_t idx); - simdjson_really_inline error_code finish(dom_parser_implementation &parser, - size_t idx, - size_t len, - stage1_mode partial); - - json_scanner scanner{}; - utf8_checker checker{}; - bit_indexer indexer; - uint64_t prev_structurals = 0; - uint64_t unescaped_chars_error = 0; -}; - -simdjson_really_inline json_structural_indexer::json_structural_indexer( - uint32_t *structural_indexes) - : indexer{structural_indexes} {} - -// Skip the last character if it is partial -simdjson_really_inline size_t trim_partial_utf8(const uint8_t *buf, - size_t len) { - if (simdjson_unlikely(len < 3)) { - switch (len) { - case 2: - if (buf[len - 1] >= 0b11000000) { - return len - 1; - } // 2-, 3- and 4-byte characters with only 1 byte left - if (buf[len - 2] >= 0b11100000) { - return len - 2; - } // 3- and 4-byte characters with only 2 bytes left - return len; - case 1: - if (buf[len - 1] >= 0b11000000) { - return len - 1; - } // 2-, 3- and 4-byte characters with only 1 byte left - return len; - case 0: - return len; - } - } - if (buf[len - 1] >= 0b11000000) { - return len - 1; - } // 2-, 3- and 4-byte characters with only 1 byte left - if (buf[len - 2] >= 0b11100000) { - return len - 2; - } // 3- and 4-byte characters with only 1 byte left - if (buf[len - 3] >= 0b11110000) { - return len - 3; - } // 4-byte characters with only 3 bytes left - return len; -} - -// -// PERF NOTES: -// We pipe 2 inputs through these stages: -// 1. Load JSON into registers. This takes a long time and is highly -// parallelizable, so we load -// 2 inputs' worth at once so that by the time step 2 is looking for them -// input, it's available. -// 2. Scan the JSON for critical data: strings, scalars and operators. This is -// the critical path. -// The output of step 1 depends entirely on this information. These functions -// don't quite use -// up enough CPU: the second half of the functions is highly serial, only -// using 1 execution core -// at a time. The second input's scans has some dependency on the first ones -// finishing it, but -// they can make a lot of progress before they need that information. -// 3. Step 1 doesn't use enough capacity, so we run some extra stuff while we're -// waiting for that -// to finish: utf-8 checks and generating the output from the last iteration. -// -// The reason we run 2 inputs at a time, is steps 2 and 3 are *still* not enough -// to soak up all -// available capacity with just one input. Running 2 at a time seems to give the -// CPU a good enough -// workout. -// -template -error_code json_structural_indexer::index(const uint8_t *buf, - size_t len, - dom_parser_implementation &parser, - stage1_mode partial) noexcept { - if (simdjson_unlikely(len > parser.capacity())) { - return CAPACITY; - } - // We guard the rest of the code so that we can assume that len > 0 - // throughout. - if (len == 0) { - return EMPTY; - } - if (is_streaming(partial)) { - len = trim_partial_utf8(buf, len); - // If you end up with an empty window after trimming - // the partial UTF-8 bytes, then chances are good that you - // have an UTF-8 formatting error. - if (len == 0) { - return UTF8_ERROR; - } - } - buf_block_reader reader(buf, len); - json_structural_indexer indexer(parser.structural_indexes.get()); - - // Read all but the last block - while (reader.has_full_block()) { - indexer.step(reader.full_block(), reader); - } - // Take care of the last block (will always be there unless file is empty - // which is - // not supposed to happen.) - uint8_t block[STEP_SIZE]; - if (simdjson_unlikely(reader.get_remainder(block) == 0)) { - return UNEXPECTED_ERROR; - } - indexer.step(block, reader); - return indexer.finish(parser, reader.block_index(), len, partial); -} - -template <> -simdjson_really_inline void json_structural_indexer::step<128>( - const uint8_t *block, buf_block_reader<128> &reader) noexcept { - simd::simd8x64 in_1(block); - simd::simd8x64 in_2(block + 64); - json_block block_1 = scanner.next(in_1); - json_block block_2 = scanner.next(in_2); - this->next(in_1, block_1, reader.block_index()); - this->next(in_2, block_2, reader.block_index() + 64); - reader.advance(); -} - -template <> -simdjson_really_inline void json_structural_indexer::step<64>( - const uint8_t *block, buf_block_reader<64> &reader) noexcept { - simd::simd8x64 in_1(block); - json_block block_1 = scanner.next(in_1); - this->next(in_1, block_1, reader.block_index()); - reader.advance(); -} - -simdjson_really_inline void json_structural_indexer::next( - const simd::simd8x64 &in, const json_block &block, size_t idx) { - uint64_t unescaped = in.lteq(0x1F); - checker.check_next_input(in); - indexer.write(uint32_t(idx - 64), prev_structurals); // Output *last* - // iteration's - // structurals to the - // parser - prev_structurals = block.structural_start(); - unescaped_chars_error |= block.non_quote_inside_string(unescaped); -} - -simdjson_really_inline error_code -json_structural_indexer::finish(dom_parser_implementation &parser, - size_t idx, - size_t len, - stage1_mode partial) { - // Write out the final iteration's structurals - indexer.write(uint32_t(idx - 64), prev_structurals); - error_code error = scanner.finish(); - // We deliberately break down the next expression so that it is - // human readable. - const bool should_we_exit = - is_streaming(partial) - ? ((error != SUCCESS) && - (error != - UNCLOSED_STRING)) // when partial we tolerate UNCLOSED_STRING - : (error != SUCCESS); // if partial is false, we must have SUCCESS - const bool have_unclosed_string = (error == UNCLOSED_STRING); - if (simdjson_unlikely(should_we_exit)) { - return error; - } - - if (unescaped_chars_error) { - return UNESCAPED_CHARS; - } - parser.n_structural_indexes = - uint32_t(indexer.tail - parser.structural_indexes.get()); - /*** - * The On Demand API requires special padding. - * - * This is related to https://github.com/simdjson/simdjson/issues/906 - * Basically, we want to make sure that if the parsing continues beyond the - *last (valid) - * structural character, it quickly stops. - * Only three structural characters can be repeated without triggering an - *error in JSON: [,] and }. - * We repeat the padding character (at 'len'). We don't know what it is, but - *if the parsing - * continues, then it must be [,] or }. - * Suppose it is ] or }. We backtrack to the first character, what could it - *be that would - * not trigger an error? It could be ] or } but no, because you can't start - *a document that way. - * It can't be a comma, a colon or any simple value. So the only way we - *could continue is - * if the repeated character is [. But if so, the document must start with - *[. But if the document - * starts with [, it should end with ]. If we enforce that rule, then we - *would get - * ][[ which is invalid. - * - * This is illustrated with the test array_iterate_unclosed_error() on the - *following input: - * R"({ "a": [,,)" - **/ - parser.structural_indexes[parser.n_structural_indexes] = - uint32_t(len); // used later in partial == stage1_mode::streaming_final - parser.structural_indexes[parser.n_structural_indexes + 1] = uint32_t(len); - parser.structural_indexes[parser.n_structural_indexes + 2] = 0; - parser.next_structural_index = 0; - // a valid JSON file cannot have zero structural indexes - we should have - // found something - if (simdjson_unlikely(parser.n_structural_indexes == 0u)) { - return EMPTY; - } - if (simdjson_unlikely( - parser.structural_indexes[parser.n_structural_indexes - 1] > len)) { - return UNEXPECTED_ERROR; - } - if (partial == stage1_mode::streaming_partial) { - // If we have an unclosed string, then the last structural - // will be the quote and we want to make sure to omit it. - if (have_unclosed_string) { - parser.n_structural_indexes--; - // a valid JSON file cannot have zero structural indexes - we should - // have found something - if (simdjson_unlikely(parser.n_structural_indexes == 0u)) { - return CAPACITY; - } - } - // We truncate the input to the end of the last complete document (or - // zero). - auto new_structural_indexes = find_next_document_index(parser); - if (new_structural_indexes == 0 && parser.n_structural_indexes > 0) { - if (parser.structural_indexes[0] == 0) { - // If the buffer is partial and we started at index 0 but the - // document is - // incomplete, it's too big to parse. - return CAPACITY; - } else { - // It is possible that the document could be parsed, we just had - // a lot - // of white space. - parser.n_structural_indexes = 0; - return EMPTY; - } - } - - parser.n_structural_indexes = new_structural_indexes; - } else if (partial == stage1_mode::streaming_final) { - if (have_unclosed_string) { - parser.n_structural_indexes--; - } - // We truncate the input to the end of the last complete document (or - // zero). - // Because partial == stage1_mode::streaming_final, it means that we may - // silently ignore trailing garbage. Though it sounds bad, we do it - // deliberately because many people who have streams of JSON documents - // will truncate them for processing. E.g., imagine that you are - // uncompressing - // the data from a size file or receiving it in chunks from the network. - // You - // may not know where exactly the last document will be. Meanwhile the - // document_stream instances allow people to know the JSON documents - // they are - // parsing (see the iterator.source() method). - parser.n_structural_indexes = find_next_document_index(parser); - // We store the initial n_structural_indexes so that the client can see - // whether we used truncation. If initial_n_structural_indexes == - // parser.n_structural_indexes, - // then this will query - // parser.structural_indexes[parser.n_structural_indexes] which is len, - // otherwise, it will copy some prior index. - parser.structural_indexes[parser.n_structural_indexes + 1] = - parser.structural_indexes[parser.n_structural_indexes]; - // This next line is critical, do not change it unless you understand - // what you are - // doing. - parser.structural_indexes[parser.n_structural_indexes] = uint32_t(len); - if (simdjson_unlikely(parser.n_structural_indexes == 0u)) { - // We tolerate an unclosed string at the very end of the stream. - // Indeed, users - // often load their data in bulk without being careful and they want - // us to ignore - // the trailing garbage. - return EMPTY; - } - } - checker.check_eof(); - return checker.errors(); -} - -} // namespace stage1 -} // unnamed namespace -} // namespace arm64 -} // namespace simdjson -/* end file src/generic/stage1/json_structural_indexer.h */ -/* begin file src/generic/stage1/utf8_validator.h */ -namespace simdjson { -namespace arm64 { -namespace { -namespace stage1 { - -/** - * Validates that the string is actual UTF-8. - */ -template -bool generic_validate_utf8(const uint8_t *input, size_t length) { - checker c{}; - buf_block_reader<64> reader(input, length); - while (reader.has_full_block()) { - simd::simd8x64 in(reader.full_block()); - c.check_next_input(in); - reader.advance(); - } - uint8_t block[64]{}; - reader.get_remainder(block); - simd::simd8x64 in(block); - c.check_next_input(in); - reader.advance(); - c.check_eof(); - return c.errors() == error_code::SUCCESS; -} - -bool generic_validate_utf8(const char *input, size_t length) { - return generic_validate_utf8( - reinterpret_cast(input), length); -} - -} // namespace stage1 -} // unnamed namespace -} // namespace arm64 -} // namespace simdjson -/* end file src/generic/stage1/utf8_validator.h */ - -// -// Stage 2 -// - -/* begin file src/generic/stage2/tape_builder.h */ -/* begin file src/generic/stage2/json_iterator.h */ -/* begin file src/generic/stage2/logger.h */ -// This is for an internal-only stage 2 specific logger. -// Set LOG_ENABLED = true to log what stage 2 is doing! -namespace simdjson { -namespace arm64 { -namespace { -namespace logger { - -static constexpr const char *DASHES = - "--------------------------------------------------------------------------" - "--------------------------------------------------------------------------" - "--------------------------------------------------------------------------" - "----------------------------------"; - -#if SIMDJSON_VERBOSE_LOGGING -static constexpr const bool LOG_ENABLED = true; -#else -static constexpr const bool LOG_ENABLED = false; -#endif -static constexpr const int LOG_EVENT_LEN = 20; -static constexpr const int LOG_BUFFER_LEN = 30; -static constexpr const int LOG_SMALL_BUFFER_LEN = 10; -static constexpr const int LOG_INDEX_LEN = 5; - -static int log_depth; // Not threadsafe. Log only. - -// Helper to turn unprintable or newline characters into spaces -static simdjson_really_inline char printable_char(char c) { - if (c >= 0x20) { - return c; - } else { - return ' '; - } -} - -// Print the header and set up log_start -static simdjson_really_inline void log_start() { - if (LOG_ENABLED) { - log_depth = 0; - printf("\n"); - printf("| %-*s | %-*s | %-*s | %-*s | Detail |\n", - LOG_EVENT_LEN, - "Event", - LOG_BUFFER_LEN, - "Buffer", - LOG_SMALL_BUFFER_LEN, - "Next", - 5, - "Next#"); - printf("|%.*s|%.*s|%.*s|%.*s|--------|\n", - LOG_EVENT_LEN + 2, - DASHES, - LOG_BUFFER_LEN + 2, - DASHES, - LOG_SMALL_BUFFER_LEN + 2, - DASHES, - 5 + 2, - DASHES); - } -} - -simdjson_unused static simdjson_really_inline void log_string( - const char *message) { - if (LOG_ENABLED) { - printf("%s\n", message); - } -} - -// Logs a single line from the stage 2 DOM parser -template -static simdjson_really_inline void log_line(S &structurals, - const char *title_prefix, - const char *title, - const char *detail) { - if (LOG_ENABLED) { - printf("| %*s%s%-*s ", - log_depth * 2, - "", - title_prefix, - LOG_EVENT_LEN - log_depth * 2 - int(strlen(title_prefix)), - title); - auto current_index = structurals.at_beginning() - ? nullptr - : structurals.next_structural - 1; - auto next_index = structurals.next_structural; - auto current = current_index ? &structurals.buf[*current_index] - : reinterpret_cast( - " " - " "); - auto next = &structurals.buf[*next_index]; - { - // Print the next N characters in the buffer. - printf("| "); - // Otherwise, print the characters starting from the buffer - // position. - // Print spaces for unprintable or newline characters. - for (int i = 0; i < LOG_BUFFER_LEN; i++) { - printf("%c", printable_char(current[i])); - } - printf(" "); - // Print the next N characters in the buffer. - printf("| "); - // Otherwise, print the characters starting from the buffer - // position. - // Print spaces for unprintable or newline characters. - for (int i = 0; i < LOG_SMALL_BUFFER_LEN; i++) { - printf("%c", printable_char(next[i])); - } - printf(" "); - } - if (current_index) { - printf("| %*u ", LOG_INDEX_LEN, *current_index); - } else { - printf("| %-*s ", LOG_INDEX_LEN, ""); - } - // printf("| %*u ", LOG_INDEX_LEN, structurals.next_tape_index()); - printf("| %-s ", detail); - printf("|\n"); - } -} - -} // namespace logger -} // unnamed namespace -} // namespace arm64 -} // namespace simdjson -/* end file src/generic/stage2/logger.h */ - -namespace simdjson { -namespace arm64 { -namespace { -namespace stage2 { - -class json_iterator { - public: - const uint8_t *const buf; - uint32_t *next_structural; - dom_parser_implementation &dom_parser; - uint32_t depth{0}; - - /** - * Walk the JSON document. - * - * The visitor receives callbacks when values are encountered. All callbacks - * pass the iterator as - * the first parameter; some callbacks have other parameters as well: - * - * - visit_document_start() - at the beginning. - * - visit_document_end() - at the end (if things were successful). - * - * - visit_array_start() - at the start `[` of a non-empty array. - * - visit_array_end() - at the end `]` of a non-empty array. - * - visit_empty_array() - when an empty array is encountered. - * - * - visit_object_end() - at the start `]` of a non-empty object. - * - visit_object_start() - at the end `]` of a non-empty object. - * - visit_empty_object() - when an empty object is encountered. - * - visit_key(const uint8_t *key) - when a key in an object field is - * encountered. key is - * guaranteed to point at the first quote - * of the string (`"key"`). - * - visit_primitive(const uint8_t *value) - when a value is a string, - * number, boolean or null. - * - visit_root_primitive(iter, uint8_t *value) - when the top-level value - * is a string, number, boolean or null. - * - * - increment_count(iter) - each time a value is found in an array or - * object. - */ - template - simdjson_warn_unused simdjson_really_inline error_code - walk_document(V &visitor) noexcept; - - /** - * Create an iterator capable of walking a JSON document. - * - * The document must have already passed through stage 1. - */ - simdjson_really_inline json_iterator(dom_parser_implementation &_dom_parser, - size_t start_structural_index); - - /** - * Look at the next token. - * - * Tokens can be strings, numbers, booleans, null, or operators (`[{]},:`)). - * - * They may include invalid JSON as well (such as `1.2.3` or `ture`). - */ - simdjson_really_inline const uint8_t *peek() const noexcept; - /** - * Advance to the next token. - * - * Tokens can be strings, numbers, booleans, null, or operators (`[{]},:`)). - * - * They may include invalid JSON as well (such as `1.2.3` or `ture`). - */ - simdjson_really_inline const uint8_t *advance() noexcept; - /** - * Get the remaining length of the document, from the start of the current - * token. - */ - simdjson_really_inline size_t remaining_len() const noexcept; - /** - * Check if we are at the end of the document. - * - * If this is true, there are no more tokens. - */ - simdjson_really_inline bool at_eof() const noexcept; - /** - * Check if we are at the beginning of the document. - */ - simdjson_really_inline bool at_beginning() const noexcept; - simdjson_really_inline uint8_t last_structural() const noexcept; - - /** - * Log that a value has been found. - * - * Set LOG_ENABLED=true in logger.h to see logging. - */ - simdjson_really_inline void log_value(const char *type) const noexcept; - /** - * Log the start of a multipart value. - * - * Set LOG_ENABLED=true in logger.h to see logging. - */ - simdjson_really_inline void log_start_value(const char *type) const - noexcept; - /** - * Log the end of a multipart value. - * - * Set LOG_ENABLED=true in logger.h to see logging. - */ - simdjson_really_inline void log_end_value(const char *type) const noexcept; - /** - * Log an error. - * - * Set LOG_ENABLED=true in logger.h to see logging. - */ - simdjson_really_inline void log_error(const char *error) const noexcept; - - template - simdjson_warn_unused simdjson_really_inline error_code - visit_root_primitive(V &visitor, const uint8_t *value) noexcept; - template - simdjson_warn_unused simdjson_really_inline error_code - visit_primitive(V &visitor, const uint8_t *value) noexcept; -}; - -template -simdjson_warn_unused simdjson_really_inline error_code -json_iterator::walk_document(V &visitor) noexcept { - logger::log_start(); - - // - // Start the document - // - if (at_eof()) { - return EMPTY; - } - log_start_value("document"); - SIMDJSON_TRY(visitor.visit_document_start(*this)); - - // - // Read first value - // - { - auto value = advance(); - - // Make sure the outer object or array is closed before continuing; - // otherwise, there are ways we - // could get into memory corruption. See - // https://github.com/simdjson/simdjson/issues/906 - if (!STREAMING) { - switch (*value) { - case '{': - if (last_structural() != '}') { - log_value("starting brace unmatched"); - return TAPE_ERROR; - }; - break; - case '[': - if (last_structural() != ']') { - log_value("starting bracket unmatched"); - return TAPE_ERROR; - }; - break; - } - } - - switch (*value) { - case '{': - if (*peek() == '}') { - advance(); - log_value("empty object"); - SIMDJSON_TRY(visitor.visit_empty_object(*this)); - break; - } - goto object_begin; - case '[': - if (*peek() == ']') { - advance(); - log_value("empty array"); - SIMDJSON_TRY(visitor.visit_empty_array(*this)); - break; - } - goto array_begin; - default: - SIMDJSON_TRY(visitor.visit_root_primitive(*this, value)); - break; - } - } - goto document_end; - -// -// Object parser states -// -object_begin: - log_start_value("object"); - depth++; - if (depth >= dom_parser.max_depth()) { - log_error("Exceeded max depth!"); - return DEPTH_ERROR; - } - dom_parser.is_array[depth] = false; - SIMDJSON_TRY(visitor.visit_object_start(*this)); - - { - auto key = advance(); - if (*key != '"') { - log_error("Object does not start with a key"); - return TAPE_ERROR; - } - SIMDJSON_TRY(visitor.increment_count(*this)); - SIMDJSON_TRY(visitor.visit_key(*this, key)); - } - -object_field: - if (simdjson_unlikely(*advance() != ':')) { - log_error("Missing colon after key in object"); - return TAPE_ERROR; - } - { - auto value = advance(); - switch (*value) { - case '{': - if (*peek() == '}') { - advance(); - log_value("empty object"); - SIMDJSON_TRY(visitor.visit_empty_object(*this)); - break; - } - goto object_begin; - case '[': - if (*peek() == ']') { - advance(); - log_value("empty array"); - SIMDJSON_TRY(visitor.visit_empty_array(*this)); - break; - } - goto array_begin; - default: - SIMDJSON_TRY(visitor.visit_primitive(*this, value)); - break; - } - } - -object_continue: - switch (*advance()) { - case ',': - SIMDJSON_TRY(visitor.increment_count(*this)); - { - auto key = advance(); - if (simdjson_unlikely(*key != '"')) { - log_error( - "Key string missing at beginning of field in object"); - return TAPE_ERROR; - } - SIMDJSON_TRY(visitor.visit_key(*this, key)); - } - goto object_field; - case '}': - log_end_value("object"); - SIMDJSON_TRY(visitor.visit_object_end(*this)); - goto scope_end; - default: - log_error("No comma between object fields"); - return TAPE_ERROR; - } - -scope_end: - depth--; - if (depth == 0) { - goto document_end; - } - if (dom_parser.is_array[depth]) { - goto array_continue; - } - goto object_continue; - -// -// Array parser states -// -array_begin: - log_start_value("array"); - depth++; - if (depth >= dom_parser.max_depth()) { - log_error("Exceeded max depth!"); - return DEPTH_ERROR; - } - dom_parser.is_array[depth] = true; - SIMDJSON_TRY(visitor.visit_array_start(*this)); - SIMDJSON_TRY(visitor.increment_count(*this)); - -array_value : { - auto value = advance(); - switch (*value) { - case '{': - if (*peek() == '}') { - advance(); - log_value("empty object"); - SIMDJSON_TRY(visitor.visit_empty_object(*this)); - break; - } - goto object_begin; - case '[': - if (*peek() == ']') { - advance(); - log_value("empty array"); - SIMDJSON_TRY(visitor.visit_empty_array(*this)); - break; - } - goto array_begin; - default: - SIMDJSON_TRY(visitor.visit_primitive(*this, value)); - break; - } -} - -array_continue: - switch (*advance()) { - case ',': - SIMDJSON_TRY(visitor.increment_count(*this)); - goto array_value; - case ']': - log_end_value("array"); - SIMDJSON_TRY(visitor.visit_array_end(*this)); - goto scope_end; - default: - log_error("Missing comma between array values"); - return TAPE_ERROR; - } - -document_end: - log_end_value("document"); - SIMDJSON_TRY(visitor.visit_document_end(*this)); - - dom_parser.next_structural_index = - uint32_t(next_structural - &dom_parser.structural_indexes[0]); - - // If we didn't make it to the end, it's an error - if (!STREAMING && - dom_parser.next_structural_index != dom_parser.n_structural_indexes) { - log_error( - "More than one JSON value at the root of the document, or extra " - "characters at the end of the JSON!"); - return TAPE_ERROR; - } - - return SUCCESS; - -} // walk_document() - -simdjson_really_inline json_iterator::json_iterator( - dom_parser_implementation &_dom_parser, size_t start_structural_index) - : buf{_dom_parser.buf}, - next_structural{&_dom_parser.structural_indexes[start_structural_index]}, - dom_parser{_dom_parser} {} - -simdjson_really_inline const uint8_t *json_iterator::peek() const noexcept { - return &buf[*(next_structural)]; -} -simdjson_really_inline const uint8_t *json_iterator::advance() noexcept { - return &buf[*(next_structural++)]; -} -simdjson_really_inline size_t json_iterator::remaining_len() const noexcept { - return dom_parser.len - *(next_structural - 1); -} - -simdjson_really_inline bool json_iterator::at_eof() const noexcept { - return next_structural == - &dom_parser.structural_indexes[dom_parser.n_structural_indexes]; -} -simdjson_really_inline bool json_iterator::at_beginning() const noexcept { - return next_structural == dom_parser.structural_indexes.get(); -} -simdjson_really_inline uint8_t json_iterator::last_structural() const noexcept { - return buf[dom_parser - .structural_indexes[dom_parser.n_structural_indexes - 1]]; -} - -simdjson_really_inline void json_iterator::log_value(const char *type) const - noexcept { - logger::log_line(*this, "", type, ""); -} - -simdjson_really_inline void json_iterator::log_start_value( - const char *type) const noexcept { - logger::log_line(*this, "+", type, ""); - if (logger::LOG_ENABLED) { - logger::log_depth++; - } -} - -simdjson_really_inline void json_iterator::log_end_value(const char *type) const - noexcept { - if (logger::LOG_ENABLED) { - logger::log_depth--; - } - logger::log_line(*this, "-", type, ""); -} - -simdjson_really_inline void json_iterator::log_error(const char *error) const - noexcept { - logger::log_line(*this, "", "ERROR", error); -} - -template -simdjson_warn_unused simdjson_really_inline error_code -json_iterator::visit_root_primitive(V &visitor, const uint8_t *value) noexcept { - switch (*value) { - case '"': - return visitor.visit_root_string(*this, value); - case 't': - return visitor.visit_root_true_atom(*this, value); - case 'f': - return visitor.visit_root_false_atom(*this, value); - case 'n': - return visitor.visit_root_null_atom(*this, value); - case '-': - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - return visitor.visit_root_number(*this, value); - default: - log_error("Document starts with a non-value character"); - return TAPE_ERROR; - } -} -template -simdjson_warn_unused simdjson_really_inline error_code -json_iterator::visit_primitive(V &visitor, const uint8_t *value) noexcept { - switch (*value) { - case '"': - return visitor.visit_string(*this, value); - case 't': - return visitor.visit_true_atom(*this, value); - case 'f': - return visitor.visit_false_atom(*this, value); - case 'n': - return visitor.visit_null_atom(*this, value); - case '-': - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - return visitor.visit_number(*this, value); - default: - log_error("Non-value found when value was expected!"); - return TAPE_ERROR; - } -} - -} // namespace stage2 -} // unnamed namespace -} // namespace arm64 -} // namespace simdjson -/* end file src/generic/stage2/json_iterator.h */ -/* begin file src/generic/stage2/tape_writer.h */ -namespace simdjson { -namespace arm64 { -namespace { -namespace stage2 { - -struct tape_writer { - /** The next place to write to tape */ - uint64_t *next_tape_loc; - - /** Write a signed 64-bit value to tape. */ - simdjson_really_inline void append_s64(int64_t value) noexcept; - - /** Write an unsigned 64-bit value to tape. */ - simdjson_really_inline void append_u64(uint64_t value) noexcept; - - /** Write a double value to tape. */ - simdjson_really_inline void append_double(double value) noexcept; - - /** - * Append a tape entry (an 8-bit type,and 56 bits worth of value). - */ - simdjson_really_inline void append(uint64_t val, - internal::tape_type t) noexcept; - - /** - * Skip the current tape entry without writing. - * - * Used to skip the start of the container, since we'll come back later to - * fill it in when the - * container ends. - */ - simdjson_really_inline void skip() noexcept; - - /** - * Skip the number of tape entries necessary to write a large u64 or i64. - */ - simdjson_really_inline void skip_large_integer() noexcept; - - /** - * Skip the number of tape entries necessary to write a double. - */ - simdjson_really_inline void skip_double() noexcept; - - /** - * Write a value to a known location on tape. - * - * Used to go back and write out the start of a container after the - * container ends. - */ - simdjson_really_inline static void write(uint64_t &tape_loc, - uint64_t val, - internal::tape_type t) noexcept; - - private: - /** - * Append both the tape entry, and a supplementary value following it. Used - * for types that need - * all 64 bits, such as double and uint64_t. - */ - template - simdjson_really_inline void append2(uint64_t val, - T val2, - internal::tape_type t) noexcept; -}; // struct number_writer - -simdjson_really_inline void tape_writer::append_s64(int64_t value) noexcept { - append2(0, value, internal::tape_type::INT64); -} - -simdjson_really_inline void tape_writer::append_u64(uint64_t value) noexcept { - append(0, internal::tape_type::UINT64); - *next_tape_loc = value; - next_tape_loc++; -} - -/** Write a double value to tape. */ -simdjson_really_inline void tape_writer::append_double(double value) noexcept { - append2(0, value, internal::tape_type::DOUBLE); -} - -simdjson_really_inline void tape_writer::skip() noexcept { next_tape_loc++; } - -simdjson_really_inline void tape_writer::skip_large_integer() noexcept { - next_tape_loc += 2; -} - -simdjson_really_inline void tape_writer::skip_double() noexcept { - next_tape_loc += 2; -} - -simdjson_really_inline void tape_writer::append( - uint64_t val, internal::tape_type t) noexcept { - *next_tape_loc = val | ((uint64_t(char(t))) << 56); - next_tape_loc++; -} - -template -simdjson_really_inline void tape_writer::append2( - uint64_t val, T val2, internal::tape_type t) noexcept { - append(val, t); - static_assert(sizeof(val2) == sizeof(*next_tape_loc), - "Type is not 64 bits!"); - memcpy(next_tape_loc, &val2, sizeof(val2)); - next_tape_loc++; -} - -simdjson_really_inline void tape_writer::write(uint64_t &tape_loc, - uint64_t val, - internal::tape_type t) noexcept { - tape_loc = val | ((uint64_t(char(t))) << 56); -} - -} // namespace stage2 -} // unnamed namespace -} // namespace arm64 -} // namespace simdjson -/* end file src/generic/stage2/tape_writer.h */ - -namespace simdjson { -namespace arm64 { -namespace { -namespace stage2 { - -struct tape_builder { - template - simdjson_warn_unused static simdjson_really_inline error_code - parse_document(dom_parser_implementation &dom_parser, - dom::document &doc) noexcept; - - /** Called when a non-empty document starts. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_document_start(json_iterator &iter) noexcept; - /** Called when a non-empty document ends without error. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_document_end(json_iterator &iter) noexcept; - - /** Called when a non-empty array starts. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_array_start(json_iterator &iter) noexcept; - /** Called when a non-empty array ends. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_array_end(json_iterator &iter) noexcept; - /** Called when an empty array is found. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_empty_array(json_iterator &iter) noexcept; - - /** Called when a non-empty object starts. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_object_start(json_iterator &iter) noexcept; - /** - * Called when a key in a field is encountered. - * - * primitive, visit_object_start, visit_empty_object, visit_array_start, or - * visit_empty_array - * will be called after this with the field value. - */ - simdjson_warn_unused simdjson_really_inline error_code - visit_key(json_iterator &iter, const uint8_t *key) noexcept; - /** Called when a non-empty object ends. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_object_end(json_iterator &iter) noexcept; - /** Called when an empty object is found. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_empty_object(json_iterator &iter) noexcept; - - /** - * Called when a string, number, boolean or null is found. - */ - simdjson_warn_unused simdjson_really_inline error_code - visit_primitive(json_iterator &iter, const uint8_t *value) noexcept; - /** - * Called when a string, number, boolean or null is found at the top level - * of a document (i.e. - * when there is no array or object and the entire document is a single - * string, number, boolean or - * null. - * - * This is separate from primitive() because simdjson's normal primitive - * parsing routines assume - * there is at least one more token after the value, which is only true in - * an array or object. - */ - simdjson_warn_unused simdjson_really_inline error_code - visit_root_primitive(json_iterator &iter, const uint8_t *value) noexcept; - - simdjson_warn_unused simdjson_really_inline error_code visit_string( - json_iterator &iter, const uint8_t *value, bool key = false) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_number(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_true_atom(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_false_atom(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_null_atom(json_iterator &iter, const uint8_t *value) noexcept; - - simdjson_warn_unused simdjson_really_inline error_code - visit_root_string(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_root_number(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_root_true_atom(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_root_false_atom(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_root_null_atom(json_iterator &iter, const uint8_t *value) noexcept; - - /** Called each time a new field or element in an array or object is found. - */ - simdjson_warn_unused simdjson_really_inline error_code - increment_count(json_iterator &iter) noexcept; - - /** Next location to write to tape */ - tape_writer tape; - - private: - /** Next write location in the string buf for stage 2 parsing */ - uint8_t *current_string_buf_loc; - - simdjson_really_inline tape_builder(dom::document &doc) noexcept; - - simdjson_really_inline uint32_t next_tape_index(json_iterator &iter) const - noexcept; - simdjson_really_inline void start_container(json_iterator &iter) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - end_container(json_iterator &iter, - internal::tape_type start, - internal::tape_type end) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - empty_container(json_iterator &iter, - internal::tape_type start, - internal::tape_type end) noexcept; - simdjson_really_inline uint8_t *on_start_string( - json_iterator &iter) noexcept; - simdjson_really_inline void on_end_string(uint8_t *dst) noexcept; -}; // class tape_builder - -template -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::parse_document(dom_parser_implementation &dom_parser, - dom::document &doc) noexcept { - dom_parser.doc = &doc; - json_iterator iter(dom_parser, - STREAMING ? dom_parser.next_structural_index : 0); - tape_builder builder(doc); - return iter.walk_document(builder); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_primitive(json_iterator &iter, - const uint8_t *value) noexcept { - return iter.visit_root_primitive(*this, value); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_primitive(json_iterator &iter, - const uint8_t *value) noexcept { - return iter.visit_primitive(*this, value); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_empty_object(json_iterator &iter) noexcept { - return empty_container(iter, - internal::tape_type::START_OBJECT, - internal::tape_type::END_OBJECT); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_empty_array(json_iterator &iter) noexcept { - return empty_container( - iter, internal::tape_type::START_ARRAY, internal::tape_type::END_ARRAY); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_document_start(json_iterator &iter) noexcept { - start_container(iter); - return SUCCESS; -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_object_start(json_iterator &iter) noexcept { - start_container(iter); - return SUCCESS; -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_array_start(json_iterator &iter) noexcept { - start_container(iter); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_object_end(json_iterator &iter) noexcept { - return end_container(iter, - internal::tape_type::START_OBJECT, - internal::tape_type::END_OBJECT); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_array_end(json_iterator &iter) noexcept { - return end_container( - iter, internal::tape_type::START_ARRAY, internal::tape_type::END_ARRAY); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_document_end(json_iterator &iter) noexcept { - constexpr uint32_t start_tape_index = 0; - tape.append(start_tape_index, internal::tape_type::ROOT); - tape_writer::write(iter.dom_parser.doc->tape[start_tape_index], - next_tape_index(iter), - internal::tape_type::ROOT); - return SUCCESS; -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_key(json_iterator &iter, const uint8_t *key) noexcept { - return visit_string(iter, key, true); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::increment_count(json_iterator &iter) noexcept { - iter.dom_parser.open_containers[iter.depth] - .count++; // we have a key value pair in the object at - // parser.dom_parser.depth - 1 - return SUCCESS; -} - -simdjson_really_inline tape_builder::tape_builder(dom::document &doc) noexcept - : tape{doc.tape.get()}, - current_string_buf_loc{doc.string_buf.get()} {} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_string(json_iterator &iter, - const uint8_t *value, - bool key) noexcept { - iter.log_value(key ? "key" : "string"); - uint8_t *dst = on_start_string(iter); - dst = stringparsing::parse_string(value + 1, dst); - if (dst == nullptr) { - iter.log_error("Invalid escape in string"); - return STRING_ERROR; - } - on_end_string(dst); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_string(json_iterator &iter, - const uint8_t *value) noexcept { - return visit_string(iter, value); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_number(json_iterator &iter, const uint8_t *value) noexcept { - iter.log_value("number"); - return numberparsing::parse_number(value, tape); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_number(json_iterator &iter, - const uint8_t *value) noexcept { - // - // We need to make a copy to make sure that the string is space terminated. - // This is not about padding the input, which should already padded up - // to len + SIMDJSON_PADDING. However, we have no control at this stage - // on how the padding was done. What if the input string was padded with - // nulls? - // It is quite common for an input string to have an extra null character (C - // string). - // We do not want to allow 9\0 (where \0 is the null character) inside a - // JSON - // document, but the string "9\0" by itself is fine. So we make a copy and - // pad the input with spaces when we know that there is just one input - // element. - // This copy is relatively expensive, but it will almost never be called in - // practice unless you are in the strange scenario where you have many JSON - // documents made of single atoms. - // - std::unique_ptr copy( - new (std::nothrow) uint8_t[iter.remaining_len() + SIMDJSON_PADDING]); - if (copy.get() == nullptr) { - return MEMALLOC; - } - std::memcpy(copy.get(), value, iter.remaining_len()); - std::memset(copy.get() + iter.remaining_len(), ' ', SIMDJSON_PADDING); - error_code error = visit_number(iter, copy.get()); - return error; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_true_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("true"); - if (!atomparsing::is_valid_true_atom(value)) { - return T_ATOM_ERROR; - } - tape.append(0, internal::tape_type::TRUE_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_true_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("true"); - if (!atomparsing::is_valid_true_atom(value, iter.remaining_len())) { - return T_ATOM_ERROR; - } - tape.append(0, internal::tape_type::TRUE_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_false_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("false"); - if (!atomparsing::is_valid_false_atom(value)) { - return F_ATOM_ERROR; - } - tape.append(0, internal::tape_type::FALSE_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_false_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("false"); - if (!atomparsing::is_valid_false_atom(value, iter.remaining_len())) { - return F_ATOM_ERROR; - } - tape.append(0, internal::tape_type::FALSE_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_null_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("null"); - if (!atomparsing::is_valid_null_atom(value)) { - return N_ATOM_ERROR; - } - tape.append(0, internal::tape_type::NULL_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_null_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("null"); - if (!atomparsing::is_valid_null_atom(value, iter.remaining_len())) { - return N_ATOM_ERROR; - } - tape.append(0, internal::tape_type::NULL_VALUE); - return SUCCESS; -} - -// private: - -simdjson_really_inline uint32_t -tape_builder::next_tape_index(json_iterator &iter) const noexcept { - return uint32_t(tape.next_tape_loc - iter.dom_parser.doc->tape.get()); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::empty_container(json_iterator &iter, - internal::tape_type start, - internal::tape_type end) noexcept { - auto start_index = next_tape_index(iter); - tape.append(start_index + 2, start); - tape.append(start_index, end); - return SUCCESS; -} - -simdjson_really_inline void tape_builder::start_container( - json_iterator &iter) noexcept { - iter.dom_parser.open_containers[iter.depth].tape_index = - next_tape_index(iter); - iter.dom_parser.open_containers[iter.depth].count = 0; - tape.skip(); // We don't actually *write* the start element until the end. -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::end_container(json_iterator &iter, - internal::tape_type start, - internal::tape_type end) noexcept { - // Write the ending tape element, pointing at the start location - const uint32_t start_tape_index = - iter.dom_parser.open_containers[iter.depth].tape_index; - tape.append(start_tape_index, end); - // Write the start tape element, pointing at the end location (and including - // count) - // count can overflow if it exceeds 24 bits... so we saturate - // the convention being that a cnt of 0xffffff or more is undetermined in - // value (>= 0xffffff). - const uint32_t count = iter.dom_parser.open_containers[iter.depth].count; - const uint32_t cntsat = count > 0xFFFFFF ? 0xFFFFFF : count; - tape_writer::write(iter.dom_parser.doc->tape[start_tape_index], - next_tape_index(iter) | (uint64_t(cntsat) << 32), - start); - return SUCCESS; -} - -simdjson_really_inline uint8_t *tape_builder::on_start_string( - json_iterator &iter) noexcept { - // we advance the point, accounting for the fact that we have a NULL - // termination - tape.append(current_string_buf_loc - iter.dom_parser.doc->string_buf.get(), - internal::tape_type::STRING); - return current_string_buf_loc + sizeof(uint32_t); -} - -simdjson_really_inline void tape_builder::on_end_string(uint8_t *dst) noexcept { - uint32_t str_length = - uint32_t(dst - (current_string_buf_loc + sizeof(uint32_t))); - // TODO check for overflow in case someone has a crazy string (>=4GB?) - // But only add the overflow check when the document itself exceeds 4GB - // Currently unneeded because we refuse to parse docs larger or equal to - // 4GB. - memcpy(current_string_buf_loc, &str_length, sizeof(uint32_t)); - // NULL termination is still handy if you expect all your strings to - // be NULL terminated? It comes at a small cost - *dst = 0; - current_string_buf_loc = dst + 1; -} - -} // namespace stage2 -} // unnamed namespace -} // namespace arm64 -} // namespace simdjson -/* end file src/generic/stage2/tape_builder.h */ - -// -// Implementation-specific overrides -// -namespace simdjson { -namespace arm64 { -namespace { -namespace stage1 { - -simdjson_really_inline uint64_t -json_string_scanner::find_escaped(uint64_t backslash) { - // On ARM, we don't short-circuit this if there are no backslashes, because - // the branch gives us no - // benefit and therefore makes things worse. - // if (!backslash) { uint64_t escaped = prev_escaped; prev_escaped = 0; - // return escaped; } - return find_escaped_branchless(backslash); -} - -} // namespace stage1 -} // unnamed namespace - -simdjson_warn_unused error_code implementation::minify(const uint8_t *buf, - size_t len, - uint8_t *dst, - size_t &dst_len) const - noexcept { - return arm64::stage1::json_minifier::minify<64>(buf, len, dst, dst_len); -} - -simdjson_warn_unused error_code dom_parser_implementation::stage1( - const uint8_t *_buf, size_t _len, stage1_mode streaming) noexcept { - this->buf = _buf; - this->len = _len; - return arm64::stage1::json_structural_indexer::index<64>( - buf, len, *this, streaming); -} - -simdjson_warn_unused bool implementation::validate_utf8(const char *buf, - size_t len) const - noexcept { - return arm64::stage1::generic_validate_utf8(buf, len); -} - -simdjson_warn_unused error_code -dom_parser_implementation::stage2(dom::document &_doc) noexcept { - return stage2::tape_builder::parse_document(*this, _doc); -} - -simdjson_warn_unused error_code -dom_parser_implementation::stage2_next(dom::document &_doc) noexcept { - return stage2::tape_builder::parse_document(*this, _doc); -} - -simdjson_warn_unused error_code dom_parser_implementation::parse( - const uint8_t *_buf, size_t _len, dom::document &_doc) noexcept { - auto error = stage1(_buf, _len, stage1_mode::regular); - if (error) { - return error; - } - return stage2(_doc); -} - -} // namespace arm64 -} // namespace simdjson - -/* begin file include/simdjson/arm64/end.h */ -/* end file include/simdjson/arm64/end.h */ -/* end file src/arm64/dom_parser_implementation.cpp */ -#endif -#if SIMDJSON_IMPLEMENTATION_FALLBACK -/* begin file src/fallback/implementation.cpp */ -/* begin file include/simdjson/fallback/begin.h */ -// redefining SIMDJSON_IMPLEMENTATION to "fallback" -// #define SIMDJSON_IMPLEMENTATION fallback -/* end file include/simdjson/fallback/begin.h */ - -namespace simdjson { -namespace fallback { - -simdjson_warn_unused error_code -implementation::create_dom_parser_implementation( - size_t capacity, - size_t max_depth, - std::unique_ptr &dst) const noexcept { - dst.reset(new (std::nothrow) dom_parser_implementation()); - if (!dst) { - return MEMALLOC; - } - if (auto err = dst->set_capacity(capacity)) return err; - if (auto err = dst->set_max_depth(max_depth)) return err; - return SUCCESS; -} - -} // namespace fallback -} // namespace simdjson - -/* begin file include/simdjson/fallback/end.h */ -/* end file include/simdjson/fallback/end.h */ -/* end file src/fallback/implementation.cpp */ -/* begin file src/fallback/dom_parser_implementation.cpp */ -/* begin file include/simdjson/fallback/begin.h */ -// redefining SIMDJSON_IMPLEMENTATION to "fallback" -// #define SIMDJSON_IMPLEMENTATION fallback -/* end file include/simdjson/fallback/begin.h */ - -// -// Stage 1 -// -/* begin file src/generic/stage1/find_next_document_index.h */ -namespace simdjson { -namespace fallback { -namespace { - -/** - * This algorithm is used to quickly identify the last structural position that - * makes up a complete document. - * - * It does this by going backwards and finding the last *document boundary* (a - * place where one value follows another without a comma between them). If the - * last document (the characters after the boundary) has an equal number of - * start and end brackets, it is considered complete. - * - * Simply put, we iterate over the structural characters, starting from - * the end. We consider that we found the end of a JSON document when the - * first element of the pair is NOT one of these characters: '{' '[' ':' ',' - * and when the second element is NOT one of these characters: '}' ']' ':' ','. - * - * This simple comparison works most of the time, but it does not cover cases - * where the batch's structural indexes contain a perfect amount of documents. - * In such a case, we do not have access to the structural index which follows - * the last document, therefore, we do not have access to the second element in - * the pair, and that means we cannot identify the last document. To fix this - * issue, we keep a count of the open and closed curly/square braces we found - * while searching for the pair. When we find a pair AND the count of open and - * closed curly/square braces is the same, we know that we just passed a - * complete document, therefore the last json buffer location is the end of the - * batch. - */ -simdjson_really_inline uint32_t -find_next_document_index(dom_parser_implementation &parser) { - // Variant: do not count separately, just figure out depth - if (parser.n_structural_indexes == 0) { - return 0; - } - auto arr_cnt = 0; - auto obj_cnt = 0; - for (auto i = parser.n_structural_indexes - 1; i > 0; i--) { - auto idxb = parser.structural_indexes[i]; - switch (parser.buf[idxb]) { - case ':': - case ',': - continue; - case '}': - obj_cnt--; - continue; - case ']': - arr_cnt--; - continue; - case '{': - obj_cnt++; - break; - case '[': - arr_cnt++; - break; - } - auto idxa = parser.structural_indexes[i - 1]; - switch (parser.buf[idxa]) { - case '{': - case '[': - case ':': - case ',': - continue; - } - // Last document is complete, so the next document will appear after! - if (!arr_cnt && !obj_cnt) { - return parser.n_structural_indexes; - } - // Last document is incomplete; mark the document at i + 1 as the next - // one - return i; - } - // If we made it to the end, we want to finish counting to see if we have a - // full document. - switch (parser.buf[parser.structural_indexes[0]]) { - case '}': - obj_cnt--; - break; - case ']': - arr_cnt--; - break; - case '{': - obj_cnt++; - break; - case '[': - arr_cnt++; - break; - } - if (!arr_cnt && !obj_cnt) { - // We have a complete document. - return parser.n_structural_indexes; - } - return 0; -} - -} // unnamed namespace -} // namespace fallback -} // namespace simdjson -/* end file src/generic/stage1/find_next_document_index.h */ - -namespace simdjson { -namespace fallback { -namespace { -namespace stage1 { - -class structural_scanner { - public: - simdjson_really_inline structural_scanner( - dom_parser_implementation &_parser, stage1_mode _partial) - : buf{_parser.buf}, - next_structural_index{_parser.structural_indexes.get()}, - parser{_parser}, - len{static_cast(_parser.len)}, - partial{_partial} {} - - simdjson_really_inline void add_structural() { - *next_structural_index = idx; - next_structural_index++; - } - - simdjson_really_inline bool is_continuation(uint8_t c) { - return (c & 0b11000000) == 0b10000000; - } - - simdjson_really_inline void validate_utf8_character() { - // Continuation - if (simdjson_unlikely((buf[idx] & 0b01000000) == 0)) { - // extra continuation - error = UTF8_ERROR; - idx++; - return; - } - - // 2-byte - if ((buf[idx] & 0b00100000) == 0) { - // missing continuation - if (simdjson_unlikely(idx + 1 > len || - !is_continuation(buf[idx + 1]))) { - if (idx + 1 > len && is_streaming(partial)) { - idx = len; - return; - } - error = UTF8_ERROR; - idx++; - return; - } - // overlong: 1100000_ 10______ - if (buf[idx] <= 0b11000001) { - error = UTF8_ERROR; - } - idx += 2; - return; - } - - // 3-byte - if ((buf[idx] & 0b00010000) == 0) { - // missing continuation - if (simdjson_unlikely(idx + 2 > len || - !is_continuation(buf[idx + 1]) || - !is_continuation(buf[idx + 2]))) { - if (idx + 2 > len && is_streaming(partial)) { - idx = len; - return; - } - error = UTF8_ERROR; - idx++; - return; - } - // overlong: 11100000 100_____ ________ - if (buf[idx] == 0b11100000 && buf[idx + 1] <= 0b10011111) { - error = UTF8_ERROR; - } - // surrogates: U+D800-U+DFFF 11101101 101_____ - if (buf[idx] == 0b11101101 && buf[idx + 1] >= 0b10100000) { - error = UTF8_ERROR; - } - idx += 3; - return; - } - - // 4-byte - // missing continuation - if (simdjson_unlikely(idx + 3 > len || !is_continuation(buf[idx + 1]) || - !is_continuation(buf[idx + 2]) || - !is_continuation(buf[idx + 3]))) { - if (idx + 2 > len && is_streaming(partial)) { - idx = len; - return; - } - error = UTF8_ERROR; - idx++; - return; - } - // overlong: 11110000 1000____ ________ ________ - if (buf[idx] == 0b11110000 && buf[idx + 1] <= 0b10001111) { - error = UTF8_ERROR; - } - // too large: > U+10FFFF: - // 11110100 (1001|101_)____ - // 1111(1___|011_|0101) 10______ - // also includes 5, 6, 7 and 8 byte characters: - // 11111___ - if (buf[idx] == 0b11110100 && buf[idx + 1] >= 0b10010000) { - error = UTF8_ERROR; - } - if (buf[idx] >= 0b11110101) { - error = UTF8_ERROR; - } - idx += 4; - } - - // Returns true if the string is unclosed. - simdjson_really_inline bool validate_string() { - idx++; // skip first quote - while (idx < len && buf[idx] != '"') { - if (buf[idx] == '\\') { - idx += 2; - } else if (simdjson_unlikely(buf[idx] & 0b10000000)) { - validate_utf8_character(); - } else { - if (buf[idx] < 0x20) { - error = UNESCAPED_CHARS; - } - idx++; - } - } - if (idx >= len) { - return true; - } - return false; - } - - simdjson_really_inline bool is_whitespace_or_operator(uint8_t c) { - switch (c) { - case '{': - case '}': - case '[': - case ']': - case ',': - case ':': - case ' ': - case '\r': - case '\n': - case '\t': - return true; - default: - return false; - } - } - - // - // Parse the entire input in STEP_SIZE-byte chunks. - // - simdjson_really_inline error_code scan() { - bool unclosed_string = false; - for (; idx < len; idx++) { - switch (buf[idx]) { - // String - case '"': - add_structural(); - unclosed_string |= validate_string(); - break; - // Operator - case '{': - case '}': - case '[': - case ']': - case ',': - case ':': - add_structural(); - break; - // Whitespace - case ' ': - case '\r': - case '\n': - case '\t': - break; - // Primitive or invalid character (invalid characters will be - // checked in stage 2) - default: - // Anything else, add the structural and go until we find - // the next one - add_structural(); - while (idx + 1 < len && - !is_whitespace_or_operator(buf[idx + 1])) { - idx++; - }; - break; - } - } - // We pad beyond. - // https://github.com/simdjson/simdjson/issues/906 - // See json_structural_indexer.h for an explanation. - *next_structural_index = - len; // assumed later in partial == stage1_mode::streaming_final - next_structural_index[1] = len; - next_structural_index[2] = 0; - parser.n_structural_indexes = - uint32_t(next_structural_index - parser.structural_indexes.get()); - if (simdjson_unlikely(parser.n_structural_indexes == 0)) { - return EMPTY; - } - parser.next_structural_index = 0; - if (partial == stage1_mode::streaming_partial) { - if (unclosed_string) { - parser.n_structural_indexes--; - if (simdjson_unlikely(parser.n_structural_indexes == 0)) { - return CAPACITY; - } - } - // We truncate the input to the end of the last complete document - // (or zero). - auto new_structural_indexes = find_next_document_index(parser); - if (new_structural_indexes == 0 && - parser.n_structural_indexes > 0) { - if (parser.structural_indexes[0] == 0) { - // If the buffer is partial and we started at index 0 but - // the document is - // incomplete, it's too big to parse. - return CAPACITY; - } else { - // It is possible that the document could be parsed, we just - // had a lot - // of white space. - parser.n_structural_indexes = 0; - return EMPTY; - } - } - parser.n_structural_indexes = new_structural_indexes; - } else if (partial == stage1_mode::streaming_final) { - if (unclosed_string) { - parser.n_structural_indexes--; - } - // We truncate the input to the end of the last complete document - // (or zero). - // Because partial == stage1_mode::streaming_final, it means that we - // may - // silently ignore trailing garbage. Though it sounds bad, we do it - // deliberately because many people who have streams of JSON - // documents - // will truncate them for processing. E.g., imagine that you are - // uncompressing - // the data from a size file or receiving it in chunks from the - // network. You - // may not know where exactly the last document will be. Meanwhile - // the - // document_stream instances allow people to know the JSON documents - // they are - // parsing (see the iterator.source() method). - parser.n_structural_indexes = find_next_document_index(parser); - // We store the initial n_structural_indexes so that the client can - // see - // whether we used truncation. If initial_n_structural_indexes == - // parser.n_structural_indexes, - // then this will query - // parser.structural_indexes[parser.n_structural_indexes] which is - // len, - // otherwise, it will copy some prior index. - parser.structural_indexes[parser.n_structural_indexes + 1] = - parser.structural_indexes[parser.n_structural_indexes]; - // This next line is critical, do not change it unless you - // understand what you are - // doing. - parser.structural_indexes[parser.n_structural_indexes] = - uint32_t(len); - if (parser.n_structural_indexes == 0) { - return EMPTY; - } - } else if (unclosed_string) { - error = UNCLOSED_STRING; - } - return error; - } - - private: - const uint8_t *buf; - uint32_t *next_structural_index; - dom_parser_implementation &parser; - uint32_t len; - uint32_t idx{0}; - error_code error{SUCCESS}; - stage1_mode partial; -}; // structural_scanner - -} // namespace stage1 -} // unnamed namespace - -simdjson_warn_unused error_code dom_parser_implementation::stage1( - const uint8_t *_buf, size_t _len, stage1_mode partial) noexcept { - this->buf = _buf; - this->len = _len; - stage1::structural_scanner scanner(*this, partial); - return scanner.scan(); -} - -// big table for the minifier -static uint8_t jump_table[256 * 3] = { - 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, - 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, - 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, - 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, - 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, - 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, - 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, - 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, - 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, - 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, - 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, - 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, - 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, - 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, - 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, - 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, - 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, - 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, - 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, - 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, - 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, - 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, - 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, - 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, - 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, - 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, - 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, - 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, - 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, - 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, - 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, -}; - -simdjson_warn_unused error_code implementation::minify(const uint8_t *buf, - size_t len, - uint8_t *dst, - size_t &dst_len) const - noexcept { - size_t i = 0, pos = 0; - uint8_t quote = 0; - uint8_t nonescape = 1; - - while (i < len) { - unsigned char c = buf[i]; - uint8_t *meta = jump_table + 3 * c; - - quote = quote ^ (meta[0] & nonescape); - dst[pos] = c; - pos += meta[2] | quote; - - i += 1; - nonescape = uint8_t(~nonescape) | (meta[1]); - } - dst_len = pos; // we intentionally do not work with a reference - // for fear of aliasing - return quote ? UNCLOSED_STRING : SUCCESS; -} - -// credit: based on code from Google Fuchsia (Apache Licensed) -simdjson_warn_unused bool implementation::validate_utf8(const char *buf, - size_t len) const - noexcept { - const uint8_t *data = reinterpret_cast(buf); - uint64_t pos = 0; - uint32_t code_point = 0; - while (pos < len) { - // check of the next 8 bytes are ascii. - uint64_t next_pos = pos + 16; - if (next_pos <= len) { // if it is safe to read 8 more bytes, check - // that they are ascii - uint64_t v1; - memcpy(&v1, data + pos, sizeof(uint64_t)); - uint64_t v2; - memcpy(&v2, data + pos + sizeof(uint64_t), sizeof(uint64_t)); - uint64_t v{v1 | v2}; - if ((v & 0x8080808080808080) == 0) { - pos = next_pos; - continue; - } - } - unsigned char byte = data[pos]; - if (byte < 0b10000000) { - pos++; - continue; - } else if ((byte & 0b11100000) == 0b11000000) { - next_pos = pos + 2; - if (next_pos > len) { - return false; - } - if ((data[pos + 1] & 0b11000000) != 0b10000000) { - return false; - } - // range check - code_point = - (byte & 0b00011111) << 6 | (data[pos + 1] & 0b00111111); - if (code_point < 0x80 || 0x7ff < code_point) { - return false; - } - } else if ((byte & 0b11110000) == 0b11100000) { - next_pos = pos + 3; - if (next_pos > len) { - return false; - } - if ((data[pos + 1] & 0b11000000) != 0b10000000) { - return false; - } - if ((data[pos + 2] & 0b11000000) != 0b10000000) { - return false; - } - // range check - code_point = (byte & 0b00001111) << 12 | - (data[pos + 1] & 0b00111111) << 6 | - (data[pos + 2] & 0b00111111); - if (code_point < 0x800 || 0xffff < code_point || - (0xd7ff < code_point && code_point < 0xe000)) { - return false; - } - } else if ((byte & 0b11111000) == 0b11110000) { // 0b11110000 - next_pos = pos + 4; - if (next_pos > len) { - return false; - } - if ((data[pos + 1] & 0b11000000) != 0b10000000) { - return false; - } - if ((data[pos + 2] & 0b11000000) != 0b10000000) { - return false; - } - if ((data[pos + 3] & 0b11000000) != 0b10000000) { - return false; - } - // range check - code_point = (byte & 0b00000111) << 18 | - (data[pos + 1] & 0b00111111) << 12 | - (data[pos + 2] & 0b00111111) << 6 | - (data[pos + 3] & 0b00111111); - if (code_point <= 0xffff || 0x10ffff < code_point) { - return false; - } - } else { - // we may have a continuation - return false; - } - pos = next_pos; - } - return true; -} - -} // namespace fallback -} // namespace simdjson - -// -// Stage 2 -// -/* begin file src/generic/stage2/tape_builder.h */ -/* begin file src/generic/stage2/json_iterator.h */ -/* begin file src/generic/stage2/logger.h */ -// This is for an internal-only stage 2 specific logger. -// Set LOG_ENABLED = true to log what stage 2 is doing! -namespace simdjson { -namespace fallback { -namespace { -namespace logger { - -static constexpr const char *DASHES = - "--------------------------------------------------------------------------" - "--------------------------------------------------------------------------" - "--------------------------------------------------------------------------" - "----------------------------------"; - -#if SIMDJSON_VERBOSE_LOGGING -static constexpr const bool LOG_ENABLED = true; -#else -static constexpr const bool LOG_ENABLED = false; -#endif -static constexpr const int LOG_EVENT_LEN = 20; -static constexpr const int LOG_BUFFER_LEN = 30; -static constexpr const int LOG_SMALL_BUFFER_LEN = 10; -static constexpr const int LOG_INDEX_LEN = 5; - -static int log_depth; // Not threadsafe. Log only. - -// Helper to turn unprintable or newline characters into spaces -static simdjson_really_inline char printable_char(char c) { - if (c >= 0x20) { - return c; - } else { - return ' '; - } -} - -// Print the header and set up log_start -static simdjson_really_inline void log_start() { - if (LOG_ENABLED) { - log_depth = 0; - printf("\n"); - printf("| %-*s | %-*s | %-*s | %-*s | Detail |\n", - LOG_EVENT_LEN, - "Event", - LOG_BUFFER_LEN, - "Buffer", - LOG_SMALL_BUFFER_LEN, - "Next", - 5, - "Next#"); - printf("|%.*s|%.*s|%.*s|%.*s|--------|\n", - LOG_EVENT_LEN + 2, - DASHES, - LOG_BUFFER_LEN + 2, - DASHES, - LOG_SMALL_BUFFER_LEN + 2, - DASHES, - 5 + 2, - DASHES); - } -} - -simdjson_unused static simdjson_really_inline void log_string( - const char *message) { - if (LOG_ENABLED) { - printf("%s\n", message); - } -} - -// Logs a single line from the stage 2 DOM parser -template -static simdjson_really_inline void log_line(S &structurals, - const char *title_prefix, - const char *title, - const char *detail) { - if (LOG_ENABLED) { - printf("| %*s%s%-*s ", - log_depth * 2, - "", - title_prefix, - LOG_EVENT_LEN - log_depth * 2 - int(strlen(title_prefix)), - title); - auto current_index = structurals.at_beginning() - ? nullptr - : structurals.next_structural - 1; - auto next_index = structurals.next_structural; - auto current = current_index ? &structurals.buf[*current_index] - : reinterpret_cast( - " " - " "); - auto next = &structurals.buf[*next_index]; - { - // Print the next N characters in the buffer. - printf("| "); - // Otherwise, print the characters starting from the buffer - // position. - // Print spaces for unprintable or newline characters. - for (int i = 0; i < LOG_BUFFER_LEN; i++) { - printf("%c", printable_char(current[i])); - } - printf(" "); - // Print the next N characters in the buffer. - printf("| "); - // Otherwise, print the characters starting from the buffer - // position. - // Print spaces for unprintable or newline characters. - for (int i = 0; i < LOG_SMALL_BUFFER_LEN; i++) { - printf("%c", printable_char(next[i])); - } - printf(" "); - } - if (current_index) { - printf("| %*u ", LOG_INDEX_LEN, *current_index); - } else { - printf("| %-*s ", LOG_INDEX_LEN, ""); - } - // printf("| %*u ", LOG_INDEX_LEN, structurals.next_tape_index()); - printf("| %-s ", detail); - printf("|\n"); - } -} - -} // namespace logger -} // unnamed namespace -} // namespace fallback -} // namespace simdjson -/* end file src/generic/stage2/logger.h */ - -namespace simdjson { -namespace fallback { -namespace { -namespace stage2 { - -class json_iterator { - public: - const uint8_t *const buf; - uint32_t *next_structural; - dom_parser_implementation &dom_parser; - uint32_t depth{0}; - - /** - * Walk the JSON document. - * - * The visitor receives callbacks when values are encountered. All callbacks - * pass the iterator as - * the first parameter; some callbacks have other parameters as well: - * - * - visit_document_start() - at the beginning. - * - visit_document_end() - at the end (if things were successful). - * - * - visit_array_start() - at the start `[` of a non-empty array. - * - visit_array_end() - at the end `]` of a non-empty array. - * - visit_empty_array() - when an empty array is encountered. - * - * - visit_object_end() - at the start `]` of a non-empty object. - * - visit_object_start() - at the end `]` of a non-empty object. - * - visit_empty_object() - when an empty object is encountered. - * - visit_key(const uint8_t *key) - when a key in an object field is - * encountered. key is - * guaranteed to point at the first quote - * of the string (`"key"`). - * - visit_primitive(const uint8_t *value) - when a value is a string, - * number, boolean or null. - * - visit_root_primitive(iter, uint8_t *value) - when the top-level value - * is a string, number, boolean or null. - * - * - increment_count(iter) - each time a value is found in an array or - * object. - */ - template - simdjson_warn_unused simdjson_really_inline error_code - walk_document(V &visitor) noexcept; - - /** - * Create an iterator capable of walking a JSON document. - * - * The document must have already passed through stage 1. - */ - simdjson_really_inline json_iterator(dom_parser_implementation &_dom_parser, - size_t start_structural_index); - - /** - * Look at the next token. - * - * Tokens can be strings, numbers, booleans, null, or operators (`[{]},:`)). - * - * They may include invalid JSON as well (such as `1.2.3` or `ture`). - */ - simdjson_really_inline const uint8_t *peek() const noexcept; - /** - * Advance to the next token. - * - * Tokens can be strings, numbers, booleans, null, or operators (`[{]},:`)). - * - * They may include invalid JSON as well (such as `1.2.3` or `ture`). - */ - simdjson_really_inline const uint8_t *advance() noexcept; - /** - * Get the remaining length of the document, from the start of the current - * token. - */ - simdjson_really_inline size_t remaining_len() const noexcept; - /** - * Check if we are at the end of the document. - * - * If this is true, there are no more tokens. - */ - simdjson_really_inline bool at_eof() const noexcept; - /** - * Check if we are at the beginning of the document. - */ - simdjson_really_inline bool at_beginning() const noexcept; - simdjson_really_inline uint8_t last_structural() const noexcept; - - /** - * Log that a value has been found. - * - * Set LOG_ENABLED=true in logger.h to see logging. - */ - simdjson_really_inline void log_value(const char *type) const noexcept; - /** - * Log the start of a multipart value. - * - * Set LOG_ENABLED=true in logger.h to see logging. - */ - simdjson_really_inline void log_start_value(const char *type) const - noexcept; - /** - * Log the end of a multipart value. - * - * Set LOG_ENABLED=true in logger.h to see logging. - */ - simdjson_really_inline void log_end_value(const char *type) const noexcept; - /** - * Log an error. - * - * Set LOG_ENABLED=true in logger.h to see logging. - */ - simdjson_really_inline void log_error(const char *error) const noexcept; - - template - simdjson_warn_unused simdjson_really_inline error_code - visit_root_primitive(V &visitor, const uint8_t *value) noexcept; - template - simdjson_warn_unused simdjson_really_inline error_code - visit_primitive(V &visitor, const uint8_t *value) noexcept; -}; - -template -simdjson_warn_unused simdjson_really_inline error_code -json_iterator::walk_document(V &visitor) noexcept { - logger::log_start(); - - // - // Start the document - // - if (at_eof()) { - return EMPTY; - } - log_start_value("document"); - SIMDJSON_TRY(visitor.visit_document_start(*this)); - - // - // Read first value - // - { - auto value = advance(); - - // Make sure the outer object or array is closed before continuing; - // otherwise, there are ways we - // could get into memory corruption. See - // https://github.com/simdjson/simdjson/issues/906 - if (!STREAMING) { - switch (*value) { - case '{': - if (last_structural() != '}') { - log_value("starting brace unmatched"); - return TAPE_ERROR; - }; - break; - case '[': - if (last_structural() != ']') { - log_value("starting bracket unmatched"); - return TAPE_ERROR; - }; - break; - } - } - - switch (*value) { - case '{': - if (*peek() == '}') { - advance(); - log_value("empty object"); - SIMDJSON_TRY(visitor.visit_empty_object(*this)); - break; - } - goto object_begin; - case '[': - if (*peek() == ']') { - advance(); - log_value("empty array"); - SIMDJSON_TRY(visitor.visit_empty_array(*this)); - break; - } - goto array_begin; - default: - SIMDJSON_TRY(visitor.visit_root_primitive(*this, value)); - break; - } - } - goto document_end; - -// -// Object parser states -// -object_begin: - log_start_value("object"); - depth++; - if (depth >= dom_parser.max_depth()) { - log_error("Exceeded max depth!"); - return DEPTH_ERROR; - } - dom_parser.is_array[depth] = false; - SIMDJSON_TRY(visitor.visit_object_start(*this)); - - { - auto key = advance(); - if (*key != '"') { - log_error("Object does not start with a key"); - return TAPE_ERROR; - } - SIMDJSON_TRY(visitor.increment_count(*this)); - SIMDJSON_TRY(visitor.visit_key(*this, key)); - } - -object_field: - if (simdjson_unlikely(*advance() != ':')) { - log_error("Missing colon after key in object"); - return TAPE_ERROR; - } - { - auto value = advance(); - switch (*value) { - case '{': - if (*peek() == '}') { - advance(); - log_value("empty object"); - SIMDJSON_TRY(visitor.visit_empty_object(*this)); - break; - } - goto object_begin; - case '[': - if (*peek() == ']') { - advance(); - log_value("empty array"); - SIMDJSON_TRY(visitor.visit_empty_array(*this)); - break; - } - goto array_begin; - default: - SIMDJSON_TRY(visitor.visit_primitive(*this, value)); - break; - } - } - -object_continue: - switch (*advance()) { - case ',': - SIMDJSON_TRY(visitor.increment_count(*this)); - { - auto key = advance(); - if (simdjson_unlikely(*key != '"')) { - log_error( - "Key string missing at beginning of field in object"); - return TAPE_ERROR; - } - SIMDJSON_TRY(visitor.visit_key(*this, key)); - } - goto object_field; - case '}': - log_end_value("object"); - SIMDJSON_TRY(visitor.visit_object_end(*this)); - goto scope_end; - default: - log_error("No comma between object fields"); - return TAPE_ERROR; - } - -scope_end: - depth--; - if (depth == 0) { - goto document_end; - } - if (dom_parser.is_array[depth]) { - goto array_continue; - } - goto object_continue; - -// -// Array parser states -// -array_begin: - log_start_value("array"); - depth++; - if (depth >= dom_parser.max_depth()) { - log_error("Exceeded max depth!"); - return DEPTH_ERROR; - } - dom_parser.is_array[depth] = true; - SIMDJSON_TRY(visitor.visit_array_start(*this)); - SIMDJSON_TRY(visitor.increment_count(*this)); - -array_value : { - auto value = advance(); - switch (*value) { - case '{': - if (*peek() == '}') { - advance(); - log_value("empty object"); - SIMDJSON_TRY(visitor.visit_empty_object(*this)); - break; - } - goto object_begin; - case '[': - if (*peek() == ']') { - advance(); - log_value("empty array"); - SIMDJSON_TRY(visitor.visit_empty_array(*this)); - break; - } - goto array_begin; - default: - SIMDJSON_TRY(visitor.visit_primitive(*this, value)); - break; - } -} - -array_continue: - switch (*advance()) { - case ',': - SIMDJSON_TRY(visitor.increment_count(*this)); - goto array_value; - case ']': - log_end_value("array"); - SIMDJSON_TRY(visitor.visit_array_end(*this)); - goto scope_end; - default: - log_error("Missing comma between array values"); - return TAPE_ERROR; - } - -document_end: - log_end_value("document"); - SIMDJSON_TRY(visitor.visit_document_end(*this)); - - dom_parser.next_structural_index = - uint32_t(next_structural - &dom_parser.structural_indexes[0]); - - // If we didn't make it to the end, it's an error - if (!STREAMING && - dom_parser.next_structural_index != dom_parser.n_structural_indexes) { - log_error( - "More than one JSON value at the root of the document, or extra " - "characters at the end of the JSON!"); - return TAPE_ERROR; - } - - return SUCCESS; - -} // walk_document() - -simdjson_really_inline json_iterator::json_iterator( - dom_parser_implementation &_dom_parser, size_t start_structural_index) - : buf{_dom_parser.buf}, - next_structural{&_dom_parser.structural_indexes[start_structural_index]}, - dom_parser{_dom_parser} {} - -simdjson_really_inline const uint8_t *json_iterator::peek() const noexcept { - return &buf[*(next_structural)]; -} -simdjson_really_inline const uint8_t *json_iterator::advance() noexcept { - return &buf[*(next_structural++)]; -} -simdjson_really_inline size_t json_iterator::remaining_len() const noexcept { - return dom_parser.len - *(next_structural - 1); -} - -simdjson_really_inline bool json_iterator::at_eof() const noexcept { - return next_structural == - &dom_parser.structural_indexes[dom_parser.n_structural_indexes]; -} -simdjson_really_inline bool json_iterator::at_beginning() const noexcept { - return next_structural == dom_parser.structural_indexes.get(); -} -simdjson_really_inline uint8_t json_iterator::last_structural() const noexcept { - return buf[dom_parser - .structural_indexes[dom_parser.n_structural_indexes - 1]]; -} - -simdjson_really_inline void json_iterator::log_value(const char *type) const - noexcept { - logger::log_line(*this, "", type, ""); -} - -simdjson_really_inline void json_iterator::log_start_value( - const char *type) const noexcept { - logger::log_line(*this, "+", type, ""); - if (logger::LOG_ENABLED) { - logger::log_depth++; - } -} - -simdjson_really_inline void json_iterator::log_end_value(const char *type) const - noexcept { - if (logger::LOG_ENABLED) { - logger::log_depth--; - } - logger::log_line(*this, "-", type, ""); -} - -simdjson_really_inline void json_iterator::log_error(const char *error) const - noexcept { - logger::log_line(*this, "", "ERROR", error); -} - -template -simdjson_warn_unused simdjson_really_inline error_code -json_iterator::visit_root_primitive(V &visitor, const uint8_t *value) noexcept { - switch (*value) { - case '"': - return visitor.visit_root_string(*this, value); - case 't': - return visitor.visit_root_true_atom(*this, value); - case 'f': - return visitor.visit_root_false_atom(*this, value); - case 'n': - return visitor.visit_root_null_atom(*this, value); - case '-': - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - return visitor.visit_root_number(*this, value); - default: - log_error("Document starts with a non-value character"); - return TAPE_ERROR; - } -} -template -simdjson_warn_unused simdjson_really_inline error_code -json_iterator::visit_primitive(V &visitor, const uint8_t *value) noexcept { - switch (*value) { - case '"': - return visitor.visit_string(*this, value); - case 't': - return visitor.visit_true_atom(*this, value); - case 'f': - return visitor.visit_false_atom(*this, value); - case 'n': - return visitor.visit_null_atom(*this, value); - case '-': - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - return visitor.visit_number(*this, value); - default: - log_error("Non-value found when value was expected!"); - return TAPE_ERROR; - } -} - -} // namespace stage2 -} // unnamed namespace -} // namespace fallback -} // namespace simdjson -/* end file src/generic/stage2/json_iterator.h */ -/* begin file src/generic/stage2/tape_writer.h */ -namespace simdjson { -namespace fallback { -namespace { -namespace stage2 { - -struct tape_writer { - /** The next place to write to tape */ - uint64_t *next_tape_loc; - - /** Write a signed 64-bit value to tape. */ - simdjson_really_inline void append_s64(int64_t value) noexcept; - - /** Write an unsigned 64-bit value to tape. */ - simdjson_really_inline void append_u64(uint64_t value) noexcept; - - /** Write a double value to tape. */ - simdjson_really_inline void append_double(double value) noexcept; - - /** - * Append a tape entry (an 8-bit type,and 56 bits worth of value). - */ - simdjson_really_inline void append(uint64_t val, - internal::tape_type t) noexcept; - - /** - * Skip the current tape entry without writing. - * - * Used to skip the start of the container, since we'll come back later to - * fill it in when the - * container ends. - */ - simdjson_really_inline void skip() noexcept; - - /** - * Skip the number of tape entries necessary to write a large u64 or i64. - */ - simdjson_really_inline void skip_large_integer() noexcept; - - /** - * Skip the number of tape entries necessary to write a double. - */ - simdjson_really_inline void skip_double() noexcept; - - /** - * Write a value to a known location on tape. - * - * Used to go back and write out the start of a container after the - * container ends. - */ - simdjson_really_inline static void write(uint64_t &tape_loc, - uint64_t val, - internal::tape_type t) noexcept; - - private: - /** - * Append both the tape entry, and a supplementary value following it. Used - * for types that need - * all 64 bits, such as double and uint64_t. - */ - template - simdjson_really_inline void append2(uint64_t val, - T val2, - internal::tape_type t) noexcept; -}; // struct number_writer - -simdjson_really_inline void tape_writer::append_s64(int64_t value) noexcept { - append2(0, value, internal::tape_type::INT64); -} - -simdjson_really_inline void tape_writer::append_u64(uint64_t value) noexcept { - append(0, internal::tape_type::UINT64); - *next_tape_loc = value; - next_tape_loc++; -} - -/** Write a double value to tape. */ -simdjson_really_inline void tape_writer::append_double(double value) noexcept { - append2(0, value, internal::tape_type::DOUBLE); -} - -simdjson_really_inline void tape_writer::skip() noexcept { next_tape_loc++; } - -simdjson_really_inline void tape_writer::skip_large_integer() noexcept { - next_tape_loc += 2; -} - -simdjson_really_inline void tape_writer::skip_double() noexcept { - next_tape_loc += 2; -} - -simdjson_really_inline void tape_writer::append( - uint64_t val, internal::tape_type t) noexcept { - *next_tape_loc = val | ((uint64_t(char(t))) << 56); - next_tape_loc++; -} - -template -simdjson_really_inline void tape_writer::append2( - uint64_t val, T val2, internal::tape_type t) noexcept { - append(val, t); - static_assert(sizeof(val2) == sizeof(*next_tape_loc), - "Type is not 64 bits!"); - memcpy(next_tape_loc, &val2, sizeof(val2)); - next_tape_loc++; -} - -simdjson_really_inline void tape_writer::write(uint64_t &tape_loc, - uint64_t val, - internal::tape_type t) noexcept { - tape_loc = val | ((uint64_t(char(t))) << 56); -} - -} // namespace stage2 -} // unnamed namespace -} // namespace fallback -} // namespace simdjson -/* end file src/generic/stage2/tape_writer.h */ - -namespace simdjson { -namespace fallback { -namespace { -namespace stage2 { - -struct tape_builder { - template - simdjson_warn_unused static simdjson_really_inline error_code - parse_document(dom_parser_implementation &dom_parser, - dom::document &doc) noexcept; - - /** Called when a non-empty document starts. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_document_start(json_iterator &iter) noexcept; - /** Called when a non-empty document ends without error. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_document_end(json_iterator &iter) noexcept; - - /** Called when a non-empty array starts. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_array_start(json_iterator &iter) noexcept; - /** Called when a non-empty array ends. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_array_end(json_iterator &iter) noexcept; - /** Called when an empty array is found. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_empty_array(json_iterator &iter) noexcept; - - /** Called when a non-empty object starts. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_object_start(json_iterator &iter) noexcept; - /** - * Called when a key in a field is encountered. - * - * primitive, visit_object_start, visit_empty_object, visit_array_start, or - * visit_empty_array - * will be called after this with the field value. - */ - simdjson_warn_unused simdjson_really_inline error_code - visit_key(json_iterator &iter, const uint8_t *key) noexcept; - /** Called when a non-empty object ends. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_object_end(json_iterator &iter) noexcept; - /** Called when an empty object is found. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_empty_object(json_iterator &iter) noexcept; - - /** - * Called when a string, number, boolean or null is found. - */ - simdjson_warn_unused simdjson_really_inline error_code - visit_primitive(json_iterator &iter, const uint8_t *value) noexcept; - /** - * Called when a string, number, boolean or null is found at the top level - * of a document (i.e. - * when there is no array or object and the entire document is a single - * string, number, boolean or - * null. - * - * This is separate from primitive() because simdjson's normal primitive - * parsing routines assume - * there is at least one more token after the value, which is only true in - * an array or object. - */ - simdjson_warn_unused simdjson_really_inline error_code - visit_root_primitive(json_iterator &iter, const uint8_t *value) noexcept; - - simdjson_warn_unused simdjson_really_inline error_code visit_string( - json_iterator &iter, const uint8_t *value, bool key = false) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_number(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_true_atom(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_false_atom(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_null_atom(json_iterator &iter, const uint8_t *value) noexcept; - - simdjson_warn_unused simdjson_really_inline error_code - visit_root_string(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_root_number(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_root_true_atom(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_root_false_atom(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_root_null_atom(json_iterator &iter, const uint8_t *value) noexcept; - - /** Called each time a new field or element in an array or object is found. - */ - simdjson_warn_unused simdjson_really_inline error_code - increment_count(json_iterator &iter) noexcept; - - /** Next location to write to tape */ - tape_writer tape; - - private: - /** Next write location in the string buf for stage 2 parsing */ - uint8_t *current_string_buf_loc; - - simdjson_really_inline tape_builder(dom::document &doc) noexcept; - - simdjson_really_inline uint32_t next_tape_index(json_iterator &iter) const - noexcept; - simdjson_really_inline void start_container(json_iterator &iter) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - end_container(json_iterator &iter, - internal::tape_type start, - internal::tape_type end) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - empty_container(json_iterator &iter, - internal::tape_type start, - internal::tape_type end) noexcept; - simdjson_really_inline uint8_t *on_start_string( - json_iterator &iter) noexcept; - simdjson_really_inline void on_end_string(uint8_t *dst) noexcept; -}; // class tape_builder - -template -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::parse_document(dom_parser_implementation &dom_parser, - dom::document &doc) noexcept { - dom_parser.doc = &doc; - json_iterator iter(dom_parser, - STREAMING ? dom_parser.next_structural_index : 0); - tape_builder builder(doc); - return iter.walk_document(builder); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_primitive(json_iterator &iter, - const uint8_t *value) noexcept { - return iter.visit_root_primitive(*this, value); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_primitive(json_iterator &iter, - const uint8_t *value) noexcept { - return iter.visit_primitive(*this, value); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_empty_object(json_iterator &iter) noexcept { - return empty_container(iter, - internal::tape_type::START_OBJECT, - internal::tape_type::END_OBJECT); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_empty_array(json_iterator &iter) noexcept { - return empty_container( - iter, internal::tape_type::START_ARRAY, internal::tape_type::END_ARRAY); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_document_start(json_iterator &iter) noexcept { - start_container(iter); - return SUCCESS; -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_object_start(json_iterator &iter) noexcept { - start_container(iter); - return SUCCESS; -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_array_start(json_iterator &iter) noexcept { - start_container(iter); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_object_end(json_iterator &iter) noexcept { - return end_container(iter, - internal::tape_type::START_OBJECT, - internal::tape_type::END_OBJECT); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_array_end(json_iterator &iter) noexcept { - return end_container( - iter, internal::tape_type::START_ARRAY, internal::tape_type::END_ARRAY); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_document_end(json_iterator &iter) noexcept { - constexpr uint32_t start_tape_index = 0; - tape.append(start_tape_index, internal::tape_type::ROOT); - tape_writer::write(iter.dom_parser.doc->tape[start_tape_index], - next_tape_index(iter), - internal::tape_type::ROOT); - return SUCCESS; -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_key(json_iterator &iter, const uint8_t *key) noexcept { - return visit_string(iter, key, true); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::increment_count(json_iterator &iter) noexcept { - iter.dom_parser.open_containers[iter.depth] - .count++; // we have a key value pair in the object at - // parser.dom_parser.depth - 1 - return SUCCESS; -} - -simdjson_really_inline tape_builder::tape_builder(dom::document &doc) noexcept - : tape{doc.tape.get()}, - current_string_buf_loc{doc.string_buf.get()} {} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_string(json_iterator &iter, - const uint8_t *value, - bool key) noexcept { - iter.log_value(key ? "key" : "string"); - uint8_t *dst = on_start_string(iter); - dst = stringparsing::parse_string(value + 1, dst); - if (dst == nullptr) { - iter.log_error("Invalid escape in string"); - return STRING_ERROR; - } - on_end_string(dst); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_string(json_iterator &iter, - const uint8_t *value) noexcept { - return visit_string(iter, value); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_number(json_iterator &iter, const uint8_t *value) noexcept { - iter.log_value("number"); - return numberparsing::parse_number(value, tape); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_number(json_iterator &iter, - const uint8_t *value) noexcept { - // - // We need to make a copy to make sure that the string is space terminated. - // This is not about padding the input, which should already padded up - // to len + SIMDJSON_PADDING. However, we have no control at this stage - // on how the padding was done. What if the input string was padded with - // nulls? - // It is quite common for an input string to have an extra null character (C - // string). - // We do not want to allow 9\0 (where \0 is the null character) inside a - // JSON - // document, but the string "9\0" by itself is fine. So we make a copy and - // pad the input with spaces when we know that there is just one input - // element. - // This copy is relatively expensive, but it will almost never be called in - // practice unless you are in the strange scenario where you have many JSON - // documents made of single atoms. - // - std::unique_ptr copy( - new (std::nothrow) uint8_t[iter.remaining_len() + SIMDJSON_PADDING]); - if (copy.get() == nullptr) { - return MEMALLOC; - } - std::memcpy(copy.get(), value, iter.remaining_len()); - std::memset(copy.get() + iter.remaining_len(), ' ', SIMDJSON_PADDING); - error_code error = visit_number(iter, copy.get()); - return error; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_true_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("true"); - if (!atomparsing::is_valid_true_atom(value)) { - return T_ATOM_ERROR; - } - tape.append(0, internal::tape_type::TRUE_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_true_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("true"); - if (!atomparsing::is_valid_true_atom(value, iter.remaining_len())) { - return T_ATOM_ERROR; - } - tape.append(0, internal::tape_type::TRUE_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_false_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("false"); - if (!atomparsing::is_valid_false_atom(value)) { - return F_ATOM_ERROR; - } - tape.append(0, internal::tape_type::FALSE_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_false_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("false"); - if (!atomparsing::is_valid_false_atom(value, iter.remaining_len())) { - return F_ATOM_ERROR; - } - tape.append(0, internal::tape_type::FALSE_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_null_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("null"); - if (!atomparsing::is_valid_null_atom(value)) { - return N_ATOM_ERROR; - } - tape.append(0, internal::tape_type::NULL_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_null_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("null"); - if (!atomparsing::is_valid_null_atom(value, iter.remaining_len())) { - return N_ATOM_ERROR; - } - tape.append(0, internal::tape_type::NULL_VALUE); - return SUCCESS; -} - -// private: - -simdjson_really_inline uint32_t -tape_builder::next_tape_index(json_iterator &iter) const noexcept { - return uint32_t(tape.next_tape_loc - iter.dom_parser.doc->tape.get()); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::empty_container(json_iterator &iter, - internal::tape_type start, - internal::tape_type end) noexcept { - auto start_index = next_tape_index(iter); - tape.append(start_index + 2, start); - tape.append(start_index, end); - return SUCCESS; -} - -simdjson_really_inline void tape_builder::start_container( - json_iterator &iter) noexcept { - iter.dom_parser.open_containers[iter.depth].tape_index = - next_tape_index(iter); - iter.dom_parser.open_containers[iter.depth].count = 0; - tape.skip(); // We don't actually *write* the start element until the end. -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::end_container(json_iterator &iter, - internal::tape_type start, - internal::tape_type end) noexcept { - // Write the ending tape element, pointing at the start location - const uint32_t start_tape_index = - iter.dom_parser.open_containers[iter.depth].tape_index; - tape.append(start_tape_index, end); - // Write the start tape element, pointing at the end location (and including - // count) - // count can overflow if it exceeds 24 bits... so we saturate - // the convention being that a cnt of 0xffffff or more is undetermined in - // value (>= 0xffffff). - const uint32_t count = iter.dom_parser.open_containers[iter.depth].count; - const uint32_t cntsat = count > 0xFFFFFF ? 0xFFFFFF : count; - tape_writer::write(iter.dom_parser.doc->tape[start_tape_index], - next_tape_index(iter) | (uint64_t(cntsat) << 32), - start); - return SUCCESS; -} - -simdjson_really_inline uint8_t *tape_builder::on_start_string( - json_iterator &iter) noexcept { - // we advance the point, accounting for the fact that we have a NULL - // termination - tape.append(current_string_buf_loc - iter.dom_parser.doc->string_buf.get(), - internal::tape_type::STRING); - return current_string_buf_loc + sizeof(uint32_t); -} - -simdjson_really_inline void tape_builder::on_end_string(uint8_t *dst) noexcept { - uint32_t str_length = - uint32_t(dst - (current_string_buf_loc + sizeof(uint32_t))); - // TODO check for overflow in case someone has a crazy string (>=4GB?) - // But only add the overflow check when the document itself exceeds 4GB - // Currently unneeded because we refuse to parse docs larger or equal to - // 4GB. - memcpy(current_string_buf_loc, &str_length, sizeof(uint32_t)); - // NULL termination is still handy if you expect all your strings to - // be NULL terminated? It comes at a small cost - *dst = 0; - current_string_buf_loc = dst + 1; -} - -} // namespace stage2 -} // unnamed namespace -} // namespace fallback -} // namespace simdjson -/* end file src/generic/stage2/tape_builder.h */ - -namespace simdjson { -namespace fallback { - -simdjson_warn_unused error_code -dom_parser_implementation::stage2(dom::document &_doc) noexcept { - return stage2::tape_builder::parse_document(*this, _doc); -} - -simdjson_warn_unused error_code -dom_parser_implementation::stage2_next(dom::document &_doc) noexcept { - return stage2::tape_builder::parse_document(*this, _doc); -} - -simdjson_warn_unused error_code dom_parser_implementation::parse( - const uint8_t *_buf, size_t _len, dom::document &_doc) noexcept { - auto error = stage1(_buf, _len, stage1_mode::regular); - if (error) { - return error; - } - return stage2(_doc); -} - -} // namespace fallback -} // namespace simdjson - -/* begin file include/simdjson/fallback/end.h */ -/* end file include/simdjson/fallback/end.h */ -/* end file src/fallback/dom_parser_implementation.cpp */ -#endif -#if SIMDJSON_IMPLEMENTATION_HASWELL -/* begin file src/haswell/implementation.cpp */ -/* begin file include/simdjson/haswell/begin.h */ -// redefining SIMDJSON_IMPLEMENTATION to "haswell" -// #define SIMDJSON_IMPLEMENTATION haswell -SIMDJSON_TARGET_HASWELL -/* end file include/simdjson/haswell/begin.h */ - -namespace simdjson { -namespace haswell { - -simdjson_warn_unused error_code -implementation::create_dom_parser_implementation( - size_t capacity, - size_t max_depth, - std::unique_ptr &dst) const noexcept { - dst.reset(new (std::nothrow) dom_parser_implementation()); - if (!dst) { - return MEMALLOC; - } - if (auto err = dst->set_capacity(capacity)) return err; - if (auto err = dst->set_max_depth(max_depth)) return err; - return SUCCESS; -} - -} // namespace haswell -} // namespace simdjson - -/* begin file include/simdjson/haswell/end.h */ -SIMDJSON_UNTARGET_HASWELL -/* end file include/simdjson/haswell/end.h */ - -/* end file src/haswell/implementation.cpp */ -/* begin file src/haswell/dom_parser_implementation.cpp */ -/* begin file include/simdjson/haswell/begin.h */ -// redefining SIMDJSON_IMPLEMENTATION to "haswell" -// #define SIMDJSON_IMPLEMENTATION haswell -SIMDJSON_TARGET_HASWELL -/* end file include/simdjson/haswell/begin.h */ - -// -// Stage 1 -// - -namespace simdjson { -namespace haswell { -namespace { - -using namespace simd; - -struct json_character_block { - static simdjson_really_inline json_character_block - classify(const simd::simd8x64 &in); - // ASCII white-space ('\r','\n','\t',' ') - simdjson_really_inline uint64_t whitespace() const noexcept; - // non-quote structural characters (comma, colon, braces, brackets) - simdjson_really_inline uint64_t op() const noexcept; - // neither a structural character nor a white-space, so letters, numbers and - // quotes - simdjson_really_inline uint64_t scalar() const noexcept; - - uint64_t _whitespace; // ASCII white-space ('\r','\n','\t',' ') - uint64_t _op; // structural characters (comma, colon, braces, brackets but - // not quotes) -}; - -simdjson_really_inline uint64_t json_character_block::whitespace() const - noexcept { - return _whitespace; -} -simdjson_really_inline uint64_t json_character_block::op() const noexcept { - return _op; -} -simdjson_really_inline uint64_t json_character_block::scalar() const noexcept { - return ~(op() | whitespace()); -} - -// This identifies structural characters (comma, colon, braces, brackets), -// and ASCII white-space ('\r','\n','\t',' '). -simdjson_really_inline json_character_block -json_character_block::classify(const simd::simd8x64 &in) { - // These lookups rely on the fact that anything < 127 will match the lower 4 - // bits, which is why - // we can't use the generic lookup_16. - const auto whitespace_table = simd8::repeat_16(' ', - 100, - 100, - 100, - 17, - 100, - 113, - 2, - 100, - '\t', - '\n', - 112, - 100, - '\r', - 100, - 100); - - // The 6 operators (:,[]{}) have these values: - // - // , 2C - // : 3A - // [ 5B - // { 7B - // ] 5D - // } 7D - // - // If you use | 0x20 to turn [ and ] into { and }, the lower 4 bits of each - // character is unique. - // We exploit this, using a simd 4-bit lookup to tell us which character - // match against, and then - // match it (against | 0x20). - // - // To prevent recognizing other characters, everything else gets compared - // with 0, which cannot - // match due to the | 0x20. - // - // NOTE: Due to the | 0x20, this ALSO treats and (control - // characters 0C and 1A) like , - // and :. This gets caught in stage 2, which checks the actual character to - // ensure the right - // operators are in the right places. - const auto op_table = - simd8::repeat_16(0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - ':', - '{', // : = 3A, [ = 5B, { = 7B - ',', - '}', - 0, - 0 // , = 2C, ] = 5D, } = 7D - ); - - // We compute whitespace and op separately. If later code only uses one or - // the - // other, given the fact that all functions are aggressively inlined, we can - // hope that useless computations will be omitted. This is namely case when - // minifying (we only need whitespace). - - const uint64_t whitespace = - in.eq({_mm256_shuffle_epi8(whitespace_table, in.chunks[0]), - _mm256_shuffle_epi8(whitespace_table, in.chunks[1])}); - // Turn [ and ] into { and } - const simd8x64 curlified{in.chunks[0] | 0x20, in.chunks[1] | 0x20}; - const uint64_t op = - curlified.eq({_mm256_shuffle_epi8(op_table, in.chunks[0]), - _mm256_shuffle_epi8(op_table, in.chunks[1])}); - - return {whitespace, op}; -} - -simdjson_really_inline bool is_ascii(const simd8x64 &input) { - return input.reduce_or().is_ascii(); -} - -simdjson_unused simdjson_really_inline simd8 must_be_continuation( - const simd8 prev1, - const simd8 prev2, - const simd8 prev3) { - simd8 is_second_byte = - prev1.saturating_sub(0b11000000u - 1); // Only 11______ will be > 0 - simd8 is_third_byte = - prev2.saturating_sub(0b11100000u - 1); // Only 111_____ will be > 0 - simd8 is_fourth_byte = - prev3.saturating_sub(0b11110000u - 1); // Only 1111____ will be > 0 - // Caller requires a bool (all 1's). All values resulting from the - // subtraction will be <= 64, so signed comparison is fine. - return simd8(is_second_byte | is_third_byte | is_fourth_byte) > - int8_t(0); -} - -simdjson_really_inline simd8 must_be_2_3_continuation( - const simd8 prev2, const simd8 prev3) { - simd8 is_third_byte = - prev2.saturating_sub(0b11100000u - 1); // Only 111_____ will be > 0 - simd8 is_fourth_byte = - prev3.saturating_sub(0b11110000u - 1); // Only 1111____ will be > 0 - // Caller requires a bool (all 1's). All values resulting from the - // subtraction will be <= 64, so signed comparison is fine. - return simd8(is_third_byte | is_fourth_byte) > int8_t(0); -} - -} // unnamed namespace -} // namespace haswell -} // namespace simdjson - -/* begin file src/generic/stage1/utf8_lookup4_algorithm.h */ -namespace simdjson { -namespace haswell { -namespace { -namespace utf8_validation { - -using namespace simd; - -simdjson_really_inline simd8 check_special_cases( - const simd8 input, const simd8 prev1) { - // Bit 0 = Too Short (lead byte/ASCII followed by lead byte/ASCII) - // Bit 1 = Too Long (ASCII followed by continuation) - // Bit 2 = Overlong 3-byte - // Bit 4 = Surrogate - // Bit 5 = Overlong 2-byte - // Bit 7 = Two Continuations - constexpr const uint8_t TOO_SHORT = 1 << 0; // 11______ 0_______ - // 11______ 11______ - constexpr const uint8_t TOO_LONG = 1 << 1; // 0_______ 10______ - constexpr const uint8_t OVERLONG_3 = 1 << 2; // 11100000 100_____ - constexpr const uint8_t SURROGATE = 1 << 4; // 11101101 101_____ - constexpr const uint8_t OVERLONG_2 = 1 << 5; // 1100000_ 10______ - constexpr const uint8_t TWO_CONTS = 1 << 7; // 10______ 10______ - constexpr const uint8_t TOO_LARGE = 1 << 3; // 11110100 1001____ - // 11110100 101_____ - // 11110101 1001____ - // 11110101 101_____ - // 1111011_ 1001____ - // 1111011_ 101_____ - // 11111___ 1001____ - // 11111___ 101_____ - constexpr const uint8_t TOO_LARGE_1000 = 1 << 6; - // 11110101 1000____ - // 1111011_ 1000____ - // 11111___ 1000____ - constexpr const uint8_t OVERLONG_4 = 1 << 6; // 11110000 1000____ - - const simd8 byte_1_high = prev1.shr<4>().lookup_16( - // 0_______ ________ - TOO_LONG, - TOO_LONG, - TOO_LONG, - TOO_LONG, - TOO_LONG, - TOO_LONG, - TOO_LONG, - TOO_LONG, - // 10______ ________ - TWO_CONTS, - TWO_CONTS, - TWO_CONTS, - TWO_CONTS, - // 1100____ ________ - TOO_SHORT | OVERLONG_2, - // 1101____ ________ - TOO_SHORT, - // 1110____ ________ - TOO_SHORT | OVERLONG_3 | SURROGATE, - // 1111____ ________ - TOO_SHORT | TOO_LARGE | TOO_LARGE_1000 | OVERLONG_4); - constexpr const uint8_t CARRY = - TOO_SHORT | TOO_LONG | TWO_CONTS; // These all have ____ in byte 1 . - const simd8 byte_1_low = - (prev1 & 0x0F) - .lookup_16( - // ____0000 ________ - CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4, - // ____0001 ________ - CARRY | OVERLONG_2, - // ____001_ ________ - CARRY, - CARRY, - - // ____0100 ________ - CARRY | TOO_LARGE, - // ____0101 ________ - CARRY | TOO_LARGE | TOO_LARGE_1000, - // ____011_ ________ - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000, - - // ____1___ ________ - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000, - // ____1101 ________ - CARRY | TOO_LARGE | TOO_LARGE_1000 | SURROGATE, - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000); - const simd8 byte_2_high = input.shr<4>().lookup_16( - // ________ 0_______ - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - - // ________ 1000____ - TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE_1000 | - OVERLONG_4, - // ________ 1001____ - TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE, - // ________ 101_____ - TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE | TOO_LARGE, - TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE | TOO_LARGE, - - // ________ 11______ - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT); - return (byte_1_high & byte_1_low & byte_2_high); -} -simdjson_really_inline simd8 check_multibyte_lengths( - const simd8 input, - const simd8 prev_input, - const simd8 sc) { - simd8 prev2 = input.prev<2>(prev_input); - simd8 prev3 = input.prev<3>(prev_input); - simd8 must23 = - simd8(must_be_2_3_continuation(prev2, prev3)); - simd8 must23_80 = must23 & uint8_t(0x80); - return must23_80 ^ sc; -} - -// -// Return nonzero if there are incomplete multibyte characters at the end of the -// block: -// e.g. if there is a 4-byte character, but it's 3 bytes from the end. -// -simdjson_really_inline simd8 is_incomplete( - const simd8 input) { - // If the previous input's last 3 bytes match this, they're too short (they - // ended at EOF): - // ... 1111____ 111_____ 11______ - static const uint8_t max_array[32] = {255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 0b11110000u - 1, - 0b11100000u - 1, - 0b11000000u - 1}; - const simd8 max_value( - &max_array[sizeof(max_array) - sizeof(simd8)]); - return input.gt_bits(max_value); -} - -struct utf8_checker { - // If this is nonzero, there has been a UTF-8 error. - simd8 error; - // The last input we received - simd8 prev_input_block; - // Whether the last input we received was incomplete (used for ASCII fast - // path) - simd8 prev_incomplete; - - // - // Check whether the current bytes are valid UTF-8. - // - simdjson_really_inline void check_utf8_bytes( - const simd8 input, const simd8 prev_input) { - // Flip prev1...prev3 so we can easily determine if they are 2+, 3+ or - // 4+ lead bytes - // (2, 3, 4-byte leads become large positive numbers instead of small - // negative numbers) - simd8 prev1 = input.prev<1>(prev_input); - simd8 sc = check_special_cases(input, prev1); - this->error |= check_multibyte_lengths(input, prev_input, sc); - } - - // The only problem that can happen at EOF is that a multibyte character is - // too short - // or a byte value too large in the last bytes: check_special_cases only - // checks for bytes - // too large in the first of two bytes. - simdjson_really_inline void check_eof() { - // If the previous block had incomplete UTF-8 characters at the end, an - // ASCII block can't - // possibly finish them. - this->error |= this->prev_incomplete; - } - - simdjson_really_inline void check_next_input( - const simd8x64 &input) { - if (simdjson_likely(is_ascii(input))) { - this->error |= this->prev_incomplete; - } else { - // you might think that a for-loop would work, but under Visual - // Studio, it is not good enough. - static_assert( - (simd8x64::NUM_CHUNKS == 2) || - (simd8x64::NUM_CHUNKS == 4), - "We support either two or four chunks per 64-byte block."); - if (simd8x64::NUM_CHUNKS == 2) { - this->check_utf8_bytes(input.chunks[0], this->prev_input_block); - this->check_utf8_bytes(input.chunks[1], input.chunks[0]); - } else if (simd8x64::NUM_CHUNKS == 4) { - this->check_utf8_bytes(input.chunks[0], this->prev_input_block); - this->check_utf8_bytes(input.chunks[1], input.chunks[0]); - this->check_utf8_bytes(input.chunks[2], input.chunks[1]); - this->check_utf8_bytes(input.chunks[3], input.chunks[2]); - } - this->prev_incomplete = - is_incomplete(input.chunks[simd8x64::NUM_CHUNKS - 1]); - this->prev_input_block = - input.chunks[simd8x64::NUM_CHUNKS - 1]; - } - } - // do not forget to call check_eof! - simdjson_really_inline error_code errors() { - return this->error.any_bits_set_anywhere() ? error_code::UTF8_ERROR - : error_code::SUCCESS; - } - -}; // struct utf8_checker -} // namespace utf8_validation - -using utf8_validation::utf8_checker; - -} // unnamed namespace -} // namespace haswell -} // namespace simdjson -/* end file src/generic/stage1/utf8_lookup4_algorithm.h */ -/* begin file src/generic/stage1/json_structural_indexer.h */ -// This file contains the common code every implementation uses in stage1 -// It is intended to be included multiple times and compiled multiple times -// We assume the file in which it is included already includes -// "simdjson/stage1.h" (this simplifies amalgation) - -/* begin file src/generic/stage1/buf_block_reader.h */ -namespace simdjson { -namespace haswell { -namespace { - -// Walks through a buffer in block-sized increments, loading the last part with -// spaces -template -struct buf_block_reader { - public: - simdjson_really_inline buf_block_reader(const uint8_t *_buf, size_t _len); - simdjson_really_inline size_t block_index(); - simdjson_really_inline bool has_full_block() const; - simdjson_really_inline const uint8_t *full_block() const; - /** - * Get the last block, padded with spaces. - * - * There will always be a last block, with at least 1 byte, unless len == 0 - * (in which case this - * function fills the buffer with spaces and returns 0. In particular, if - * len == STEP_SIZE there - * will be 0 full_blocks and 1 remainder block with STEP_SIZE bytes and no - * spaces for padding. - * - * @return the number of effective characters in the last block. - */ - simdjson_really_inline size_t get_remainder(uint8_t *dst) const; - simdjson_really_inline void advance(); - - private: - const uint8_t *buf; - const size_t len; - const size_t lenminusstep; - size_t idx; -}; - -// Routines to print masks and text for debugging bitmask operations -simdjson_unused static char *format_input_text_64(const uint8_t *text) { - static char buf[sizeof(simd8x64) + 1]; - for (size_t i = 0; i < sizeof(simd8x64); i++) { - buf[i] = int8_t(text[i]) < ' ' ? '_' : int8_t(text[i]); - } - buf[sizeof(simd8x64)] = '\0'; - return buf; -} - -// Routines to print masks and text for debugging bitmask operations -simdjson_unused static char *format_input_text(const simd8x64 &in) { - static char buf[sizeof(simd8x64) + 1]; - in.store(reinterpret_cast(buf)); - for (size_t i = 0; i < sizeof(simd8x64); i++) { - if (buf[i] < ' ') { - buf[i] = '_'; - } - } - buf[sizeof(simd8x64)] = '\0'; - return buf; -} - -simdjson_unused static char *format_mask(uint64_t mask) { - static char buf[sizeof(simd8x64) + 1]; - for (size_t i = 0; i < 64; i++) { - buf[i] = (mask & (size_t(1) << i)) ? 'X' : ' '; - } - buf[64] = '\0'; - return buf; -} - -template -simdjson_really_inline buf_block_reader::buf_block_reader( - const uint8_t *_buf, size_t _len) - : buf{_buf}, - len{_len}, - lenminusstep{len < STEP_SIZE ? 0 : len - STEP_SIZE}, - idx{0} {} - -template -simdjson_really_inline size_t buf_block_reader::block_index() { - return idx; -} - -template -simdjson_really_inline bool buf_block_reader::has_full_block() - const { - return idx < lenminusstep; -} - -template -simdjson_really_inline const uint8_t *buf_block_reader::full_block() - const { - return &buf[idx]; -} - -template -simdjson_really_inline size_t -buf_block_reader::get_remainder(uint8_t *dst) const { - if (len == idx) { - return 0; - } // memcpy(dst, null, 0) will trigger an error with some sanitizers - std::memset(dst, 0x20, STEP_SIZE); // std::memset STEP_SIZE because it's - // more efficient to write out 8 or 16 - // bytes at once. - std::memcpy(dst, buf + idx, len - idx); - return len - idx; -} - -template -simdjson_really_inline void buf_block_reader::advance() { - idx += STEP_SIZE; -} - -} // unnamed namespace -} // namespace haswell -} // namespace simdjson -/* end file src/generic/stage1/buf_block_reader.h */ -/* begin file src/generic/stage1/json_string_scanner.h */ -namespace simdjson { -namespace haswell { -namespace { -namespace stage1 { - -struct json_string_block { - // We spell out the constructors in the hope of resolving inlining issues - // with Visual Studio 2017 - simdjson_really_inline json_string_block(uint64_t backslash, - uint64_t escaped, - uint64_t quote, - uint64_t in_string) - : _backslash(backslash), - _escaped(escaped), - _quote(quote), - _in_string(in_string) {} - - // Escaped characters (characters following an escape() character) - simdjson_really_inline uint64_t escaped() const { return _escaped; } - // Escape characters (backslashes that are not escaped--i.e. in \\, includes - // only the first \) - simdjson_really_inline uint64_t escape() const { - return _backslash & ~_escaped; - } - // Real (non-backslashed) quotes - simdjson_really_inline uint64_t quote() const { return _quote; } - // Start quotes of strings - simdjson_really_inline uint64_t string_start() const { - return _quote & _in_string; - } - // End quotes of strings - simdjson_really_inline uint64_t string_end() const { - return _quote & ~_in_string; - } - // Only characters inside the string (not including the quotes) - simdjson_really_inline uint64_t string_content() const { - return _in_string & ~_quote; - } - // Return a mask of whether the given characters are inside a string (only - // works on non-quotes) - simdjson_really_inline uint64_t - non_quote_inside_string(uint64_t mask) const { - return mask & _in_string; - } - // Return a mask of whether the given characters are inside a string (only - // works on non-quotes) - simdjson_really_inline uint64_t - non_quote_outside_string(uint64_t mask) const { - return mask & ~_in_string; - } - // Tail of string (everything except the start quote) - simdjson_really_inline uint64_t string_tail() const { - return _in_string ^ _quote; - } - - // backslash characters - uint64_t _backslash; - // escaped characters (backslashed--does not include the hex characters - // after \u) - uint64_t _escaped; - // real quotes (non-backslashed ones) - uint64_t _quote; - // string characters (includes start quote but not end quote) - uint64_t _in_string; -}; - -// Scans blocks for string characters, storing the state necessary to do so -class json_string_scanner { - public: - simdjson_really_inline json_string_block - next(const simd::simd8x64 &in); - // Returns either UNCLOSED_STRING or SUCCESS - simdjson_really_inline error_code finish(); - - private: - // Intended to be defined by the implementation - simdjson_really_inline uint64_t find_escaped(uint64_t escape); - simdjson_really_inline uint64_t find_escaped_branchless(uint64_t escape); - - // Whether the last iteration was still inside a string (all 1's = true, all - // 0's = false). - uint64_t prev_in_string = 0ULL; - // Whether the first character of the next iteration is escaped. - uint64_t prev_escaped = 0ULL; -}; - -// -// Finds escaped characters (characters following \). -// -// Handles runs of backslashes like \\\" and \\\\" correctly (yielding 0101 and -// 01010, respectively). -// -// Does this by: -// - Shift the escape mask to get potentially escaped characters (characters -// after backslashes). -// - Mask escaped sequences that start on *even* bits with 1010101010 (odd bits -// are escaped, even bits are not) -// - Mask escaped sequences that start on *odd* bits with 0101010101 (even bits -// are escaped, odd bits are not) -// -// To distinguish between escaped sequences starting on even/odd bits, it finds -// the start of all -// escape sequences, filters out the ones that start on even bits, and adds that -// to the mask of -// escape sequences. This causes the addition to clear out the sequences -// starting on odd bits (since -// the start bit causes a carry), and leaves even-bit sequences alone. -// -// Example: -// -// text | \\\ | \\\"\\\" \\\" \\"\\" | -// escape | xxx | xx xxx xxx xx xx | Removed overflow backslash; -// will | it into follows_escape -// odd_starts | x | x x x | escape & ~even_bits & -// ~follows_escape -// even_seq | c| cxxx c xx c | c = carry bit -- will be -// masked out later -// invert_mask | | cxxx c xx c| even_seq << 1 -// follows_escape | xx | x xx xxx xxx xx xx | Includes overflow bit -// escaped | x | x x x x x x x x | -// desired | x | x x x x x x x x | -// text | \\\ | \\\"\\\" \\\" \\"\\" | -// -simdjson_really_inline uint64_t -json_string_scanner::find_escaped_branchless(uint64_t backslash) { - // If there was overflow, pretend the first character isn't a backslash - backslash &= ~prev_escaped; - uint64_t follows_escape = backslash << 1 | prev_escaped; - - // Get sequences starting on even bits by clearing out the odd series using - // + - const uint64_t even_bits = 0x5555555555555555ULL; - uint64_t odd_sequence_starts = backslash & ~even_bits & ~follows_escape; - uint64_t sequences_starting_on_even_bits; - prev_escaped = add_overflow( - odd_sequence_starts, backslash, &sequences_starting_on_even_bits); - uint64_t invert_mask = - sequences_starting_on_even_bits - << 1; // The mask we want to return is the *escaped* bits, not escapes. - - // Mask every other backslashed character as an escaped character - // Flip the mask for sequences that start on even bits, to correct them - return (even_bits ^ invert_mask) & follows_escape; -} - -// -// Return a mask of all string characters plus end quotes. -// -// prev_escaped is overflow saying whether the next character is escaped. -// prev_in_string is overflow saying whether we're still in a string. -// -// Backslash sequences outside of quotes will be detected in stage 2. -// -simdjson_really_inline json_string_block -json_string_scanner::next(const simd::simd8x64 &in) { - const uint64_t backslash = in.eq('\\'); - const uint64_t escaped = find_escaped(backslash); - const uint64_t quote = in.eq('"') & ~escaped; - - // - // prefix_xor flips on bits inside the string (and flips off the end quote). - // - // Then we xor with prev_in_string: if we were in a string already, its - // effect is flipped - // (characters inside strings are outside, and characters outside strings - // are inside). - // - const uint64_t in_string = prefix_xor(quote) ^ prev_in_string; - - // - // Check if we're still in a string at the end of the box so the next block - // will know - // - // right shift of a signed value expected to be well-defined and standard - // compliant as of C++20, John Regher from Utah U. says this is fine code - // - prev_in_string = uint64_t(static_cast(in_string) >> 63); - - // Use ^ to turn the beginning quote off, and the end quote on. - - // We are returning a function-local object so either we get a move - // constructor - // or we get copy elision. - return json_string_block(backslash, escaped, quote, in_string); -} - -simdjson_really_inline error_code json_string_scanner::finish() { - if (prev_in_string) { - return UNCLOSED_STRING; - } - return SUCCESS; -} - -} // namespace stage1 -} // unnamed namespace -} // namespace haswell -} // namespace simdjson -/* end file src/generic/stage1/json_string_scanner.h */ -/* begin file src/generic/stage1/json_scanner.h */ -namespace simdjson { -namespace haswell { -namespace { -namespace stage1 { - -/** - * A block of scanned json, with information on operators and scalars. - * - * We seek to identify pseudo-structural characters. Anything that is inside - * a string must be omitted (hence & ~_string.string_tail()). - * Otherwise, pseudo-structural characters come in two forms. - * 1. We have the structural characters ([,],{,},:, comma). The - * term 'structural character' is from the JSON RFC. - * 2. We have the 'scalar pseudo-structural characters'. - * Scalars are quotes, and any character except structural characters and - * white space. - * - * To identify the scalar pseudo-structural characters, we must look at what - * comes - * before them: it must be a space, a quote or a structural characters. - * Starting with simdjson v0.3, we identify them by - * negation: we identify everything that is followed by a non-quote scalar, - * and we negate that. Whatever remains must be a 'scalar pseudo-structural - * character'. - */ -struct json_block { - public: - // We spell out the constructors in the hope of resolving inlining issues - // with Visual Studio 2017 - simdjson_really_inline json_block( - json_string_block &&string, - json_character_block characters, - uint64_t follows_potential_nonquote_scalar) - : _string(std::move(string)), - _characters(characters), - _follows_potential_nonquote_scalar( - follows_potential_nonquote_scalar) {} - simdjson_really_inline json_block( - json_string_block string, - json_character_block characters, - uint64_t follows_potential_nonquote_scalar) - : _string(string), - _characters(characters), - _follows_potential_nonquote_scalar( - follows_potential_nonquote_scalar) {} - - /** - * The start of structurals. - * In simdjson prior to v0.3, these were called the pseudo-structural - *characters. - **/ - simdjson_really_inline uint64_t structural_start() const noexcept { - return potential_structural_start() & ~_string.string_tail(); - } - /** All JSON whitespace (i.e. not in a string) */ - simdjson_really_inline uint64_t whitespace() const noexcept { - return non_quote_outside_string(_characters.whitespace()); - } - - // Helpers - - /** Whether the given characters are inside a string (only works on - * non-quotes) */ - simdjson_really_inline uint64_t non_quote_inside_string(uint64_t mask) const - noexcept { - return _string.non_quote_inside_string(mask); - } - /** Whether the given characters are outside a string (only works on - * non-quotes) */ - simdjson_really_inline uint64_t - non_quote_outside_string(uint64_t mask) const noexcept { - return _string.non_quote_outside_string(mask); - } - - // string and escape characters - json_string_block _string; - // whitespace, structural characters ('operators'), scalars - json_character_block _characters; - // whether the previous character was a scalar - uint64_t _follows_potential_nonquote_scalar; - - private: - // Potential structurals (i.e. disregarding strings) - - /** - * structural elements ([,],{,},:, comma) plus scalar starts like 123, true - *and "abc". - * They may reside inside a string. - **/ - simdjson_really_inline uint64_t potential_structural_start() const - noexcept { - return _characters.op() | potential_scalar_start(); - } - /** - * The start of non-operator runs, like 123, true and "abc". - * It main reside inside a string. - **/ - simdjson_really_inline uint64_t potential_scalar_start() const noexcept { - // The term "scalar" refers to anything except structural characters and - // white space - // (so letters, numbers, quotes). - // Whenever it is preceded by something that is not a structural element - // ({,},[,],:, ") nor a white-space - // then we know that it is irrelevant structurally. - return _characters.scalar() & ~follows_potential_scalar(); - } - /** - * Whether the given character is immediately after a non-operator like 123, - * true. - * The characters following a quote are not included. - */ - simdjson_really_inline uint64_t follows_potential_scalar() const noexcept { - // _follows_potential_nonquote_scalar: is defined as marking any - // character that follows a character - // that is not a structural element ({,},[,],:, comma) nor a quote (") - // and that is not a - // white space. - // It is understood that within quoted region, anything at all could be - // marked (irrelevant). - return _follows_potential_nonquote_scalar; - } -}; - -/** - * Scans JSON for important bits: structural characters or 'operators', strings, - * and scalars. - * - * The scanner starts by calculating two distinct things: - * - string characters (taking \" into account) - * - structural characters or 'operators' ([]{},:, comma) - * and scalars (runs of non-operators like 123, true and "abc") - * - * To minimize data dependency (a key component of the scanner's speed), it - * finds these in parallel: - * in particular, the operator/scalar bit will find plenty of things that are - * actually part of - * strings. When we're done, json_block will fuse the two together by masking - * out tokens that are - * part of a string. - */ -class json_scanner { - public: - json_scanner() {} - simdjson_really_inline json_block next(const simd::simd8x64 &in); - // Returns either UNCLOSED_STRING or SUCCESS - simdjson_really_inline error_code finish(); - - private: - // Whether the last character of the previous iteration is part of a scalar - // token - // (anything except whitespace or a structural character/'operator'). - uint64_t prev_scalar = 0ULL; - json_string_scanner string_scanner{}; -}; - - -// -// Check if the current character immediately follows a matching character. -// -// For example, this checks for quotes with backslashes in front of them: -// -// const uint64_t backslashed_quote = in.eq('"') & -// immediately_follows(in.eq('\'), prev_backslash); -// -simdjson_really_inline uint64_t follows(const uint64_t match, - uint64_t &overflow) { - const uint64_t result = match << 1 | overflow; - overflow = match >> 63; - return result; -} - -simdjson_really_inline json_block -json_scanner::next(const simd::simd8x64 &in) { - json_string_block strings = string_scanner.next(in); - // identifies the white-space and the structural characters - json_character_block characters = json_character_block::classify(in); - // The term "scalar" refers to anything except structural characters and - // white space - // (so letters, numbers, quotes). - // We want follows_scalar to mark anything that follows a non-quote scalar - // (so letters and numbers). - // - // A terminal quote should either be followed by a structural character - // (comma, brace, bracket, colon) - // or nothing. However, we still want ' "a string"true ' to mark the 't' of - // 'true' as a potential - // pseudo-structural character just like we would if we had ' "a string" - // true '; otherwise we - // may need to add an extra check when parsing strings. - // - // Performance: there are many ways to skin this cat. - const uint64_t nonquote_scalar = characters.scalar() & ~strings.quote(); - uint64_t follows_nonquote_scalar = follows(nonquote_scalar, prev_scalar); - // We are returning a function-local object so either we get a move - // constructor - // or we get copy elision. - return json_block(strings, // strings is a function-local object so either - // it moves or the copy is elided. - characters, - follows_nonquote_scalar); -} - -simdjson_really_inline error_code json_scanner::finish() { - return string_scanner.finish(); -} - -} // namespace stage1 -} // unnamed namespace -} // namespace haswell -} // namespace simdjson -/* end file src/generic/stage1/json_scanner.h */ -/* begin file src/generic/stage1/json_minifier.h */ -// This file contains the common code every implementation uses in stage1 -// It is intended to be included multiple times and compiled multiple times -// We assume the file in which it is included already includes -// "simdjson/stage1.h" (this simplifies amalgation) - -namespace simdjson { -namespace haswell { -namespace { -namespace stage1 { - -class json_minifier { - public: - template - static error_code minify(const uint8_t *buf, - size_t len, - uint8_t *dst, - size_t &dst_len) noexcept; - - private: - simdjson_really_inline json_minifier(uint8_t *_dst) : dst{_dst} {} - template - simdjson_really_inline void step( - const uint8_t *block_buf, buf_block_reader &reader) noexcept; - simdjson_really_inline void next(const simd::simd8x64 &in, - const json_block &block); - simdjson_really_inline error_code finish(uint8_t *dst_start, - size_t &dst_len); - json_scanner scanner{}; - uint8_t *dst; -}; - -simdjson_really_inline void json_minifier::next( - const simd::simd8x64 &in, const json_block &block) { - uint64_t mask = block.whitespace(); - dst += in.compress(mask, dst); -} - -simdjson_really_inline error_code json_minifier::finish(uint8_t *dst_start, - size_t &dst_len) { - error_code error = scanner.finish(); - if (error) { - dst_len = 0; - return error; - } - dst_len = dst - dst_start; - return SUCCESS; -} - -template <> -simdjson_really_inline void json_minifier::step<128>( - const uint8_t *block_buf, buf_block_reader<128> &reader) noexcept { - simd::simd8x64 in_1(block_buf); - simd::simd8x64 in_2(block_buf + 64); - json_block block_1 = scanner.next(in_1); - json_block block_2 = scanner.next(in_2); - this->next(in_1, block_1); - this->next(in_2, block_2); - reader.advance(); -} - -template <> -simdjson_really_inline void json_minifier::step<64>( - const uint8_t *block_buf, buf_block_reader<64> &reader) noexcept { - simd::simd8x64 in_1(block_buf); - json_block block_1 = scanner.next(in_1); - this->next(block_buf, block_1); - reader.advance(); -} - -template -error_code json_minifier::minify(const uint8_t *buf, - size_t len, - uint8_t *dst, - size_t &dst_len) noexcept { - buf_block_reader reader(buf, len); - json_minifier minifier(dst); - - // Index the first n-1 blocks - while (reader.has_full_block()) { - minifier.step(reader.full_block(), reader); - } - - // Index the last (remainder) block, padded with spaces - uint8_t block[STEP_SIZE]; - size_t remaining_bytes = reader.get_remainder(block); - if (remaining_bytes > 0) { - // We do not want to write directly to the output stream. Rather, we - // write - // to a local buffer (for safety). - uint8_t out_block[STEP_SIZE]; - uint8_t *const guarded_dst{minifier.dst}; - minifier.dst = out_block; - minifier.step(block, reader); - size_t to_write = minifier.dst - out_block; - // In some cases, we could be enticed to consider the padded spaces - // as part of the string. This is fine as long as we do not write more - // than we consumed. - if (to_write > remaining_bytes) { - to_write = remaining_bytes; - } - memcpy(guarded_dst, out_block, to_write); - minifier.dst = guarded_dst + to_write; - } - return minifier.finish(dst, dst_len); -} - -} // namespace stage1 -} // unnamed namespace -} // namespace haswell -} // namespace simdjson -/* end file src/generic/stage1/json_minifier.h */ -/* begin file src/generic/stage1/find_next_document_index.h */ -namespace simdjson { -namespace haswell { -namespace { - -/** - * This algorithm is used to quickly identify the last structural position that - * makes up a complete document. - * - * It does this by going backwards and finding the last *document boundary* (a - * place where one value follows another without a comma between them). If the - * last document (the characters after the boundary) has an equal number of - * start and end brackets, it is considered complete. - * - * Simply put, we iterate over the structural characters, starting from - * the end. We consider that we found the end of a JSON document when the - * first element of the pair is NOT one of these characters: '{' '[' ':' ',' - * and when the second element is NOT one of these characters: '}' ']' ':' ','. - * - * This simple comparison works most of the time, but it does not cover cases - * where the batch's structural indexes contain a perfect amount of documents. - * In such a case, we do not have access to the structural index which follows - * the last document, therefore, we do not have access to the second element in - * the pair, and that means we cannot identify the last document. To fix this - * issue, we keep a count of the open and closed curly/square braces we found - * while searching for the pair. When we find a pair AND the count of open and - * closed curly/square braces is the same, we know that we just passed a - * complete document, therefore the last json buffer location is the end of the - * batch. - */ -simdjson_really_inline uint32_t -find_next_document_index(dom_parser_implementation &parser) { - // Variant: do not count separately, just figure out depth - if (parser.n_structural_indexes == 0) { - return 0; - } - auto arr_cnt = 0; - auto obj_cnt = 0; - for (auto i = parser.n_structural_indexes - 1; i > 0; i--) { - auto idxb = parser.structural_indexes[i]; - switch (parser.buf[idxb]) { - case ':': - case ',': - continue; - case '}': - obj_cnt--; - continue; - case ']': - arr_cnt--; - continue; - case '{': - obj_cnt++; - break; - case '[': - arr_cnt++; - break; - } - auto idxa = parser.structural_indexes[i - 1]; - switch (parser.buf[idxa]) { - case '{': - case '[': - case ':': - case ',': - continue; - } - // Last document is complete, so the next document will appear after! - if (!arr_cnt && !obj_cnt) { - return parser.n_structural_indexes; - } - // Last document is incomplete; mark the document at i + 1 as the next - // one - return i; - } - // If we made it to the end, we want to finish counting to see if we have a - // full document. - switch (parser.buf[parser.structural_indexes[0]]) { - case '}': - obj_cnt--; - break; - case ']': - arr_cnt--; - break; - case '{': - obj_cnt++; - break; - case '[': - arr_cnt++; - break; - } - if (!arr_cnt && !obj_cnt) { - // We have a complete document. - return parser.n_structural_indexes; - } - return 0; -} - -} // unnamed namespace -} // namespace haswell -} // namespace simdjson -/* end file src/generic/stage1/find_next_document_index.h */ - -namespace simdjson { -namespace haswell { -namespace { -namespace stage1 { - -class bit_indexer { - public: - uint32_t *tail; - - simdjson_really_inline bit_indexer(uint32_t *index_buf) : tail(index_buf) {} - - // flatten out values in 'bits' assuming that they are are to have values of - // idx - // plus their position in the bitvector, and store these indexes at - // base_ptr[base] incrementing base as we go - // will potentially store extra values beyond end of valid bits, so base_ptr - // needs to be large enough to handle this - simdjson_really_inline void write(uint32_t idx, uint64_t bits) { - // In some instances, the next branch is expensive because it is - // mispredicted. - // Unfortunately, in other cases, - // it helps tremendously. - if (bits == 0) return; -#if defined(SIMDJSON_PREFER_REVERSE_BITS) - /** - * ARM lacks a fast trailing zero instruction, but it has a fast - * bit reversal instruction and a fast leading zero instruction. - * Thus it may be profitable to reverse the bits (once) and then - * to rely on a sequence of instructions that call the leading - * zero instruction. - * - * Performance notes: - * The chosen routine is not optimal in terms of data dependency - * since zero_leading_bit might require two instructions. However, - * it tends to minimize the total number of instructions which is - * beneficial. - */ - - uint64_t rev_bits = reverse_bits(bits); - int cnt = static_cast(count_ones(bits)); - int i = 0; - // Do the first 8 all together - for (; i < 8; i++) { - int lz = leading_zeroes(rev_bits); - this->tail[i] = static_cast(idx) + lz; - rev_bits = zero_leading_bit(rev_bits, lz); - } - // Do the next 8 all together (we hope in most cases it won't happen at - // all - // and the branch is easily predicted). - if (simdjson_unlikely(cnt > 8)) { - i = 8; - for (; i < 16; i++) { - int lz = leading_zeroes(rev_bits); - this->tail[i] = static_cast(idx) + lz; - rev_bits = zero_leading_bit(rev_bits, lz); - } - - - // Most files don't have 16+ structurals per block, so we take - // several basically guaranteed - // branch mispredictions here. 16+ structurals per block means - // either punctuation ({} [] , :) - // or the start of a value ("abc" true 123) every four characters. - if (simdjson_unlikely(cnt > 16)) { - i = 16; - while (rev_bits != 0) { - int lz = leading_zeroes(rev_bits); - this->tail[i++] = static_cast(idx) + lz; - rev_bits = zero_leading_bit(rev_bits, lz); - } - } - } - this->tail += cnt; -#else // SIMDJSON_PREFER_REVERSE_BITS - /** - * Under recent x64 systems, we often have both a fast trailing zero - * instruction and a fast 'clear-lower-bit' instruction so the following - * algorithm can be competitive. - */ - - int cnt = static_cast(count_ones(bits)); - // Do the first 8 all together - for (int i = 0; i < 8; i++) { - this->tail[i] = idx + trailing_zeroes(bits); - bits = clear_lowest_bit(bits); - } - - // Do the next 8 all together (we hope in most cases it won't happen at - // all - // and the branch is easily predicted). - if (simdjson_unlikely(cnt > 8)) { - for (int i = 8; i < 16; i++) { - this->tail[i] = idx + trailing_zeroes(bits); - bits = clear_lowest_bit(bits); - } - - // Most files don't have 16+ structurals per block, so we take - // several basically guaranteed - // branch mispredictions here. 16+ structurals per block means - // either punctuation ({} [] , :) - // or the start of a value ("abc" true 123) every four characters. - if (simdjson_unlikely(cnt > 16)) { - int i = 16; - do { - this->tail[i] = idx + trailing_zeroes(bits); - bits = clear_lowest_bit(bits); - i++; - } while (i < cnt); - } - } - - this->tail += cnt; -#endif - } -}; - -class json_structural_indexer { - public: - /** - * Find the important bits of JSON in a 128-byte chunk, and add them to - * structural_indexes. - * - * @param partial Setting the partial parameter to true allows the - * find_structural_bits to - * tolerate unclosed strings. The caller should still ensure that the - * input is valid UTF-8. If - * you are processing substrings, you may want to call on a function like - * trimmed_length_safe_utf8. - */ - template - static error_code index(const uint8_t *buf, - size_t len, - dom_parser_implementation &parser, - stage1_mode partial) noexcept; - - private: - simdjson_really_inline json_structural_indexer( - uint32_t *structural_indexes); - template - simdjson_really_inline void step( - const uint8_t *block, buf_block_reader &reader) noexcept; - simdjson_really_inline void next(const simd::simd8x64 &in, - const json_block &block, - size_t idx); - simdjson_really_inline error_code finish(dom_parser_implementation &parser, - size_t idx, - size_t len, - stage1_mode partial); - - json_scanner scanner{}; - utf8_checker checker{}; - bit_indexer indexer; - uint64_t prev_structurals = 0; - uint64_t unescaped_chars_error = 0; -}; - -simdjson_really_inline json_structural_indexer::json_structural_indexer( - uint32_t *structural_indexes) - : indexer{structural_indexes} {} - -// Skip the last character if it is partial -simdjson_really_inline size_t trim_partial_utf8(const uint8_t *buf, - size_t len) { - if (simdjson_unlikely(len < 3)) { - switch (len) { - case 2: - if (buf[len - 1] >= 0b11000000) { - return len - 1; - } // 2-, 3- and 4-byte characters with only 1 byte left - if (buf[len - 2] >= 0b11100000) { - return len - 2; - } // 3- and 4-byte characters with only 2 bytes left - return len; - case 1: - if (buf[len - 1] >= 0b11000000) { - return len - 1; - } // 2-, 3- and 4-byte characters with only 1 byte left - return len; - case 0: - return len; - } - } - if (buf[len - 1] >= 0b11000000) { - return len - 1; - } // 2-, 3- and 4-byte characters with only 1 byte left - if (buf[len - 2] >= 0b11100000) { - return len - 2; - } // 3- and 4-byte characters with only 1 byte left - if (buf[len - 3] >= 0b11110000) { - return len - 3; - } // 4-byte characters with only 3 bytes left - return len; -} - -// -// PERF NOTES: -// We pipe 2 inputs through these stages: -// 1. Load JSON into registers. This takes a long time and is highly -// parallelizable, so we load -// 2 inputs' worth at once so that by the time step 2 is looking for them -// input, it's available. -// 2. Scan the JSON for critical data: strings, scalars and operators. This is -// the critical path. -// The output of step 1 depends entirely on this information. These functions -// don't quite use -// up enough CPU: the second half of the functions is highly serial, only -// using 1 execution core -// at a time. The second input's scans has some dependency on the first ones -// finishing it, but -// they can make a lot of progress before they need that information. -// 3. Step 1 doesn't use enough capacity, so we run some extra stuff while we're -// waiting for that -// to finish: utf-8 checks and generating the output from the last iteration. -// -// The reason we run 2 inputs at a time, is steps 2 and 3 are *still* not enough -// to soak up all -// available capacity with just one input. Running 2 at a time seems to give the -// CPU a good enough -// workout. -// -template -error_code json_structural_indexer::index(const uint8_t *buf, - size_t len, - dom_parser_implementation &parser, - stage1_mode partial) noexcept { - if (simdjson_unlikely(len > parser.capacity())) { - return CAPACITY; - } - // We guard the rest of the code so that we can assume that len > 0 - // throughout. - if (len == 0) { - return EMPTY; - } - if (is_streaming(partial)) { - len = trim_partial_utf8(buf, len); - // If you end up with an empty window after trimming - // the partial UTF-8 bytes, then chances are good that you - // have an UTF-8 formatting error. - if (len == 0) { - return UTF8_ERROR; - } - } - buf_block_reader reader(buf, len); - json_structural_indexer indexer(parser.structural_indexes.get()); - - // Read all but the last block - while (reader.has_full_block()) { - indexer.step(reader.full_block(), reader); - } - // Take care of the last block (will always be there unless file is empty - // which is - // not supposed to happen.) - uint8_t block[STEP_SIZE]; - if (simdjson_unlikely(reader.get_remainder(block) == 0)) { - return UNEXPECTED_ERROR; - } - indexer.step(block, reader); - return indexer.finish(parser, reader.block_index(), len, partial); -} - -template <> -simdjson_really_inline void json_structural_indexer::step<128>( - const uint8_t *block, buf_block_reader<128> &reader) noexcept { - simd::simd8x64 in_1(block); - simd::simd8x64 in_2(block + 64); - json_block block_1 = scanner.next(in_1); - json_block block_2 = scanner.next(in_2); - this->next(in_1, block_1, reader.block_index()); - this->next(in_2, block_2, reader.block_index() + 64); - reader.advance(); -} - -template <> -simdjson_really_inline void json_structural_indexer::step<64>( - const uint8_t *block, buf_block_reader<64> &reader) noexcept { - simd::simd8x64 in_1(block); - json_block block_1 = scanner.next(in_1); - this->next(in_1, block_1, reader.block_index()); - reader.advance(); -} - -simdjson_really_inline void json_structural_indexer::next( - const simd::simd8x64 &in, const json_block &block, size_t idx) { - uint64_t unescaped = in.lteq(0x1F); - checker.check_next_input(in); - indexer.write(uint32_t(idx - 64), prev_structurals); // Output *last* - // iteration's - // structurals to the - // parser - prev_structurals = block.structural_start(); - unescaped_chars_error |= block.non_quote_inside_string(unescaped); -} - -simdjson_really_inline error_code -json_structural_indexer::finish(dom_parser_implementation &parser, - size_t idx, - size_t len, - stage1_mode partial) { - // Write out the final iteration's structurals - indexer.write(uint32_t(idx - 64), prev_structurals); - error_code error = scanner.finish(); - // We deliberately break down the next expression so that it is - // human readable. - const bool should_we_exit = - is_streaming(partial) - ? ((error != SUCCESS) && - (error != - UNCLOSED_STRING)) // when partial we tolerate UNCLOSED_STRING - : (error != SUCCESS); // if partial is false, we must have SUCCESS - const bool have_unclosed_string = (error == UNCLOSED_STRING); - if (simdjson_unlikely(should_we_exit)) { - return error; - } - - if (unescaped_chars_error) { - return UNESCAPED_CHARS; - } - parser.n_structural_indexes = - uint32_t(indexer.tail - parser.structural_indexes.get()); - /*** - * The On Demand API requires special padding. - * - * This is related to https://github.com/simdjson/simdjson/issues/906 - * Basically, we want to make sure that if the parsing continues beyond the - *last (valid) - * structural character, it quickly stops. - * Only three structural characters can be repeated without triggering an - *error in JSON: [,] and }. - * We repeat the padding character (at 'len'). We don't know what it is, but - *if the parsing - * continues, then it must be [,] or }. - * Suppose it is ] or }. We backtrack to the first character, what could it - *be that would - * not trigger an error? It could be ] or } but no, because you can't start - *a document that way. - * It can't be a comma, a colon or any simple value. So the only way we - *could continue is - * if the repeated character is [. But if so, the document must start with - *[. But if the document - * starts with [, it should end with ]. If we enforce that rule, then we - *would get - * ][[ which is invalid. - * - * This is illustrated with the test array_iterate_unclosed_error() on the - *following input: - * R"({ "a": [,,)" - **/ - parser.structural_indexes[parser.n_structural_indexes] = - uint32_t(len); // used later in partial == stage1_mode::streaming_final - parser.structural_indexes[parser.n_structural_indexes + 1] = uint32_t(len); - parser.structural_indexes[parser.n_structural_indexes + 2] = 0; - parser.next_structural_index = 0; - // a valid JSON file cannot have zero structural indexes - we should have - // found something - if (simdjson_unlikely(parser.n_structural_indexes == 0u)) { - return EMPTY; - } - if (simdjson_unlikely( - parser.structural_indexes[parser.n_structural_indexes - 1] > len)) { - return UNEXPECTED_ERROR; - } - if (partial == stage1_mode::streaming_partial) { - // If we have an unclosed string, then the last structural - // will be the quote and we want to make sure to omit it. - if (have_unclosed_string) { - parser.n_structural_indexes--; - // a valid JSON file cannot have zero structural indexes - we should - // have found something - if (simdjson_unlikely(parser.n_structural_indexes == 0u)) { - return CAPACITY; - } - } - // We truncate the input to the end of the last complete document (or - // zero). - auto new_structural_indexes = find_next_document_index(parser); - if (new_structural_indexes == 0 && parser.n_structural_indexes > 0) { - if (parser.structural_indexes[0] == 0) { - // If the buffer is partial and we started at index 0 but the - // document is - // incomplete, it's too big to parse. - return CAPACITY; - } else { - // It is possible that the document could be parsed, we just had - // a lot - // of white space. - parser.n_structural_indexes = 0; - return EMPTY; - } - } - - parser.n_structural_indexes = new_structural_indexes; - } else if (partial == stage1_mode::streaming_final) { - if (have_unclosed_string) { - parser.n_structural_indexes--; - } - // We truncate the input to the end of the last complete document (or - // zero). - // Because partial == stage1_mode::streaming_final, it means that we may - // silently ignore trailing garbage. Though it sounds bad, we do it - // deliberately because many people who have streams of JSON documents - // will truncate them for processing. E.g., imagine that you are - // uncompressing - // the data from a size file or receiving it in chunks from the network. - // You - // may not know where exactly the last document will be. Meanwhile the - // document_stream instances allow people to know the JSON documents - // they are - // parsing (see the iterator.source() method). - parser.n_structural_indexes = find_next_document_index(parser); - // We store the initial n_structural_indexes so that the client can see - // whether we used truncation. If initial_n_structural_indexes == - // parser.n_structural_indexes, - // then this will query - // parser.structural_indexes[parser.n_structural_indexes] which is len, - // otherwise, it will copy some prior index. - parser.structural_indexes[parser.n_structural_indexes + 1] = - parser.structural_indexes[parser.n_structural_indexes]; - // This next line is critical, do not change it unless you understand - // what you are - // doing. - parser.structural_indexes[parser.n_structural_indexes] = uint32_t(len); - if (simdjson_unlikely(parser.n_structural_indexes == 0u)) { - // We tolerate an unclosed string at the very end of the stream. - // Indeed, users - // often load their data in bulk without being careful and they want - // us to ignore - // the trailing garbage. - return EMPTY; - } - } - checker.check_eof(); - return checker.errors(); -} - -} // namespace stage1 -} // unnamed namespace -} // namespace haswell -} // namespace simdjson -/* end file src/generic/stage1/json_structural_indexer.h */ -/* begin file src/generic/stage1/utf8_validator.h */ -namespace simdjson { -namespace haswell { -namespace { -namespace stage1 { - -/** - * Validates that the string is actual UTF-8. - */ -template -bool generic_validate_utf8(const uint8_t *input, size_t length) { - checker c{}; - buf_block_reader<64> reader(input, length); - while (reader.has_full_block()) { - simd::simd8x64 in(reader.full_block()); - c.check_next_input(in); - reader.advance(); - } - uint8_t block[64]{}; - reader.get_remainder(block); - simd::simd8x64 in(block); - c.check_next_input(in); - reader.advance(); - c.check_eof(); - return c.errors() == error_code::SUCCESS; -} - -bool generic_validate_utf8(const char *input, size_t length) { - return generic_validate_utf8( - reinterpret_cast(input), length); -} - -} // namespace stage1 -} // unnamed namespace -} // namespace haswell -} // namespace simdjson -/* end file src/generic/stage1/utf8_validator.h */ - -// -// Stage 2 -// -/* begin file src/generic/stage2/tape_builder.h */ -/* begin file src/generic/stage2/json_iterator.h */ -/* begin file src/generic/stage2/logger.h */ -// This is for an internal-only stage 2 specific logger. -// Set LOG_ENABLED = true to log what stage 2 is doing! -namespace simdjson { -namespace haswell { -namespace { -namespace logger { - -static constexpr const char *DASHES = - "--------------------------------------------------------------------------" - "--------------------------------------------------------------------------" - "--------------------------------------------------------------------------" - "----------------------------------"; - -#if SIMDJSON_VERBOSE_LOGGING -static constexpr const bool LOG_ENABLED = true; -#else -static constexpr const bool LOG_ENABLED = false; -#endif -static constexpr const int LOG_EVENT_LEN = 20; -static constexpr const int LOG_BUFFER_LEN = 30; -static constexpr const int LOG_SMALL_BUFFER_LEN = 10; -static constexpr const int LOG_INDEX_LEN = 5; - -static int log_depth; // Not threadsafe. Log only. - -// Helper to turn unprintable or newline characters into spaces -static simdjson_really_inline char printable_char(char c) { - if (c >= 0x20) { - return c; - } else { - return ' '; - } -} - -// Print the header and set up log_start -static simdjson_really_inline void log_start() { - if (LOG_ENABLED) { - log_depth = 0; - printf("\n"); - printf("| %-*s | %-*s | %-*s | %-*s | Detail |\n", - LOG_EVENT_LEN, - "Event", - LOG_BUFFER_LEN, - "Buffer", - LOG_SMALL_BUFFER_LEN, - "Next", - 5, - "Next#"); - printf("|%.*s|%.*s|%.*s|%.*s|--------|\n", - LOG_EVENT_LEN + 2, - DASHES, - LOG_BUFFER_LEN + 2, - DASHES, - LOG_SMALL_BUFFER_LEN + 2, - DASHES, - 5 + 2, - DASHES); - } -} - -simdjson_unused static simdjson_really_inline void log_string( - const char *message) { - if (LOG_ENABLED) { - printf("%s\n", message); - } -} - -// Logs a single line from the stage 2 DOM parser -template -static simdjson_really_inline void log_line(S &structurals, - const char *title_prefix, - const char *title, - const char *detail) { - if (LOG_ENABLED) { - printf("| %*s%s%-*s ", - log_depth * 2, - "", - title_prefix, - LOG_EVENT_LEN - log_depth * 2 - int(strlen(title_prefix)), - title); - auto current_index = structurals.at_beginning() - ? nullptr - : structurals.next_structural - 1; - auto next_index = structurals.next_structural; - auto current = current_index ? &structurals.buf[*current_index] - : reinterpret_cast( - " " - " "); - auto next = &structurals.buf[*next_index]; - { - // Print the next N characters in the buffer. - printf("| "); - // Otherwise, print the characters starting from the buffer - // position. - // Print spaces for unprintable or newline characters. - for (int i = 0; i < LOG_BUFFER_LEN; i++) { - printf("%c", printable_char(current[i])); - } - printf(" "); - // Print the next N characters in the buffer. - printf("| "); - // Otherwise, print the characters starting from the buffer - // position. - // Print spaces for unprintable or newline characters. - for (int i = 0; i < LOG_SMALL_BUFFER_LEN; i++) { - printf("%c", printable_char(next[i])); - } - printf(" "); - } - if (current_index) { - printf("| %*u ", LOG_INDEX_LEN, *current_index); - } else { - printf("| %-*s ", LOG_INDEX_LEN, ""); - } - // printf("| %*u ", LOG_INDEX_LEN, structurals.next_tape_index()); - printf("| %-s ", detail); - printf("|\n"); - } -} - -} // namespace logger -} // unnamed namespace -} // namespace haswell -} // namespace simdjson -/* end file src/generic/stage2/logger.h */ - -namespace simdjson { -namespace haswell { -namespace { -namespace stage2 { - -class json_iterator { - public: - const uint8_t *const buf; - uint32_t *next_structural; - dom_parser_implementation &dom_parser; - uint32_t depth{0}; - - /** - * Walk the JSON document. - * - * The visitor receives callbacks when values are encountered. All callbacks - * pass the iterator as - * the first parameter; some callbacks have other parameters as well: - * - * - visit_document_start() - at the beginning. - * - visit_document_end() - at the end (if things were successful). - * - * - visit_array_start() - at the start `[` of a non-empty array. - * - visit_array_end() - at the end `]` of a non-empty array. - * - visit_empty_array() - when an empty array is encountered. - * - * - visit_object_end() - at the start `]` of a non-empty object. - * - visit_object_start() - at the end `]` of a non-empty object. - * - visit_empty_object() - when an empty object is encountered. - * - visit_key(const uint8_t *key) - when a key in an object field is - * encountered. key is - * guaranteed to point at the first quote - * of the string (`"key"`). - * - visit_primitive(const uint8_t *value) - when a value is a string, - * number, boolean or null. - * - visit_root_primitive(iter, uint8_t *value) - when the top-level value - * is a string, number, boolean or null. - * - * - increment_count(iter) - each time a value is found in an array or - * object. - */ - template - simdjson_warn_unused simdjson_really_inline error_code - walk_document(V &visitor) noexcept; - - /** - * Create an iterator capable of walking a JSON document. - * - * The document must have already passed through stage 1. - */ - simdjson_really_inline json_iterator(dom_parser_implementation &_dom_parser, - size_t start_structural_index); - - /** - * Look at the next token. - * - * Tokens can be strings, numbers, booleans, null, or operators (`[{]},:`)). - * - * They may include invalid JSON as well (such as `1.2.3` or `ture`). - */ - simdjson_really_inline const uint8_t *peek() const noexcept; - /** - * Advance to the next token. - * - * Tokens can be strings, numbers, booleans, null, or operators (`[{]},:`)). - * - * They may include invalid JSON as well (such as `1.2.3` or `ture`). - */ - simdjson_really_inline const uint8_t *advance() noexcept; - /** - * Get the remaining length of the document, from the start of the current - * token. - */ - simdjson_really_inline size_t remaining_len() const noexcept; - /** - * Check if we are at the end of the document. - * - * If this is true, there are no more tokens. - */ - simdjson_really_inline bool at_eof() const noexcept; - /** - * Check if we are at the beginning of the document. - */ - simdjson_really_inline bool at_beginning() const noexcept; - simdjson_really_inline uint8_t last_structural() const noexcept; - - /** - * Log that a value has been found. - * - * Set LOG_ENABLED=true in logger.h to see logging. - */ - simdjson_really_inline void log_value(const char *type) const noexcept; - /** - * Log the start of a multipart value. - * - * Set LOG_ENABLED=true in logger.h to see logging. - */ - simdjson_really_inline void log_start_value(const char *type) const - noexcept; - /** - * Log the end of a multipart value. - * - * Set LOG_ENABLED=true in logger.h to see logging. - */ - simdjson_really_inline void log_end_value(const char *type) const noexcept; - /** - * Log an error. - * - * Set LOG_ENABLED=true in logger.h to see logging. - */ - simdjson_really_inline void log_error(const char *error) const noexcept; - - template - simdjson_warn_unused simdjson_really_inline error_code - visit_root_primitive(V &visitor, const uint8_t *value) noexcept; - template - simdjson_warn_unused simdjson_really_inline error_code - visit_primitive(V &visitor, const uint8_t *value) noexcept; -}; - -template -simdjson_warn_unused simdjson_really_inline error_code -json_iterator::walk_document(V &visitor) noexcept { - logger::log_start(); - - // - // Start the document - // - if (at_eof()) { - return EMPTY; - } - log_start_value("document"); - SIMDJSON_TRY(visitor.visit_document_start(*this)); - - // - // Read first value - // - { - auto value = advance(); - - // Make sure the outer object or array is closed before continuing; - // otherwise, there are ways we - // could get into memory corruption. See - // https://github.com/simdjson/simdjson/issues/906 - if (!STREAMING) { - switch (*value) { - case '{': - if (last_structural() != '}') { - log_value("starting brace unmatched"); - return TAPE_ERROR; - }; - break; - case '[': - if (last_structural() != ']') { - log_value("starting bracket unmatched"); - return TAPE_ERROR; - }; - break; - } - } - - switch (*value) { - case '{': - if (*peek() == '}') { - advance(); - log_value("empty object"); - SIMDJSON_TRY(visitor.visit_empty_object(*this)); - break; - } - goto object_begin; - case '[': - if (*peek() == ']') { - advance(); - log_value("empty array"); - SIMDJSON_TRY(visitor.visit_empty_array(*this)); - break; - } - goto array_begin; - default: - SIMDJSON_TRY(visitor.visit_root_primitive(*this, value)); - break; - } - } - goto document_end; - -// -// Object parser states -// -object_begin: - log_start_value("object"); - depth++; - if (depth >= dom_parser.max_depth()) { - log_error("Exceeded max depth!"); - return DEPTH_ERROR; - } - dom_parser.is_array[depth] = false; - SIMDJSON_TRY(visitor.visit_object_start(*this)); - - { - auto key = advance(); - if (*key != '"') { - log_error("Object does not start with a key"); - return TAPE_ERROR; - } - SIMDJSON_TRY(visitor.increment_count(*this)); - SIMDJSON_TRY(visitor.visit_key(*this, key)); - } - -object_field: - if (simdjson_unlikely(*advance() != ':')) { - log_error("Missing colon after key in object"); - return TAPE_ERROR; - } - { - auto value = advance(); - switch (*value) { - case '{': - if (*peek() == '}') { - advance(); - log_value("empty object"); - SIMDJSON_TRY(visitor.visit_empty_object(*this)); - break; - } - goto object_begin; - case '[': - if (*peek() == ']') { - advance(); - log_value("empty array"); - SIMDJSON_TRY(visitor.visit_empty_array(*this)); - break; - } - goto array_begin; - default: - SIMDJSON_TRY(visitor.visit_primitive(*this, value)); - break; - } - } - -object_continue: - switch (*advance()) { - case ',': - SIMDJSON_TRY(visitor.increment_count(*this)); - { - auto key = advance(); - if (simdjson_unlikely(*key != '"')) { - log_error( - "Key string missing at beginning of field in object"); - return TAPE_ERROR; - } - SIMDJSON_TRY(visitor.visit_key(*this, key)); - } - goto object_field; - case '}': - log_end_value("object"); - SIMDJSON_TRY(visitor.visit_object_end(*this)); - goto scope_end; - default: - log_error("No comma between object fields"); - return TAPE_ERROR; - } - -scope_end: - depth--; - if (depth == 0) { - goto document_end; - } - if (dom_parser.is_array[depth]) { - goto array_continue; - } - goto object_continue; - -// -// Array parser states -// -array_begin: - log_start_value("array"); - depth++; - if (depth >= dom_parser.max_depth()) { - log_error("Exceeded max depth!"); - return DEPTH_ERROR; - } - dom_parser.is_array[depth] = true; - SIMDJSON_TRY(visitor.visit_array_start(*this)); - SIMDJSON_TRY(visitor.increment_count(*this)); - -array_value : { - auto value = advance(); - switch (*value) { - case '{': - if (*peek() == '}') { - advance(); - log_value("empty object"); - SIMDJSON_TRY(visitor.visit_empty_object(*this)); - break; - } - goto object_begin; - case '[': - if (*peek() == ']') { - advance(); - log_value("empty array"); - SIMDJSON_TRY(visitor.visit_empty_array(*this)); - break; - } - goto array_begin; - default: - SIMDJSON_TRY(visitor.visit_primitive(*this, value)); - break; - } -} - -array_continue: - switch (*advance()) { - case ',': - SIMDJSON_TRY(visitor.increment_count(*this)); - goto array_value; - case ']': - log_end_value("array"); - SIMDJSON_TRY(visitor.visit_array_end(*this)); - goto scope_end; - default: - log_error("Missing comma between array values"); - return TAPE_ERROR; - } - -document_end: - log_end_value("document"); - SIMDJSON_TRY(visitor.visit_document_end(*this)); - - dom_parser.next_structural_index = - uint32_t(next_structural - &dom_parser.structural_indexes[0]); - - // If we didn't make it to the end, it's an error - if (!STREAMING && - dom_parser.next_structural_index != dom_parser.n_structural_indexes) { - log_error( - "More than one JSON value at the root of the document, or extra " - "characters at the end of the JSON!"); - return TAPE_ERROR; - } - - return SUCCESS; - -} // walk_document() - -simdjson_really_inline json_iterator::json_iterator( - dom_parser_implementation &_dom_parser, size_t start_structural_index) - : buf{_dom_parser.buf}, - next_structural{&_dom_parser.structural_indexes[start_structural_index]}, - dom_parser{_dom_parser} {} - -simdjson_really_inline const uint8_t *json_iterator::peek() const noexcept { - return &buf[*(next_structural)]; -} -simdjson_really_inline const uint8_t *json_iterator::advance() noexcept { - return &buf[*(next_structural++)]; -} -simdjson_really_inline size_t json_iterator::remaining_len() const noexcept { - return dom_parser.len - *(next_structural - 1); -} - -simdjson_really_inline bool json_iterator::at_eof() const noexcept { - return next_structural == - &dom_parser.structural_indexes[dom_parser.n_structural_indexes]; -} -simdjson_really_inline bool json_iterator::at_beginning() const noexcept { - return next_structural == dom_parser.structural_indexes.get(); -} -simdjson_really_inline uint8_t json_iterator::last_structural() const noexcept { - return buf[dom_parser - .structural_indexes[dom_parser.n_structural_indexes - 1]]; -} - -simdjson_really_inline void json_iterator::log_value(const char *type) const - noexcept { - logger::log_line(*this, "", type, ""); -} - -simdjson_really_inline void json_iterator::log_start_value( - const char *type) const noexcept { - logger::log_line(*this, "+", type, ""); - if (logger::LOG_ENABLED) { - logger::log_depth++; - } -} - -simdjson_really_inline void json_iterator::log_end_value(const char *type) const - noexcept { - if (logger::LOG_ENABLED) { - logger::log_depth--; - } - logger::log_line(*this, "-", type, ""); -} - -simdjson_really_inline void json_iterator::log_error(const char *error) const - noexcept { - logger::log_line(*this, "", "ERROR", error); -} - -template -simdjson_warn_unused simdjson_really_inline error_code -json_iterator::visit_root_primitive(V &visitor, const uint8_t *value) noexcept { - switch (*value) { - case '"': - return visitor.visit_root_string(*this, value); - case 't': - return visitor.visit_root_true_atom(*this, value); - case 'f': - return visitor.visit_root_false_atom(*this, value); - case 'n': - return visitor.visit_root_null_atom(*this, value); - case '-': - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - return visitor.visit_root_number(*this, value); - default: - log_error("Document starts with a non-value character"); - return TAPE_ERROR; - } -} -template -simdjson_warn_unused simdjson_really_inline error_code -json_iterator::visit_primitive(V &visitor, const uint8_t *value) noexcept { - switch (*value) { - case '"': - return visitor.visit_string(*this, value); - case 't': - return visitor.visit_true_atom(*this, value); - case 'f': - return visitor.visit_false_atom(*this, value); - case 'n': - return visitor.visit_null_atom(*this, value); - case '-': - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - return visitor.visit_number(*this, value); - default: - log_error("Non-value found when value was expected!"); - return TAPE_ERROR; - } -} - -} // namespace stage2 -} // unnamed namespace -} // namespace haswell -} // namespace simdjson -/* end file src/generic/stage2/json_iterator.h */ -/* begin file src/generic/stage2/tape_writer.h */ -namespace simdjson { -namespace haswell { -namespace { -namespace stage2 { - -struct tape_writer { - /** The next place to write to tape */ - uint64_t *next_tape_loc; - - /** Write a signed 64-bit value to tape. */ - simdjson_really_inline void append_s64(int64_t value) noexcept; - - /** Write an unsigned 64-bit value to tape. */ - simdjson_really_inline void append_u64(uint64_t value) noexcept; - - /** Write a double value to tape. */ - simdjson_really_inline void append_double(double value) noexcept; - - /** - * Append a tape entry (an 8-bit type,and 56 bits worth of value). - */ - simdjson_really_inline void append(uint64_t val, - internal::tape_type t) noexcept; - - /** - * Skip the current tape entry without writing. - * - * Used to skip the start of the container, since we'll come back later to - * fill it in when the - * container ends. - */ - simdjson_really_inline void skip() noexcept; - - /** - * Skip the number of tape entries necessary to write a large u64 or i64. - */ - simdjson_really_inline void skip_large_integer() noexcept; - - /** - * Skip the number of tape entries necessary to write a double. - */ - simdjson_really_inline void skip_double() noexcept; - - /** - * Write a value to a known location on tape. - * - * Used to go back and write out the start of a container after the - * container ends. - */ - simdjson_really_inline static void write(uint64_t &tape_loc, - uint64_t val, - internal::tape_type t) noexcept; - - private: - /** - * Append both the tape entry, and a supplementary value following it. Used - * for types that need - * all 64 bits, such as double and uint64_t. - */ - template - simdjson_really_inline void append2(uint64_t val, - T val2, - internal::tape_type t) noexcept; -}; // struct number_writer - -simdjson_really_inline void tape_writer::append_s64(int64_t value) noexcept { - append2(0, value, internal::tape_type::INT64); -} - -simdjson_really_inline void tape_writer::append_u64(uint64_t value) noexcept { - append(0, internal::tape_type::UINT64); - *next_tape_loc = value; - next_tape_loc++; -} - -/** Write a double value to tape. */ -simdjson_really_inline void tape_writer::append_double(double value) noexcept { - append2(0, value, internal::tape_type::DOUBLE); -} - -simdjson_really_inline void tape_writer::skip() noexcept { next_tape_loc++; } - -simdjson_really_inline void tape_writer::skip_large_integer() noexcept { - next_tape_loc += 2; -} - -simdjson_really_inline void tape_writer::skip_double() noexcept { - next_tape_loc += 2; -} - -simdjson_really_inline void tape_writer::append( - uint64_t val, internal::tape_type t) noexcept { - *next_tape_loc = val | ((uint64_t(char(t))) << 56); - next_tape_loc++; -} - -template -simdjson_really_inline void tape_writer::append2( - uint64_t val, T val2, internal::tape_type t) noexcept { - append(val, t); - static_assert(sizeof(val2) == sizeof(*next_tape_loc), - "Type is not 64 bits!"); - memcpy(next_tape_loc, &val2, sizeof(val2)); - next_tape_loc++; -} - -simdjson_really_inline void tape_writer::write(uint64_t &tape_loc, - uint64_t val, - internal::tape_type t) noexcept { - tape_loc = val | ((uint64_t(char(t))) << 56); -} - -} // namespace stage2 -} // unnamed namespace -} // namespace haswell -} // namespace simdjson -/* end file src/generic/stage2/tape_writer.h */ - -namespace simdjson { -namespace haswell { -namespace { -namespace stage2 { - -struct tape_builder { - template - simdjson_warn_unused static simdjson_really_inline error_code - parse_document(dom_parser_implementation &dom_parser, - dom::document &doc) noexcept; - - /** Called when a non-empty document starts. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_document_start(json_iterator &iter) noexcept; - /** Called when a non-empty document ends without error. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_document_end(json_iterator &iter) noexcept; - - /** Called when a non-empty array starts. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_array_start(json_iterator &iter) noexcept; - /** Called when a non-empty array ends. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_array_end(json_iterator &iter) noexcept; - /** Called when an empty array is found. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_empty_array(json_iterator &iter) noexcept; - - /** Called when a non-empty object starts. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_object_start(json_iterator &iter) noexcept; - /** - * Called when a key in a field is encountered. - * - * primitive, visit_object_start, visit_empty_object, visit_array_start, or - * visit_empty_array - * will be called after this with the field value. - */ - simdjson_warn_unused simdjson_really_inline error_code - visit_key(json_iterator &iter, const uint8_t *key) noexcept; - /** Called when a non-empty object ends. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_object_end(json_iterator &iter) noexcept; - /** Called when an empty object is found. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_empty_object(json_iterator &iter) noexcept; - - /** - * Called when a string, number, boolean or null is found. - */ - simdjson_warn_unused simdjson_really_inline error_code - visit_primitive(json_iterator &iter, const uint8_t *value) noexcept; - /** - * Called when a string, number, boolean or null is found at the top level - * of a document (i.e. - * when there is no array or object and the entire document is a single - * string, number, boolean or - * null. - * - * This is separate from primitive() because simdjson's normal primitive - * parsing routines assume - * there is at least one more token after the value, which is only true in - * an array or object. - */ - simdjson_warn_unused simdjson_really_inline error_code - visit_root_primitive(json_iterator &iter, const uint8_t *value) noexcept; - - simdjson_warn_unused simdjson_really_inline error_code visit_string( - json_iterator &iter, const uint8_t *value, bool key = false) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_number(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_true_atom(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_false_atom(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_null_atom(json_iterator &iter, const uint8_t *value) noexcept; - - simdjson_warn_unused simdjson_really_inline error_code - visit_root_string(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_root_number(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_root_true_atom(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_root_false_atom(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_root_null_atom(json_iterator &iter, const uint8_t *value) noexcept; - - /** Called each time a new field or element in an array or object is found. - */ - simdjson_warn_unused simdjson_really_inline error_code - increment_count(json_iterator &iter) noexcept; - - /** Next location to write to tape */ - tape_writer tape; - - private: - /** Next write location in the string buf for stage 2 parsing */ - uint8_t *current_string_buf_loc; - - simdjson_really_inline tape_builder(dom::document &doc) noexcept; - - simdjson_really_inline uint32_t next_tape_index(json_iterator &iter) const - noexcept; - simdjson_really_inline void start_container(json_iterator &iter) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - end_container(json_iterator &iter, - internal::tape_type start, - internal::tape_type end) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - empty_container(json_iterator &iter, - internal::tape_type start, - internal::tape_type end) noexcept; - simdjson_really_inline uint8_t *on_start_string( - json_iterator &iter) noexcept; - simdjson_really_inline void on_end_string(uint8_t *dst) noexcept; -}; // class tape_builder - -template -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::parse_document(dom_parser_implementation &dom_parser, - dom::document &doc) noexcept { - dom_parser.doc = &doc; - json_iterator iter(dom_parser, - STREAMING ? dom_parser.next_structural_index : 0); - tape_builder builder(doc); - return iter.walk_document(builder); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_primitive(json_iterator &iter, - const uint8_t *value) noexcept { - return iter.visit_root_primitive(*this, value); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_primitive(json_iterator &iter, - const uint8_t *value) noexcept { - return iter.visit_primitive(*this, value); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_empty_object(json_iterator &iter) noexcept { - return empty_container(iter, - internal::tape_type::START_OBJECT, - internal::tape_type::END_OBJECT); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_empty_array(json_iterator &iter) noexcept { - return empty_container( - iter, internal::tape_type::START_ARRAY, internal::tape_type::END_ARRAY); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_document_start(json_iterator &iter) noexcept { - start_container(iter); - return SUCCESS; -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_object_start(json_iterator &iter) noexcept { - start_container(iter); - return SUCCESS; -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_array_start(json_iterator &iter) noexcept { - start_container(iter); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_object_end(json_iterator &iter) noexcept { - return end_container(iter, - internal::tape_type::START_OBJECT, - internal::tape_type::END_OBJECT); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_array_end(json_iterator &iter) noexcept { - return end_container( - iter, internal::tape_type::START_ARRAY, internal::tape_type::END_ARRAY); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_document_end(json_iterator &iter) noexcept { - constexpr uint32_t start_tape_index = 0; - tape.append(start_tape_index, internal::tape_type::ROOT); - tape_writer::write(iter.dom_parser.doc->tape[start_tape_index], - next_tape_index(iter), - internal::tape_type::ROOT); - return SUCCESS; -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_key(json_iterator &iter, const uint8_t *key) noexcept { - return visit_string(iter, key, true); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::increment_count(json_iterator &iter) noexcept { - iter.dom_parser.open_containers[iter.depth] - .count++; // we have a key value pair in the object at - // parser.dom_parser.depth - 1 - return SUCCESS; -} - -simdjson_really_inline tape_builder::tape_builder(dom::document &doc) noexcept - : tape{doc.tape.get()}, - current_string_buf_loc{doc.string_buf.get()} {} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_string(json_iterator &iter, - const uint8_t *value, - bool key) noexcept { - iter.log_value(key ? "key" : "string"); - uint8_t *dst = on_start_string(iter); - dst = stringparsing::parse_string(value + 1, dst); - if (dst == nullptr) { - iter.log_error("Invalid escape in string"); - return STRING_ERROR; - } - on_end_string(dst); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_string(json_iterator &iter, - const uint8_t *value) noexcept { - return visit_string(iter, value); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_number(json_iterator &iter, const uint8_t *value) noexcept { - iter.log_value("number"); - return numberparsing::parse_number(value, tape); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_number(json_iterator &iter, - const uint8_t *value) noexcept { - // - // We need to make a copy to make sure that the string is space terminated. - // This is not about padding the input, which should already padded up - // to len + SIMDJSON_PADDING. However, we have no control at this stage - // on how the padding was done. What if the input string was padded with - // nulls? - // It is quite common for an input string to have an extra null character (C - // string). - // We do not want to allow 9\0 (where \0 is the null character) inside a - // JSON - // document, but the string "9\0" by itself is fine. So we make a copy and - // pad the input with spaces when we know that there is just one input - // element. - // This copy is relatively expensive, but it will almost never be called in - // practice unless you are in the strange scenario where you have many JSON - // documents made of single atoms. - // - std::unique_ptr copy( - new (std::nothrow) uint8_t[iter.remaining_len() + SIMDJSON_PADDING]); - if (copy.get() == nullptr) { - return MEMALLOC; - } - std::memcpy(copy.get(), value, iter.remaining_len()); - std::memset(copy.get() + iter.remaining_len(), ' ', SIMDJSON_PADDING); - error_code error = visit_number(iter, copy.get()); - return error; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_true_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("true"); - if (!atomparsing::is_valid_true_atom(value)) { - return T_ATOM_ERROR; - } - tape.append(0, internal::tape_type::TRUE_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_true_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("true"); - if (!atomparsing::is_valid_true_atom(value, iter.remaining_len())) { - return T_ATOM_ERROR; - } - tape.append(0, internal::tape_type::TRUE_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_false_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("false"); - if (!atomparsing::is_valid_false_atom(value)) { - return F_ATOM_ERROR; - } - tape.append(0, internal::tape_type::FALSE_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_false_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("false"); - if (!atomparsing::is_valid_false_atom(value, iter.remaining_len())) { - return F_ATOM_ERROR; - } - tape.append(0, internal::tape_type::FALSE_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_null_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("null"); - if (!atomparsing::is_valid_null_atom(value)) { - return N_ATOM_ERROR; - } - tape.append(0, internal::tape_type::NULL_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_null_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("null"); - if (!atomparsing::is_valid_null_atom(value, iter.remaining_len())) { - return N_ATOM_ERROR; - } - tape.append(0, internal::tape_type::NULL_VALUE); - return SUCCESS; -} - -// private: - -simdjson_really_inline uint32_t -tape_builder::next_tape_index(json_iterator &iter) const noexcept { - return uint32_t(tape.next_tape_loc - iter.dom_parser.doc->tape.get()); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::empty_container(json_iterator &iter, - internal::tape_type start, - internal::tape_type end) noexcept { - auto start_index = next_tape_index(iter); - tape.append(start_index + 2, start); - tape.append(start_index, end); - return SUCCESS; -} - -simdjson_really_inline void tape_builder::start_container( - json_iterator &iter) noexcept { - iter.dom_parser.open_containers[iter.depth].tape_index = - next_tape_index(iter); - iter.dom_parser.open_containers[iter.depth].count = 0; - tape.skip(); // We don't actually *write* the start element until the end. -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::end_container(json_iterator &iter, - internal::tape_type start, - internal::tape_type end) noexcept { - // Write the ending tape element, pointing at the start location - const uint32_t start_tape_index = - iter.dom_parser.open_containers[iter.depth].tape_index; - tape.append(start_tape_index, end); - // Write the start tape element, pointing at the end location (and including - // count) - // count can overflow if it exceeds 24 bits... so we saturate - // the convention being that a cnt of 0xffffff or more is undetermined in - // value (>= 0xffffff). - const uint32_t count = iter.dom_parser.open_containers[iter.depth].count; - const uint32_t cntsat = count > 0xFFFFFF ? 0xFFFFFF : count; - tape_writer::write(iter.dom_parser.doc->tape[start_tape_index], - next_tape_index(iter) | (uint64_t(cntsat) << 32), - start); - return SUCCESS; -} - -simdjson_really_inline uint8_t *tape_builder::on_start_string( - json_iterator &iter) noexcept { - // we advance the point, accounting for the fact that we have a NULL - // termination - tape.append(current_string_buf_loc - iter.dom_parser.doc->string_buf.get(), - internal::tape_type::STRING); - return current_string_buf_loc + sizeof(uint32_t); -} - -simdjson_really_inline void tape_builder::on_end_string(uint8_t *dst) noexcept { - uint32_t str_length = - uint32_t(dst - (current_string_buf_loc + sizeof(uint32_t))); - // TODO check for overflow in case someone has a crazy string (>=4GB?) - // But only add the overflow check when the document itself exceeds 4GB - // Currently unneeded because we refuse to parse docs larger or equal to - // 4GB. - memcpy(current_string_buf_loc, &str_length, sizeof(uint32_t)); - // NULL termination is still handy if you expect all your strings to - // be NULL terminated? It comes at a small cost - *dst = 0; - current_string_buf_loc = dst + 1; -} - -} // namespace stage2 -} // unnamed namespace -} // namespace haswell -} // namespace simdjson -/* end file src/generic/stage2/tape_builder.h */ - -// -// Implementation-specific overrides -// -namespace simdjson { -namespace haswell { -namespace { -namespace stage1 { - -simdjson_really_inline uint64_t -json_string_scanner::find_escaped(uint64_t backslash) { - if (!backslash) { - uint64_t escaped = prev_escaped; - prev_escaped = 0; - return escaped; - } - return find_escaped_branchless(backslash); -} - -} // namespace stage1 -} // unnamed namespace - -simdjson_warn_unused error_code implementation::minify(const uint8_t *buf, - size_t len, - uint8_t *dst, - size_t &dst_len) const - noexcept { - return haswell::stage1::json_minifier::minify<128>(buf, len, dst, dst_len); -} - -simdjson_warn_unused error_code dom_parser_implementation::stage1( - const uint8_t *_buf, size_t _len, stage1_mode streaming) noexcept { - this->buf = _buf; - this->len = _len; - return haswell::stage1::json_structural_indexer::index<128>( - _buf, _len, *this, streaming); -} - -simdjson_warn_unused bool implementation::validate_utf8(const char *buf, - size_t len) const - noexcept { - return haswell::stage1::generic_validate_utf8(buf, len); -} - -simdjson_warn_unused error_code -dom_parser_implementation::stage2(dom::document &_doc) noexcept { - return stage2::tape_builder::parse_document(*this, _doc); -} - -simdjson_warn_unused error_code -dom_parser_implementation::stage2_next(dom::document &_doc) noexcept { - return stage2::tape_builder::parse_document(*this, _doc); -} - -simdjson_warn_unused error_code dom_parser_implementation::parse( - const uint8_t *_buf, size_t _len, dom::document &_doc) noexcept { - auto error = stage1(_buf, _len, stage1_mode::regular); - if (error) { - return error; - } - return stage2(_doc); -} - -} // namespace haswell -} // namespace simdjson - -/* begin file include/simdjson/haswell/end.h */ -SIMDJSON_UNTARGET_HASWELL -/* end file include/simdjson/haswell/end.h */ -/* end file src/haswell/dom_parser_implementation.cpp */ -#endif -#if SIMDJSON_IMPLEMENTATION_PPC64 -/* begin file src/ppc64/implementation.cpp */ -/* begin file include/simdjson/ppc64/begin.h */ -// redefining SIMDJSON_IMPLEMENTATION to "ppc64" -// #define SIMDJSON_IMPLEMENTATION ppc64 -/* end file include/simdjson/ppc64/begin.h */ - -namespace simdjson { -namespace ppc64 { - -simdjson_warn_unused error_code -implementation::create_dom_parser_implementation( - size_t capacity, - size_t max_depth, - std::unique_ptr &dst) const noexcept { - dst.reset(new (std::nothrow) dom_parser_implementation()); - if (!dst) { - return MEMALLOC; - } - if (auto err = dst->set_capacity(capacity)) return err; - if (auto err = dst->set_max_depth(max_depth)) return err; - return SUCCESS; -} - -} // namespace ppc64 -} // namespace simdjson - -/* begin file include/simdjson/ppc64/end.h */ -/* end file include/simdjson/ppc64/end.h */ -/* end file src/ppc64/implementation.cpp */ -/* begin file src/ppc64/dom_parser_implementation.cpp */ -/* begin file include/simdjson/ppc64/begin.h */ -// redefining SIMDJSON_IMPLEMENTATION to "ppc64" -// #define SIMDJSON_IMPLEMENTATION ppc64 -/* end file include/simdjson/ppc64/begin.h */ - -// -// Stage 1 -// -namespace simdjson { -namespace ppc64 { -namespace { - -using namespace simd; - -struct json_character_block { - static simdjson_really_inline json_character_block - classify(const simd::simd8x64 &in); - - simdjson_really_inline uint64_t whitespace() const noexcept { - return _whitespace; - } - simdjson_really_inline uint64_t op() const noexcept { return _op; } - simdjson_really_inline uint64_t scalar() const noexcept { - return ~(op() | whitespace()); - } - - uint64_t _whitespace; - uint64_t _op; -}; - -simdjson_really_inline json_character_block -json_character_block::classify(const simd::simd8x64 &in) { - const simd8 table1( - 16, 0, 0, 0, 0, 0, 0, 0, 0, 8, 12, 1, 2, 9, 0, 0); - const simd8 table2( - 8, 0, 18, 4, 0, 1, 0, 1, 0, 0, 0, 3, 2, 1, 0, 0); - - simd8x64 v((in.chunks[0] & 0xf).lookup_16(table1) & - (in.chunks[0].shr<4>()).lookup_16(table2), - (in.chunks[1] & 0xf).lookup_16(table1) & - (in.chunks[1].shr<4>()).lookup_16(table2), - (in.chunks[2] & 0xf).lookup_16(table1) & - (in.chunks[2].shr<4>()).lookup_16(table2), - (in.chunks[3] & 0xf).lookup_16(table1) & - (in.chunks[3].shr<4>()).lookup_16(table2)); - - uint64_t op = simd8x64(v.chunks[0].any_bits_set(0x7), - v.chunks[1].any_bits_set(0x7), - v.chunks[2].any_bits_set(0x7), - v.chunks[3].any_bits_set(0x7)) - .to_bitmask(); - - uint64_t whitespace = simd8x64(v.chunks[0].any_bits_set(0x18), - v.chunks[1].any_bits_set(0x18), - v.chunks[2].any_bits_set(0x18), - v.chunks[3].any_bits_set(0x18)) - .to_bitmask(); - - return {whitespace, op}; -} - -simdjson_really_inline bool is_ascii(const simd8x64 &input) { - // careful: 0x80 is not ascii. - return input.reduce_or() - .saturating_sub(0b01111111u) - .bits_not_set_anywhere(); -} - -simdjson_unused simdjson_really_inline simd8 must_be_continuation( - const simd8 prev1, - const simd8 prev2, - const simd8 prev3) { - simd8 is_second_byte = - prev1.saturating_sub(0b11000000u - 1); // Only 11______ will be > 0 - simd8 is_third_byte = - prev2.saturating_sub(0b11100000u - 1); // Only 111_____ will be > 0 - simd8 is_fourth_byte = - prev3.saturating_sub(0b11110000u - 1); // Only 1111____ will be > 0 - // Caller requires a bool (all 1's). All values resulting from the - // subtraction will be <= 64, so signed comparison is fine. - return simd8(is_second_byte | is_third_byte | is_fourth_byte) > - int8_t(0); -} - -simdjson_really_inline simd8 must_be_2_3_continuation( - const simd8 prev2, const simd8 prev3) { - simd8 is_third_byte = - prev2.saturating_sub(0b11100000u - 1); // Only 111_____ will be > 0 - simd8 is_fourth_byte = - prev3.saturating_sub(0b11110000u - 1); // Only 1111____ will be > 0 - // Caller requires a bool (all 1's). All values resulting from the - // subtraction will be <= 64, so signed comparison is fine. - return simd8(is_third_byte | is_fourth_byte) > int8_t(0); -} - -} // unnamed namespace -} // namespace ppc64 -} // namespace simdjson - -/* begin file src/generic/stage1/utf8_lookup4_algorithm.h */ -namespace simdjson { -namespace ppc64 { -namespace { -namespace utf8_validation { - -using namespace simd; - -simdjson_really_inline simd8 check_special_cases( - const simd8 input, const simd8 prev1) { - // Bit 0 = Too Short (lead byte/ASCII followed by lead byte/ASCII) - // Bit 1 = Too Long (ASCII followed by continuation) - // Bit 2 = Overlong 3-byte - // Bit 4 = Surrogate - // Bit 5 = Overlong 2-byte - // Bit 7 = Two Continuations - constexpr const uint8_t TOO_SHORT = 1 << 0; // 11______ 0_______ - // 11______ 11______ - constexpr const uint8_t TOO_LONG = 1 << 1; // 0_______ 10______ - constexpr const uint8_t OVERLONG_3 = 1 << 2; // 11100000 100_____ - constexpr const uint8_t SURROGATE = 1 << 4; // 11101101 101_____ - constexpr const uint8_t OVERLONG_2 = 1 << 5; // 1100000_ 10______ - constexpr const uint8_t TWO_CONTS = 1 << 7; // 10______ 10______ - constexpr const uint8_t TOO_LARGE = 1 << 3; // 11110100 1001____ - // 11110100 101_____ - // 11110101 1001____ - // 11110101 101_____ - // 1111011_ 1001____ - // 1111011_ 101_____ - // 11111___ 1001____ - // 11111___ 101_____ - constexpr const uint8_t TOO_LARGE_1000 = 1 << 6; - // 11110101 1000____ - // 1111011_ 1000____ - // 11111___ 1000____ - constexpr const uint8_t OVERLONG_4 = 1 << 6; // 11110000 1000____ - - const simd8 byte_1_high = prev1.shr<4>().lookup_16( - // 0_______ ________ - TOO_LONG, - TOO_LONG, - TOO_LONG, - TOO_LONG, - TOO_LONG, - TOO_LONG, - TOO_LONG, - TOO_LONG, - // 10______ ________ - TWO_CONTS, - TWO_CONTS, - TWO_CONTS, - TWO_CONTS, - // 1100____ ________ - TOO_SHORT | OVERLONG_2, - // 1101____ ________ - TOO_SHORT, - // 1110____ ________ - TOO_SHORT | OVERLONG_3 | SURROGATE, - // 1111____ ________ - TOO_SHORT | TOO_LARGE | TOO_LARGE_1000 | OVERLONG_4); - constexpr const uint8_t CARRY = - TOO_SHORT | TOO_LONG | TWO_CONTS; // These all have ____ in byte 1 . - const simd8 byte_1_low = - (prev1 & 0x0F) - .lookup_16( - // ____0000 ________ - CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4, - // ____0001 ________ - CARRY | OVERLONG_2, - // ____001_ ________ - CARRY, - CARRY, - - // ____0100 ________ - CARRY | TOO_LARGE, - // ____0101 ________ - CARRY | TOO_LARGE | TOO_LARGE_1000, - // ____011_ ________ - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000, - - // ____1___ ________ - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000, - // ____1101 ________ - CARRY | TOO_LARGE | TOO_LARGE_1000 | SURROGATE, - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000); - const simd8 byte_2_high = input.shr<4>().lookup_16( - // ________ 0_______ - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - - // ________ 1000____ - TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE_1000 | - OVERLONG_4, - // ________ 1001____ - TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE, - // ________ 101_____ - TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE | TOO_LARGE, - TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE | TOO_LARGE, - - // ________ 11______ - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT); - return (byte_1_high & byte_1_low & byte_2_high); -} -simdjson_really_inline simd8 check_multibyte_lengths( - const simd8 input, - const simd8 prev_input, - const simd8 sc) { - simd8 prev2 = input.prev<2>(prev_input); - simd8 prev3 = input.prev<3>(prev_input); - simd8 must23 = - simd8(must_be_2_3_continuation(prev2, prev3)); - simd8 must23_80 = must23 & uint8_t(0x80); - return must23_80 ^ sc; -} - -// -// Return nonzero if there are incomplete multibyte characters at the end of the -// block: -// e.g. if there is a 4-byte character, but it's 3 bytes from the end. -// -simdjson_really_inline simd8 is_incomplete( - const simd8 input) { - // If the previous input's last 3 bytes match this, they're too short (they - // ended at EOF): - // ... 1111____ 111_____ 11______ - static const uint8_t max_array[32] = {255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 0b11110000u - 1, - 0b11100000u - 1, - 0b11000000u - 1}; - const simd8 max_value( - &max_array[sizeof(max_array) - sizeof(simd8)]); - return input.gt_bits(max_value); -} - -struct utf8_checker { - // If this is nonzero, there has been a UTF-8 error. - simd8 error; - // The last input we received - simd8 prev_input_block; - // Whether the last input we received was incomplete (used for ASCII fast - // path) - simd8 prev_incomplete; - - // - // Check whether the current bytes are valid UTF-8. - // - simdjson_really_inline void check_utf8_bytes( - const simd8 input, const simd8 prev_input) { - // Flip prev1...prev3 so we can easily determine if they are 2+, 3+ or - // 4+ lead bytes - // (2, 3, 4-byte leads become large positive numbers instead of small - // negative numbers) - simd8 prev1 = input.prev<1>(prev_input); - simd8 sc = check_special_cases(input, prev1); - this->error |= check_multibyte_lengths(input, prev_input, sc); - } - - // The only problem that can happen at EOF is that a multibyte character is - // too short - // or a byte value too large in the last bytes: check_special_cases only - // checks for bytes - // too large in the first of two bytes. - simdjson_really_inline void check_eof() { - // If the previous block had incomplete UTF-8 characters at the end, an - // ASCII block can't - // possibly finish them. - this->error |= this->prev_incomplete; - } - - simdjson_really_inline void check_next_input( - const simd8x64 &input) { - if (simdjson_likely(is_ascii(input))) { - this->error |= this->prev_incomplete; - } else { - // you might think that a for-loop would work, but under Visual - // Studio, it is not good enough. - static_assert( - (simd8x64::NUM_CHUNKS == 2) || - (simd8x64::NUM_CHUNKS == 4), - "We support either two or four chunks per 64-byte block."); - if (simd8x64::NUM_CHUNKS == 2) { - this->check_utf8_bytes(input.chunks[0], this->prev_input_block); - this->check_utf8_bytes(input.chunks[1], input.chunks[0]); - } else if (simd8x64::NUM_CHUNKS == 4) { - this->check_utf8_bytes(input.chunks[0], this->prev_input_block); - this->check_utf8_bytes(input.chunks[1], input.chunks[0]); - this->check_utf8_bytes(input.chunks[2], input.chunks[1]); - this->check_utf8_bytes(input.chunks[3], input.chunks[2]); - } - this->prev_incomplete = - is_incomplete(input.chunks[simd8x64::NUM_CHUNKS - 1]); - this->prev_input_block = - input.chunks[simd8x64::NUM_CHUNKS - 1]; - } - } - // do not forget to call check_eof! - simdjson_really_inline error_code errors() { - return this->error.any_bits_set_anywhere() ? error_code::UTF8_ERROR - : error_code::SUCCESS; - } - -}; // struct utf8_checker -} // namespace utf8_validation - -using utf8_validation::utf8_checker; - -} // unnamed namespace -} // namespace ppc64 -} // namespace simdjson -/* end file src/generic/stage1/utf8_lookup4_algorithm.h */ -/* begin file src/generic/stage1/json_structural_indexer.h */ -// This file contains the common code every implementation uses in stage1 -// It is intended to be included multiple times and compiled multiple times -// We assume the file in which it is included already includes -// "simdjson/stage1.h" (this simplifies amalgation) - -/* begin file src/generic/stage1/buf_block_reader.h */ -namespace simdjson { -namespace ppc64 { -namespace { - -// Walks through a buffer in block-sized increments, loading the last part with -// spaces -template -struct buf_block_reader { - public: - simdjson_really_inline buf_block_reader(const uint8_t *_buf, size_t _len); - simdjson_really_inline size_t block_index(); - simdjson_really_inline bool has_full_block() const; - simdjson_really_inline const uint8_t *full_block() const; - /** - * Get the last block, padded with spaces. - * - * There will always be a last block, with at least 1 byte, unless len == 0 - * (in which case this - * function fills the buffer with spaces and returns 0. In particular, if - * len == STEP_SIZE there - * will be 0 full_blocks and 1 remainder block with STEP_SIZE bytes and no - * spaces for padding. - * - * @return the number of effective characters in the last block. - */ - simdjson_really_inline size_t get_remainder(uint8_t *dst) const; - simdjson_really_inline void advance(); - - private: - const uint8_t *buf; - const size_t len; - const size_t lenminusstep; - size_t idx; -}; - -// Routines to print masks and text for debugging bitmask operations -simdjson_unused static char *format_input_text_64(const uint8_t *text) { - static char buf[sizeof(simd8x64) + 1]; - for (size_t i = 0; i < sizeof(simd8x64); i++) { - buf[i] = int8_t(text[i]) < ' ' ? '_' : int8_t(text[i]); - } - buf[sizeof(simd8x64)] = '\0'; - return buf; -} - -// Routines to print masks and text for debugging bitmask operations -simdjson_unused static char *format_input_text(const simd8x64 &in) { - static char buf[sizeof(simd8x64) + 1]; - in.store(reinterpret_cast(buf)); - for (size_t i = 0; i < sizeof(simd8x64); i++) { - if (buf[i] < ' ') { - buf[i] = '_'; - } - } - buf[sizeof(simd8x64)] = '\0'; - return buf; -} - -simdjson_unused static char *format_mask(uint64_t mask) { - static char buf[sizeof(simd8x64) + 1]; - for (size_t i = 0; i < 64; i++) { - buf[i] = (mask & (size_t(1) << i)) ? 'X' : ' '; - } - buf[64] = '\0'; - return buf; -} - -template -simdjson_really_inline buf_block_reader::buf_block_reader( - const uint8_t *_buf, size_t _len) - : buf{_buf}, - len{_len}, - lenminusstep{len < STEP_SIZE ? 0 : len - STEP_SIZE}, - idx{0} {} - -template -simdjson_really_inline size_t buf_block_reader::block_index() { - return idx; -} - -template -simdjson_really_inline bool buf_block_reader::has_full_block() - const { - return idx < lenminusstep; -} - -template -simdjson_really_inline const uint8_t *buf_block_reader::full_block() - const { - return &buf[idx]; -} - -template -simdjson_really_inline size_t -buf_block_reader::get_remainder(uint8_t *dst) const { - if (len == idx) { - return 0; - } // memcpy(dst, null, 0) will trigger an error with some sanitizers - std::memset(dst, 0x20, STEP_SIZE); // std::memset STEP_SIZE because it's - // more efficient to write out 8 or 16 - // bytes at once. - std::memcpy(dst, buf + idx, len - idx); - return len - idx; -} - -template -simdjson_really_inline void buf_block_reader::advance() { - idx += STEP_SIZE; -} - -} // unnamed namespace -} // namespace ppc64 -} // namespace simdjson -/* end file src/generic/stage1/buf_block_reader.h */ -/* begin file src/generic/stage1/json_string_scanner.h */ -namespace simdjson { -namespace ppc64 { -namespace { -namespace stage1 { - -struct json_string_block { - // We spell out the constructors in the hope of resolving inlining issues - // with Visual Studio 2017 - simdjson_really_inline json_string_block(uint64_t backslash, - uint64_t escaped, - uint64_t quote, - uint64_t in_string) - : _backslash(backslash), - _escaped(escaped), - _quote(quote), - _in_string(in_string) {} - - // Escaped characters (characters following an escape() character) - simdjson_really_inline uint64_t escaped() const { return _escaped; } - // Escape characters (backslashes that are not escaped--i.e. in \\, includes - // only the first \) - simdjson_really_inline uint64_t escape() const { - return _backslash & ~_escaped; - } - // Real (non-backslashed) quotes - simdjson_really_inline uint64_t quote() const { return _quote; } - // Start quotes of strings - simdjson_really_inline uint64_t string_start() const { - return _quote & _in_string; - } - // End quotes of strings - simdjson_really_inline uint64_t string_end() const { - return _quote & ~_in_string; - } - // Only characters inside the string (not including the quotes) - simdjson_really_inline uint64_t string_content() const { - return _in_string & ~_quote; - } - // Return a mask of whether the given characters are inside a string (only - // works on non-quotes) - simdjson_really_inline uint64_t - non_quote_inside_string(uint64_t mask) const { - return mask & _in_string; - } - // Return a mask of whether the given characters are inside a string (only - // works on non-quotes) - simdjson_really_inline uint64_t - non_quote_outside_string(uint64_t mask) const { - return mask & ~_in_string; - } - // Tail of string (everything except the start quote) - simdjson_really_inline uint64_t string_tail() const { - return _in_string ^ _quote; - } - - // backslash characters - uint64_t _backslash; - // escaped characters (backslashed--does not include the hex characters - // after \u) - uint64_t _escaped; - // real quotes (non-backslashed ones) - uint64_t _quote; - // string characters (includes start quote but not end quote) - uint64_t _in_string; -}; - -// Scans blocks for string characters, storing the state necessary to do so -class json_string_scanner { - public: - simdjson_really_inline json_string_block - next(const simd::simd8x64 &in); - // Returns either UNCLOSED_STRING or SUCCESS - simdjson_really_inline error_code finish(); - - private: - // Intended to be defined by the implementation - simdjson_really_inline uint64_t find_escaped(uint64_t escape); - simdjson_really_inline uint64_t find_escaped_branchless(uint64_t escape); - - // Whether the last iteration was still inside a string (all 1's = true, all - // 0's = false). - uint64_t prev_in_string = 0ULL; - // Whether the first character of the next iteration is escaped. - uint64_t prev_escaped = 0ULL; -}; - -// -// Finds escaped characters (characters following \). -// -// Handles runs of backslashes like \\\" and \\\\" correctly (yielding 0101 and -// 01010, respectively). -// -// Does this by: -// - Shift the escape mask to get potentially escaped characters (characters -// after backslashes). -// - Mask escaped sequences that start on *even* bits with 1010101010 (odd bits -// are escaped, even bits are not) -// - Mask escaped sequences that start on *odd* bits with 0101010101 (even bits -// are escaped, odd bits are not) -// -// To distinguish between escaped sequences starting on even/odd bits, it finds -// the start of all -// escape sequences, filters out the ones that start on even bits, and adds that -// to the mask of -// escape sequences. This causes the addition to clear out the sequences -// starting on odd bits (since -// the start bit causes a carry), and leaves even-bit sequences alone. -// -// Example: -// -// text | \\\ | \\\"\\\" \\\" \\"\\" | -// escape | xxx | xx xxx xxx xx xx | Removed overflow backslash; -// will | it into follows_escape -// odd_starts | x | x x x | escape & ~even_bits & -// ~follows_escape -// even_seq | c| cxxx c xx c | c = carry bit -- will be -// masked out later -// invert_mask | | cxxx c xx c| even_seq << 1 -// follows_escape | xx | x xx xxx xxx xx xx | Includes overflow bit -// escaped | x | x x x x x x x x | -// desired | x | x x x x x x x x | -// text | \\\ | \\\"\\\" \\\" \\"\\" | -// -simdjson_really_inline uint64_t -json_string_scanner::find_escaped_branchless(uint64_t backslash) { - // If there was overflow, pretend the first character isn't a backslash - backslash &= ~prev_escaped; - uint64_t follows_escape = backslash << 1 | prev_escaped; - - // Get sequences starting on even bits by clearing out the odd series using - // + - const uint64_t even_bits = 0x5555555555555555ULL; - uint64_t odd_sequence_starts = backslash & ~even_bits & ~follows_escape; - uint64_t sequences_starting_on_even_bits; - prev_escaped = add_overflow( - odd_sequence_starts, backslash, &sequences_starting_on_even_bits); - uint64_t invert_mask = - sequences_starting_on_even_bits - << 1; // The mask we want to return is the *escaped* bits, not escapes. - - // Mask every other backslashed character as an escaped character - // Flip the mask for sequences that start on even bits, to correct them - return (even_bits ^ invert_mask) & follows_escape; -} - -// -// Return a mask of all string characters plus end quotes. -// -// prev_escaped is overflow saying whether the next character is escaped. -// prev_in_string is overflow saying whether we're still in a string. -// -// Backslash sequences outside of quotes will be detected in stage 2. -// -simdjson_really_inline json_string_block -json_string_scanner::next(const simd::simd8x64 &in) { - const uint64_t backslash = in.eq('\\'); - const uint64_t escaped = find_escaped(backslash); - const uint64_t quote = in.eq('"') & ~escaped; - - // - // prefix_xor flips on bits inside the string (and flips off the end quote). - // - // Then we xor with prev_in_string: if we were in a string already, its - // effect is flipped - // (characters inside strings are outside, and characters outside strings - // are inside). - // - const uint64_t in_string = prefix_xor(quote) ^ prev_in_string; - - // - // Check if we're still in a string at the end of the box so the next block - // will know - // - // right shift of a signed value expected to be well-defined and standard - // compliant as of C++20, John Regher from Utah U. says this is fine code - // - prev_in_string = uint64_t(static_cast(in_string) >> 63); - - // Use ^ to turn the beginning quote off, and the end quote on. - - // We are returning a function-local object so either we get a move - // constructor - // or we get copy elision. - return json_string_block(backslash, escaped, quote, in_string); -} - -simdjson_really_inline error_code json_string_scanner::finish() { - if (prev_in_string) { - return UNCLOSED_STRING; - } - return SUCCESS; -} - -} // namespace stage1 -} // unnamed namespace -} // namespace ppc64 -} // namespace simdjson -/* end file src/generic/stage1/json_string_scanner.h */ -/* begin file src/generic/stage1/json_scanner.h */ -namespace simdjson { -namespace ppc64 { -namespace { -namespace stage1 { - -/** - * A block of scanned json, with information on operators and scalars. - * - * We seek to identify pseudo-structural characters. Anything that is inside - * a string must be omitted (hence & ~_string.string_tail()). - * Otherwise, pseudo-structural characters come in two forms. - * 1. We have the structural characters ([,],{,},:, comma). The - * term 'structural character' is from the JSON RFC. - * 2. We have the 'scalar pseudo-structural characters'. - * Scalars are quotes, and any character except structural characters and - * white space. - * - * To identify the scalar pseudo-structural characters, we must look at what - * comes - * before them: it must be a space, a quote or a structural characters. - * Starting with simdjson v0.3, we identify them by - * negation: we identify everything that is followed by a non-quote scalar, - * and we negate that. Whatever remains must be a 'scalar pseudo-structural - * character'. - */ -struct json_block { - public: - // We spell out the constructors in the hope of resolving inlining issues - // with Visual Studio 2017 - simdjson_really_inline json_block( - json_string_block &&string, - json_character_block characters, - uint64_t follows_potential_nonquote_scalar) - : _string(std::move(string)), - _characters(characters), - _follows_potential_nonquote_scalar( - follows_potential_nonquote_scalar) {} - simdjson_really_inline json_block( - json_string_block string, - json_character_block characters, - uint64_t follows_potential_nonquote_scalar) - : _string(string), - _characters(characters), - _follows_potential_nonquote_scalar( - follows_potential_nonquote_scalar) {} - - /** - * The start of structurals. - * In simdjson prior to v0.3, these were called the pseudo-structural - *characters. - **/ - simdjson_really_inline uint64_t structural_start() const noexcept { - return potential_structural_start() & ~_string.string_tail(); - } - /** All JSON whitespace (i.e. not in a string) */ - simdjson_really_inline uint64_t whitespace() const noexcept { - return non_quote_outside_string(_characters.whitespace()); - } - - // Helpers - - /** Whether the given characters are inside a string (only works on - * non-quotes) */ - simdjson_really_inline uint64_t non_quote_inside_string(uint64_t mask) const - noexcept { - return _string.non_quote_inside_string(mask); - } - /** Whether the given characters are outside a string (only works on - * non-quotes) */ - simdjson_really_inline uint64_t - non_quote_outside_string(uint64_t mask) const noexcept { - return _string.non_quote_outside_string(mask); - } - - // string and escape characters - json_string_block _string; - // whitespace, structural characters ('operators'), scalars - json_character_block _characters; - // whether the previous character was a scalar - uint64_t _follows_potential_nonquote_scalar; - - private: - // Potential structurals (i.e. disregarding strings) - - /** - * structural elements ([,],{,},:, comma) plus scalar starts like 123, true - *and "abc". - * They may reside inside a string. - **/ - simdjson_really_inline uint64_t potential_structural_start() const - noexcept { - return _characters.op() | potential_scalar_start(); - } - /** - * The start of non-operator runs, like 123, true and "abc". - * It main reside inside a string. - **/ - simdjson_really_inline uint64_t potential_scalar_start() const noexcept { - // The term "scalar" refers to anything except structural characters and - // white space - // (so letters, numbers, quotes). - // Whenever it is preceded by something that is not a structural element - // ({,},[,],:, ") nor a white-space - // then we know that it is irrelevant structurally. - return _characters.scalar() & ~follows_potential_scalar(); - } - /** - * Whether the given character is immediately after a non-operator like 123, - * true. - * The characters following a quote are not included. - */ - simdjson_really_inline uint64_t follows_potential_scalar() const noexcept { - // _follows_potential_nonquote_scalar: is defined as marking any - // character that follows a character - // that is not a structural element ({,},[,],:, comma) nor a quote (") - // and that is not a - // white space. - // It is understood that within quoted region, anything at all could be - // marked (irrelevant). - return _follows_potential_nonquote_scalar; - } -}; - -/** - * Scans JSON for important bits: structural characters or 'operators', strings, - * and scalars. - * - * The scanner starts by calculating two distinct things: - * - string characters (taking \" into account) - * - structural characters or 'operators' ([]{},:, comma) - * and scalars (runs of non-operators like 123, true and "abc") - * - * To minimize data dependency (a key component of the scanner's speed), it - * finds these in parallel: - * in particular, the operator/scalar bit will find plenty of things that are - * actually part of - * strings. When we're done, json_block will fuse the two together by masking - * out tokens that are - * part of a string. - */ -class json_scanner { - public: - json_scanner() {} - simdjson_really_inline json_block next(const simd::simd8x64 &in); - // Returns either UNCLOSED_STRING or SUCCESS - simdjson_really_inline error_code finish(); - - private: - // Whether the last character of the previous iteration is part of a scalar - // token - // (anything except whitespace or a structural character/'operator'). - uint64_t prev_scalar = 0ULL; - json_string_scanner string_scanner{}; -}; - - -// -// Check if the current character immediately follows a matching character. -// -// For example, this checks for quotes with backslashes in front of them: -// -// const uint64_t backslashed_quote = in.eq('"') & -// immediately_follows(in.eq('\'), prev_backslash); -// -simdjson_really_inline uint64_t follows(const uint64_t match, - uint64_t &overflow) { - const uint64_t result = match << 1 | overflow; - overflow = match >> 63; - return result; -} - -simdjson_really_inline json_block -json_scanner::next(const simd::simd8x64 &in) { - json_string_block strings = string_scanner.next(in); - // identifies the white-space and the structural characters - json_character_block characters = json_character_block::classify(in); - // The term "scalar" refers to anything except structural characters and - // white space - // (so letters, numbers, quotes). - // We want follows_scalar to mark anything that follows a non-quote scalar - // (so letters and numbers). - // - // A terminal quote should either be followed by a structural character - // (comma, brace, bracket, colon) - // or nothing. However, we still want ' "a string"true ' to mark the 't' of - // 'true' as a potential - // pseudo-structural character just like we would if we had ' "a string" - // true '; otherwise we - // may need to add an extra check when parsing strings. - // - // Performance: there are many ways to skin this cat. - const uint64_t nonquote_scalar = characters.scalar() & ~strings.quote(); - uint64_t follows_nonquote_scalar = follows(nonquote_scalar, prev_scalar); - // We are returning a function-local object so either we get a move - // constructor - // or we get copy elision. - return json_block(strings, // strings is a function-local object so either - // it moves or the copy is elided. - characters, - follows_nonquote_scalar); -} - -simdjson_really_inline error_code json_scanner::finish() { - return string_scanner.finish(); -} - -} // namespace stage1 -} // unnamed namespace -} // namespace ppc64 -} // namespace simdjson -/* end file src/generic/stage1/json_scanner.h */ -/* begin file src/generic/stage1/json_minifier.h */ -// This file contains the common code every implementation uses in stage1 -// It is intended to be included multiple times and compiled multiple times -// We assume the file in which it is included already includes -// "simdjson/stage1.h" (this simplifies amalgation) - -namespace simdjson { -namespace ppc64 { -namespace { -namespace stage1 { - -class json_minifier { - public: - template - static error_code minify(const uint8_t *buf, - size_t len, - uint8_t *dst, - size_t &dst_len) noexcept; - - private: - simdjson_really_inline json_minifier(uint8_t *_dst) : dst{_dst} {} - template - simdjson_really_inline void step( - const uint8_t *block_buf, buf_block_reader &reader) noexcept; - simdjson_really_inline void next(const simd::simd8x64 &in, - const json_block &block); - simdjson_really_inline error_code finish(uint8_t *dst_start, - size_t &dst_len); - json_scanner scanner{}; - uint8_t *dst; -}; - -simdjson_really_inline void json_minifier::next( - const simd::simd8x64 &in, const json_block &block) { - uint64_t mask = block.whitespace(); - dst += in.compress(mask, dst); -} - -simdjson_really_inline error_code json_minifier::finish(uint8_t *dst_start, - size_t &dst_len) { - error_code error = scanner.finish(); - if (error) { - dst_len = 0; - return error; - } - dst_len = dst - dst_start; - return SUCCESS; -} - -template <> -simdjson_really_inline void json_minifier::step<128>( - const uint8_t *block_buf, buf_block_reader<128> &reader) noexcept { - simd::simd8x64 in_1(block_buf); - simd::simd8x64 in_2(block_buf + 64); - json_block block_1 = scanner.next(in_1); - json_block block_2 = scanner.next(in_2); - this->next(in_1, block_1); - this->next(in_2, block_2); - reader.advance(); -} - -template <> -simdjson_really_inline void json_minifier::step<64>( - const uint8_t *block_buf, buf_block_reader<64> &reader) noexcept { - simd::simd8x64 in_1(block_buf); - json_block block_1 = scanner.next(in_1); - this->next(block_buf, block_1); - reader.advance(); -} - -template -error_code json_minifier::minify(const uint8_t *buf, - size_t len, - uint8_t *dst, - size_t &dst_len) noexcept { - buf_block_reader reader(buf, len); - json_minifier minifier(dst); - - // Index the first n-1 blocks - while (reader.has_full_block()) { - minifier.step(reader.full_block(), reader); - } - - // Index the last (remainder) block, padded with spaces - uint8_t block[STEP_SIZE]; - size_t remaining_bytes = reader.get_remainder(block); - if (remaining_bytes > 0) { - // We do not want to write directly to the output stream. Rather, we - // write - // to a local buffer (for safety). - uint8_t out_block[STEP_SIZE]; - uint8_t *const guarded_dst{minifier.dst}; - minifier.dst = out_block; - minifier.step(block, reader); - size_t to_write = minifier.dst - out_block; - // In some cases, we could be enticed to consider the padded spaces - // as part of the string. This is fine as long as we do not write more - // than we consumed. - if (to_write > remaining_bytes) { - to_write = remaining_bytes; - } - memcpy(guarded_dst, out_block, to_write); - minifier.dst = guarded_dst + to_write; - } - return minifier.finish(dst, dst_len); -} - -} // namespace stage1 -} // unnamed namespace -} // namespace ppc64 -} // namespace simdjson -/* end file src/generic/stage1/json_minifier.h */ -/* begin file src/generic/stage1/find_next_document_index.h */ -namespace simdjson { -namespace ppc64 { -namespace { - -/** - * This algorithm is used to quickly identify the last structural position that - * makes up a complete document. - * - * It does this by going backwards and finding the last *document boundary* (a - * place where one value follows another without a comma between them). If the - * last document (the characters after the boundary) has an equal number of - * start and end brackets, it is considered complete. - * - * Simply put, we iterate over the structural characters, starting from - * the end. We consider that we found the end of a JSON document when the - * first element of the pair is NOT one of these characters: '{' '[' ':' ',' - * and when the second element is NOT one of these characters: '}' ']' ':' ','. - * - * This simple comparison works most of the time, but it does not cover cases - * where the batch's structural indexes contain a perfect amount of documents. - * In such a case, we do not have access to the structural index which follows - * the last document, therefore, we do not have access to the second element in - * the pair, and that means we cannot identify the last document. To fix this - * issue, we keep a count of the open and closed curly/square braces we found - * while searching for the pair. When we find a pair AND the count of open and - * closed curly/square braces is the same, we know that we just passed a - * complete document, therefore the last json buffer location is the end of the - * batch. - */ -simdjson_really_inline uint32_t -find_next_document_index(dom_parser_implementation &parser) { - // Variant: do not count separately, just figure out depth - if (parser.n_structural_indexes == 0) { - return 0; - } - auto arr_cnt = 0; - auto obj_cnt = 0; - for (auto i = parser.n_structural_indexes - 1; i > 0; i--) { - auto idxb = parser.structural_indexes[i]; - switch (parser.buf[idxb]) { - case ':': - case ',': - continue; - case '}': - obj_cnt--; - continue; - case ']': - arr_cnt--; - continue; - case '{': - obj_cnt++; - break; - case '[': - arr_cnt++; - break; - } - auto idxa = parser.structural_indexes[i - 1]; - switch (parser.buf[idxa]) { - case '{': - case '[': - case ':': - case ',': - continue; - } - // Last document is complete, so the next document will appear after! - if (!arr_cnt && !obj_cnt) { - return parser.n_structural_indexes; - } - // Last document is incomplete; mark the document at i + 1 as the next - // one - return i; - } - // If we made it to the end, we want to finish counting to see if we have a - // full document. - switch (parser.buf[parser.structural_indexes[0]]) { - case '}': - obj_cnt--; - break; - case ']': - arr_cnt--; - break; - case '{': - obj_cnt++; - break; - case '[': - arr_cnt++; - break; - } - if (!arr_cnt && !obj_cnt) { - // We have a complete document. - return parser.n_structural_indexes; - } - return 0; -} - -} // unnamed namespace -} // namespace ppc64 -} // namespace simdjson -/* end file src/generic/stage1/find_next_document_index.h */ - -namespace simdjson { -namespace ppc64 { -namespace { -namespace stage1 { - -class bit_indexer { - public: - uint32_t *tail; - - simdjson_really_inline bit_indexer(uint32_t *index_buf) : tail(index_buf) {} - - // flatten out values in 'bits' assuming that they are are to have values of - // idx - // plus their position in the bitvector, and store these indexes at - // base_ptr[base] incrementing base as we go - // will potentially store extra values beyond end of valid bits, so base_ptr - // needs to be large enough to handle this - simdjson_really_inline void write(uint32_t idx, uint64_t bits) { - // In some instances, the next branch is expensive because it is - // mispredicted. - // Unfortunately, in other cases, - // it helps tremendously. - if (bits == 0) return; -#if defined(SIMDJSON_PREFER_REVERSE_BITS) - /** - * ARM lacks a fast trailing zero instruction, but it has a fast - * bit reversal instruction and a fast leading zero instruction. - * Thus it may be profitable to reverse the bits (once) and then - * to rely on a sequence of instructions that call the leading - * zero instruction. - * - * Performance notes: - * The chosen routine is not optimal in terms of data dependency - * since zero_leading_bit might require two instructions. However, - * it tends to minimize the total number of instructions which is - * beneficial. - */ - - uint64_t rev_bits = reverse_bits(bits); - int cnt = static_cast(count_ones(bits)); - int i = 0; - // Do the first 8 all together - for (; i < 8; i++) { - int lz = leading_zeroes(rev_bits); - this->tail[i] = static_cast(idx) + lz; - rev_bits = zero_leading_bit(rev_bits, lz); - } - // Do the next 8 all together (we hope in most cases it won't happen at - // all - // and the branch is easily predicted). - if (simdjson_unlikely(cnt > 8)) { - i = 8; - for (; i < 16; i++) { - int lz = leading_zeroes(rev_bits); - this->tail[i] = static_cast(idx) + lz; - rev_bits = zero_leading_bit(rev_bits, lz); - } - - - // Most files don't have 16+ structurals per block, so we take - // several basically guaranteed - // branch mispredictions here. 16+ structurals per block means - // either punctuation ({} [] , :) - // or the start of a value ("abc" true 123) every four characters. - if (simdjson_unlikely(cnt > 16)) { - i = 16; - while (rev_bits != 0) { - int lz = leading_zeroes(rev_bits); - this->tail[i++] = static_cast(idx) + lz; - rev_bits = zero_leading_bit(rev_bits, lz); - } - } - } - this->tail += cnt; -#else // SIMDJSON_PREFER_REVERSE_BITS - /** - * Under recent x64 systems, we often have both a fast trailing zero - * instruction and a fast 'clear-lower-bit' instruction so the following - * algorithm can be competitive. - */ - - int cnt = static_cast(count_ones(bits)); - // Do the first 8 all together - for (int i = 0; i < 8; i++) { - this->tail[i] = idx + trailing_zeroes(bits); - bits = clear_lowest_bit(bits); - } - - // Do the next 8 all together (we hope in most cases it won't happen at - // all - // and the branch is easily predicted). - if (simdjson_unlikely(cnt > 8)) { - for (int i = 8; i < 16; i++) { - this->tail[i] = idx + trailing_zeroes(bits); - bits = clear_lowest_bit(bits); - } - - // Most files don't have 16+ structurals per block, so we take - // several basically guaranteed - // branch mispredictions here. 16+ structurals per block means - // either punctuation ({} [] , :) - // or the start of a value ("abc" true 123) every four characters. - if (simdjson_unlikely(cnt > 16)) { - int i = 16; - do { - this->tail[i] = idx + trailing_zeroes(bits); - bits = clear_lowest_bit(bits); - i++; - } while (i < cnt); - } - } - - this->tail += cnt; -#endif - } -}; - -class json_structural_indexer { - public: - /** - * Find the important bits of JSON in a 128-byte chunk, and add them to - * structural_indexes. - * - * @param partial Setting the partial parameter to true allows the - * find_structural_bits to - * tolerate unclosed strings. The caller should still ensure that the - * input is valid UTF-8. If - * you are processing substrings, you may want to call on a function like - * trimmed_length_safe_utf8. - */ - template - static error_code index(const uint8_t *buf, - size_t len, - dom_parser_implementation &parser, - stage1_mode partial) noexcept; - - private: - simdjson_really_inline json_structural_indexer( - uint32_t *structural_indexes); - template - simdjson_really_inline void step( - const uint8_t *block, buf_block_reader &reader) noexcept; - simdjson_really_inline void next(const simd::simd8x64 &in, - const json_block &block, - size_t idx); - simdjson_really_inline error_code finish(dom_parser_implementation &parser, - size_t idx, - size_t len, - stage1_mode partial); - - json_scanner scanner{}; - utf8_checker checker{}; - bit_indexer indexer; - uint64_t prev_structurals = 0; - uint64_t unescaped_chars_error = 0; -}; - -simdjson_really_inline json_structural_indexer::json_structural_indexer( - uint32_t *structural_indexes) - : indexer{structural_indexes} {} - -// Skip the last character if it is partial -simdjson_really_inline size_t trim_partial_utf8(const uint8_t *buf, - size_t len) { - if (simdjson_unlikely(len < 3)) { - switch (len) { - case 2: - if (buf[len - 1] >= 0b11000000) { - return len - 1; - } // 2-, 3- and 4-byte characters with only 1 byte left - if (buf[len - 2] >= 0b11100000) { - return len - 2; - } // 3- and 4-byte characters with only 2 bytes left - return len; - case 1: - if (buf[len - 1] >= 0b11000000) { - return len - 1; - } // 2-, 3- and 4-byte characters with only 1 byte left - return len; - case 0: - return len; - } - } - if (buf[len - 1] >= 0b11000000) { - return len - 1; - } // 2-, 3- and 4-byte characters with only 1 byte left - if (buf[len - 2] >= 0b11100000) { - return len - 2; - } // 3- and 4-byte characters with only 1 byte left - if (buf[len - 3] >= 0b11110000) { - return len - 3; - } // 4-byte characters with only 3 bytes left - return len; -} - -// -// PERF NOTES: -// We pipe 2 inputs through these stages: -// 1. Load JSON into registers. This takes a long time and is highly -// parallelizable, so we load -// 2 inputs' worth at once so that by the time step 2 is looking for them -// input, it's available. -// 2. Scan the JSON for critical data: strings, scalars and operators. This is -// the critical path. -// The output of step 1 depends entirely on this information. These functions -// don't quite use -// up enough CPU: the second half of the functions is highly serial, only -// using 1 execution core -// at a time. The second input's scans has some dependency on the first ones -// finishing it, but -// they can make a lot of progress before they need that information. -// 3. Step 1 doesn't use enough capacity, so we run some extra stuff while we're -// waiting for that -// to finish: utf-8 checks and generating the output from the last iteration. -// -// The reason we run 2 inputs at a time, is steps 2 and 3 are *still* not enough -// to soak up all -// available capacity with just one input. Running 2 at a time seems to give the -// CPU a good enough -// workout. -// -template -error_code json_structural_indexer::index(const uint8_t *buf, - size_t len, - dom_parser_implementation &parser, - stage1_mode partial) noexcept { - if (simdjson_unlikely(len > parser.capacity())) { - return CAPACITY; - } - // We guard the rest of the code so that we can assume that len > 0 - // throughout. - if (len == 0) { - return EMPTY; - } - if (is_streaming(partial)) { - len = trim_partial_utf8(buf, len); - // If you end up with an empty window after trimming - // the partial UTF-8 bytes, then chances are good that you - // have an UTF-8 formatting error. - if (len == 0) { - return UTF8_ERROR; - } - } - buf_block_reader reader(buf, len); - json_structural_indexer indexer(parser.structural_indexes.get()); - - // Read all but the last block - while (reader.has_full_block()) { - indexer.step(reader.full_block(), reader); - } - // Take care of the last block (will always be there unless file is empty - // which is - // not supposed to happen.) - uint8_t block[STEP_SIZE]; - if (simdjson_unlikely(reader.get_remainder(block) == 0)) { - return UNEXPECTED_ERROR; - } - indexer.step(block, reader); - return indexer.finish(parser, reader.block_index(), len, partial); -} - -template <> -simdjson_really_inline void json_structural_indexer::step<128>( - const uint8_t *block, buf_block_reader<128> &reader) noexcept { - simd::simd8x64 in_1(block); - simd::simd8x64 in_2(block + 64); - json_block block_1 = scanner.next(in_1); - json_block block_2 = scanner.next(in_2); - this->next(in_1, block_1, reader.block_index()); - this->next(in_2, block_2, reader.block_index() + 64); - reader.advance(); -} - -template <> -simdjson_really_inline void json_structural_indexer::step<64>( - const uint8_t *block, buf_block_reader<64> &reader) noexcept { - simd::simd8x64 in_1(block); - json_block block_1 = scanner.next(in_1); - this->next(in_1, block_1, reader.block_index()); - reader.advance(); -} - -simdjson_really_inline void json_structural_indexer::next( - const simd::simd8x64 &in, const json_block &block, size_t idx) { - uint64_t unescaped = in.lteq(0x1F); - checker.check_next_input(in); - indexer.write(uint32_t(idx - 64), prev_structurals); // Output *last* - // iteration's - // structurals to the - // parser - prev_structurals = block.structural_start(); - unescaped_chars_error |= block.non_quote_inside_string(unescaped); -} - -simdjson_really_inline error_code -json_structural_indexer::finish(dom_parser_implementation &parser, - size_t idx, - size_t len, - stage1_mode partial) { - // Write out the final iteration's structurals - indexer.write(uint32_t(idx - 64), prev_structurals); - error_code error = scanner.finish(); - // We deliberately break down the next expression so that it is - // human readable. - const bool should_we_exit = - is_streaming(partial) - ? ((error != SUCCESS) && - (error != - UNCLOSED_STRING)) // when partial we tolerate UNCLOSED_STRING - : (error != SUCCESS); // if partial is false, we must have SUCCESS - const bool have_unclosed_string = (error == UNCLOSED_STRING); - if (simdjson_unlikely(should_we_exit)) { - return error; - } - - if (unescaped_chars_error) { - return UNESCAPED_CHARS; - } - parser.n_structural_indexes = - uint32_t(indexer.tail - parser.structural_indexes.get()); - /*** - * The On Demand API requires special padding. - * - * This is related to https://github.com/simdjson/simdjson/issues/906 - * Basically, we want to make sure that if the parsing continues beyond the - *last (valid) - * structural character, it quickly stops. - * Only three structural characters can be repeated without triggering an - *error in JSON: [,] and }. - * We repeat the padding character (at 'len'). We don't know what it is, but - *if the parsing - * continues, then it must be [,] or }. - * Suppose it is ] or }. We backtrack to the first character, what could it - *be that would - * not trigger an error? It could be ] or } but no, because you can't start - *a document that way. - * It can't be a comma, a colon or any simple value. So the only way we - *could continue is - * if the repeated character is [. But if so, the document must start with - *[. But if the document - * starts with [, it should end with ]. If we enforce that rule, then we - *would get - * ][[ which is invalid. - * - * This is illustrated with the test array_iterate_unclosed_error() on the - *following input: - * R"({ "a": [,,)" - **/ - parser.structural_indexes[parser.n_structural_indexes] = - uint32_t(len); // used later in partial == stage1_mode::streaming_final - parser.structural_indexes[parser.n_structural_indexes + 1] = uint32_t(len); - parser.structural_indexes[parser.n_structural_indexes + 2] = 0; - parser.next_structural_index = 0; - // a valid JSON file cannot have zero structural indexes - we should have - // found something - if (simdjson_unlikely(parser.n_structural_indexes == 0u)) { - return EMPTY; - } - if (simdjson_unlikely( - parser.structural_indexes[parser.n_structural_indexes - 1] > len)) { - return UNEXPECTED_ERROR; - } - if (partial == stage1_mode::streaming_partial) { - // If we have an unclosed string, then the last structural - // will be the quote and we want to make sure to omit it. - if (have_unclosed_string) { - parser.n_structural_indexes--; - // a valid JSON file cannot have zero structural indexes - we should - // have found something - if (simdjson_unlikely(parser.n_structural_indexes == 0u)) { - return CAPACITY; - } - } - // We truncate the input to the end of the last complete document (or - // zero). - auto new_structural_indexes = find_next_document_index(parser); - if (new_structural_indexes == 0 && parser.n_structural_indexes > 0) { - if (parser.structural_indexes[0] == 0) { - // If the buffer is partial and we started at index 0 but the - // document is - // incomplete, it's too big to parse. - return CAPACITY; - } else { - // It is possible that the document could be parsed, we just had - // a lot - // of white space. - parser.n_structural_indexes = 0; - return EMPTY; - } - } - - parser.n_structural_indexes = new_structural_indexes; - } else if (partial == stage1_mode::streaming_final) { - if (have_unclosed_string) { - parser.n_structural_indexes--; - } - // We truncate the input to the end of the last complete document (or - // zero). - // Because partial == stage1_mode::streaming_final, it means that we may - // silently ignore trailing garbage. Though it sounds bad, we do it - // deliberately because many people who have streams of JSON documents - // will truncate them for processing. E.g., imagine that you are - // uncompressing - // the data from a size file or receiving it in chunks from the network. - // You - // may not know where exactly the last document will be. Meanwhile the - // document_stream instances allow people to know the JSON documents - // they are - // parsing (see the iterator.source() method). - parser.n_structural_indexes = find_next_document_index(parser); - // We store the initial n_structural_indexes so that the client can see - // whether we used truncation. If initial_n_structural_indexes == - // parser.n_structural_indexes, - // then this will query - // parser.structural_indexes[parser.n_structural_indexes] which is len, - // otherwise, it will copy some prior index. - parser.structural_indexes[parser.n_structural_indexes + 1] = - parser.structural_indexes[parser.n_structural_indexes]; - // This next line is critical, do not change it unless you understand - // what you are - // doing. - parser.structural_indexes[parser.n_structural_indexes] = uint32_t(len); - if (simdjson_unlikely(parser.n_structural_indexes == 0u)) { - // We tolerate an unclosed string at the very end of the stream. - // Indeed, users - // often load their data in bulk without being careful and they want - // us to ignore - // the trailing garbage. - return EMPTY; - } - } - checker.check_eof(); - return checker.errors(); -} - -} // namespace stage1 -} // unnamed namespace -} // namespace ppc64 -} // namespace simdjson -/* end file src/generic/stage1/json_structural_indexer.h */ -/* begin file src/generic/stage1/utf8_validator.h */ -namespace simdjson { -namespace ppc64 { -namespace { -namespace stage1 { - -/** - * Validates that the string is actual UTF-8. - */ -template -bool generic_validate_utf8(const uint8_t *input, size_t length) { - checker c{}; - buf_block_reader<64> reader(input, length); - while (reader.has_full_block()) { - simd::simd8x64 in(reader.full_block()); - c.check_next_input(in); - reader.advance(); - } - uint8_t block[64]{}; - reader.get_remainder(block); - simd::simd8x64 in(block); - c.check_next_input(in); - reader.advance(); - c.check_eof(); - return c.errors() == error_code::SUCCESS; -} - -bool generic_validate_utf8(const char *input, size_t length) { - return generic_validate_utf8( - reinterpret_cast(input), length); -} - -} // namespace stage1 -} // unnamed namespace -} // namespace ppc64 -} // namespace simdjson -/* end file src/generic/stage1/utf8_validator.h */ - -// -// Stage 2 -// - -/* begin file src/generic/stage2/tape_builder.h */ -/* begin file src/generic/stage2/json_iterator.h */ -/* begin file src/generic/stage2/logger.h */ -// This is for an internal-only stage 2 specific logger. -// Set LOG_ENABLED = true to log what stage 2 is doing! -namespace simdjson { -namespace ppc64 { -namespace { -namespace logger { - -static constexpr const char *DASHES = - "--------------------------------------------------------------------------" - "--------------------------------------------------------------------------" - "--------------------------------------------------------------------------" - "----------------------------------"; - -#if SIMDJSON_VERBOSE_LOGGING -static constexpr const bool LOG_ENABLED = true; -#else -static constexpr const bool LOG_ENABLED = false; -#endif -static constexpr const int LOG_EVENT_LEN = 20; -static constexpr const int LOG_BUFFER_LEN = 30; -static constexpr const int LOG_SMALL_BUFFER_LEN = 10; -static constexpr const int LOG_INDEX_LEN = 5; - -static int log_depth; // Not threadsafe. Log only. - -// Helper to turn unprintable or newline characters into spaces -static simdjson_really_inline char printable_char(char c) { - if (c >= 0x20) { - return c; - } else { - return ' '; - } -} - -// Print the header and set up log_start -static simdjson_really_inline void log_start() { - if (LOG_ENABLED) { - log_depth = 0; - printf("\n"); - printf("| %-*s | %-*s | %-*s | %-*s | Detail |\n", - LOG_EVENT_LEN, - "Event", - LOG_BUFFER_LEN, - "Buffer", - LOG_SMALL_BUFFER_LEN, - "Next", - 5, - "Next#"); - printf("|%.*s|%.*s|%.*s|%.*s|--------|\n", - LOG_EVENT_LEN + 2, - DASHES, - LOG_BUFFER_LEN + 2, - DASHES, - LOG_SMALL_BUFFER_LEN + 2, - DASHES, - 5 + 2, - DASHES); - } -} - -simdjson_unused static simdjson_really_inline void log_string( - const char *message) { - if (LOG_ENABLED) { - printf("%s\n", message); - } -} - -// Logs a single line from the stage 2 DOM parser -template -static simdjson_really_inline void log_line(S &structurals, - const char *title_prefix, - const char *title, - const char *detail) { - if (LOG_ENABLED) { - printf("| %*s%s%-*s ", - log_depth * 2, - "", - title_prefix, - LOG_EVENT_LEN - log_depth * 2 - int(strlen(title_prefix)), - title); - auto current_index = structurals.at_beginning() - ? nullptr - : structurals.next_structural - 1; - auto next_index = structurals.next_structural; - auto current = current_index ? &structurals.buf[*current_index] - : reinterpret_cast( - " " - " "); - auto next = &structurals.buf[*next_index]; - { - // Print the next N characters in the buffer. - printf("| "); - // Otherwise, print the characters starting from the buffer - // position. - // Print spaces for unprintable or newline characters. - for (int i = 0; i < LOG_BUFFER_LEN; i++) { - printf("%c", printable_char(current[i])); - } - printf(" "); - // Print the next N characters in the buffer. - printf("| "); - // Otherwise, print the characters starting from the buffer - // position. - // Print spaces for unprintable or newline characters. - for (int i = 0; i < LOG_SMALL_BUFFER_LEN; i++) { - printf("%c", printable_char(next[i])); - } - printf(" "); - } - if (current_index) { - printf("| %*u ", LOG_INDEX_LEN, *current_index); - } else { - printf("| %-*s ", LOG_INDEX_LEN, ""); - } - // printf("| %*u ", LOG_INDEX_LEN, structurals.next_tape_index()); - printf("| %-s ", detail); - printf("|\n"); - } -} - -} // namespace logger -} // unnamed namespace -} // namespace ppc64 -} // namespace simdjson -/* end file src/generic/stage2/logger.h */ - -namespace simdjson { -namespace ppc64 { -namespace { -namespace stage2 { - -class json_iterator { - public: - const uint8_t *const buf; - uint32_t *next_structural; - dom_parser_implementation &dom_parser; - uint32_t depth{0}; - - /** - * Walk the JSON document. - * - * The visitor receives callbacks when values are encountered. All callbacks - * pass the iterator as - * the first parameter; some callbacks have other parameters as well: - * - * - visit_document_start() - at the beginning. - * - visit_document_end() - at the end (if things were successful). - * - * - visit_array_start() - at the start `[` of a non-empty array. - * - visit_array_end() - at the end `]` of a non-empty array. - * - visit_empty_array() - when an empty array is encountered. - * - * - visit_object_end() - at the start `]` of a non-empty object. - * - visit_object_start() - at the end `]` of a non-empty object. - * - visit_empty_object() - when an empty object is encountered. - * - visit_key(const uint8_t *key) - when a key in an object field is - * encountered. key is - * guaranteed to point at the first quote - * of the string (`"key"`). - * - visit_primitive(const uint8_t *value) - when a value is a string, - * number, boolean or null. - * - visit_root_primitive(iter, uint8_t *value) - when the top-level value - * is a string, number, boolean or null. - * - * - increment_count(iter) - each time a value is found in an array or - * object. - */ - template - simdjson_warn_unused simdjson_really_inline error_code - walk_document(V &visitor) noexcept; - - /** - * Create an iterator capable of walking a JSON document. - * - * The document must have already passed through stage 1. - */ - simdjson_really_inline json_iterator(dom_parser_implementation &_dom_parser, - size_t start_structural_index); - - /** - * Look at the next token. - * - * Tokens can be strings, numbers, booleans, null, or operators (`[{]},:`)). - * - * They may include invalid JSON as well (such as `1.2.3` or `ture`). - */ - simdjson_really_inline const uint8_t *peek() const noexcept; - /** - * Advance to the next token. - * - * Tokens can be strings, numbers, booleans, null, or operators (`[{]},:`)). - * - * They may include invalid JSON as well (such as `1.2.3` or `ture`). - */ - simdjson_really_inline const uint8_t *advance() noexcept; - /** - * Get the remaining length of the document, from the start of the current - * token. - */ - simdjson_really_inline size_t remaining_len() const noexcept; - /** - * Check if we are at the end of the document. - * - * If this is true, there are no more tokens. - */ - simdjson_really_inline bool at_eof() const noexcept; - /** - * Check if we are at the beginning of the document. - */ - simdjson_really_inline bool at_beginning() const noexcept; - simdjson_really_inline uint8_t last_structural() const noexcept; - - /** - * Log that a value has been found. - * - * Set LOG_ENABLED=true in logger.h to see logging. - */ - simdjson_really_inline void log_value(const char *type) const noexcept; - /** - * Log the start of a multipart value. - * - * Set LOG_ENABLED=true in logger.h to see logging. - */ - simdjson_really_inline void log_start_value(const char *type) const - noexcept; - /** - * Log the end of a multipart value. - * - * Set LOG_ENABLED=true in logger.h to see logging. - */ - simdjson_really_inline void log_end_value(const char *type) const noexcept; - /** - * Log an error. - * - * Set LOG_ENABLED=true in logger.h to see logging. - */ - simdjson_really_inline void log_error(const char *error) const noexcept; - - template - simdjson_warn_unused simdjson_really_inline error_code - visit_root_primitive(V &visitor, const uint8_t *value) noexcept; - template - simdjson_warn_unused simdjson_really_inline error_code - visit_primitive(V &visitor, const uint8_t *value) noexcept; -}; - -template -simdjson_warn_unused simdjson_really_inline error_code -json_iterator::walk_document(V &visitor) noexcept { - logger::log_start(); - - // - // Start the document - // - if (at_eof()) { - return EMPTY; - } - log_start_value("document"); - SIMDJSON_TRY(visitor.visit_document_start(*this)); - - // - // Read first value - // - { - auto value = advance(); - - // Make sure the outer object or array is closed before continuing; - // otherwise, there are ways we - // could get into memory corruption. See - // https://github.com/simdjson/simdjson/issues/906 - if (!STREAMING) { - switch (*value) { - case '{': - if (last_structural() != '}') { - log_value("starting brace unmatched"); - return TAPE_ERROR; - }; - break; - case '[': - if (last_structural() != ']') { - log_value("starting bracket unmatched"); - return TAPE_ERROR; - }; - break; - } - } - - switch (*value) { - case '{': - if (*peek() == '}') { - advance(); - log_value("empty object"); - SIMDJSON_TRY(visitor.visit_empty_object(*this)); - break; - } - goto object_begin; - case '[': - if (*peek() == ']') { - advance(); - log_value("empty array"); - SIMDJSON_TRY(visitor.visit_empty_array(*this)); - break; - } - goto array_begin; - default: - SIMDJSON_TRY(visitor.visit_root_primitive(*this, value)); - break; - } - } - goto document_end; - -// -// Object parser states -// -object_begin: - log_start_value("object"); - depth++; - if (depth >= dom_parser.max_depth()) { - log_error("Exceeded max depth!"); - return DEPTH_ERROR; - } - dom_parser.is_array[depth] = false; - SIMDJSON_TRY(visitor.visit_object_start(*this)); - - { - auto key = advance(); - if (*key != '"') { - log_error("Object does not start with a key"); - return TAPE_ERROR; - } - SIMDJSON_TRY(visitor.increment_count(*this)); - SIMDJSON_TRY(visitor.visit_key(*this, key)); - } - -object_field: - if (simdjson_unlikely(*advance() != ':')) { - log_error("Missing colon after key in object"); - return TAPE_ERROR; - } - { - auto value = advance(); - switch (*value) { - case '{': - if (*peek() == '}') { - advance(); - log_value("empty object"); - SIMDJSON_TRY(visitor.visit_empty_object(*this)); - break; - } - goto object_begin; - case '[': - if (*peek() == ']') { - advance(); - log_value("empty array"); - SIMDJSON_TRY(visitor.visit_empty_array(*this)); - break; - } - goto array_begin; - default: - SIMDJSON_TRY(visitor.visit_primitive(*this, value)); - break; - } - } - -object_continue: - switch (*advance()) { - case ',': - SIMDJSON_TRY(visitor.increment_count(*this)); - { - auto key = advance(); - if (simdjson_unlikely(*key != '"')) { - log_error( - "Key string missing at beginning of field in object"); - return TAPE_ERROR; - } - SIMDJSON_TRY(visitor.visit_key(*this, key)); - } - goto object_field; - case '}': - log_end_value("object"); - SIMDJSON_TRY(visitor.visit_object_end(*this)); - goto scope_end; - default: - log_error("No comma between object fields"); - return TAPE_ERROR; - } - -scope_end: - depth--; - if (depth == 0) { - goto document_end; - } - if (dom_parser.is_array[depth]) { - goto array_continue; - } - goto object_continue; - -// -// Array parser states -// -array_begin: - log_start_value("array"); - depth++; - if (depth >= dom_parser.max_depth()) { - log_error("Exceeded max depth!"); - return DEPTH_ERROR; - } - dom_parser.is_array[depth] = true; - SIMDJSON_TRY(visitor.visit_array_start(*this)); - SIMDJSON_TRY(visitor.increment_count(*this)); - -array_value : { - auto value = advance(); - switch (*value) { - case '{': - if (*peek() == '}') { - advance(); - log_value("empty object"); - SIMDJSON_TRY(visitor.visit_empty_object(*this)); - break; - } - goto object_begin; - case '[': - if (*peek() == ']') { - advance(); - log_value("empty array"); - SIMDJSON_TRY(visitor.visit_empty_array(*this)); - break; - } - goto array_begin; - default: - SIMDJSON_TRY(visitor.visit_primitive(*this, value)); - break; - } -} - -array_continue: - switch (*advance()) { - case ',': - SIMDJSON_TRY(visitor.increment_count(*this)); - goto array_value; - case ']': - log_end_value("array"); - SIMDJSON_TRY(visitor.visit_array_end(*this)); - goto scope_end; - default: - log_error("Missing comma between array values"); - return TAPE_ERROR; - } - -document_end: - log_end_value("document"); - SIMDJSON_TRY(visitor.visit_document_end(*this)); - - dom_parser.next_structural_index = - uint32_t(next_structural - &dom_parser.structural_indexes[0]); - - // If we didn't make it to the end, it's an error - if (!STREAMING && - dom_parser.next_structural_index != dom_parser.n_structural_indexes) { - log_error( - "More than one JSON value at the root of the document, or extra " - "characters at the end of the JSON!"); - return TAPE_ERROR; - } - - return SUCCESS; - -} // walk_document() - -simdjson_really_inline json_iterator::json_iterator( - dom_parser_implementation &_dom_parser, size_t start_structural_index) - : buf{_dom_parser.buf}, - next_structural{&_dom_parser.structural_indexes[start_structural_index]}, - dom_parser{_dom_parser} {} - -simdjson_really_inline const uint8_t *json_iterator::peek() const noexcept { - return &buf[*(next_structural)]; -} -simdjson_really_inline const uint8_t *json_iterator::advance() noexcept { - return &buf[*(next_structural++)]; -} -simdjson_really_inline size_t json_iterator::remaining_len() const noexcept { - return dom_parser.len - *(next_structural - 1); -} - -simdjson_really_inline bool json_iterator::at_eof() const noexcept { - return next_structural == - &dom_parser.structural_indexes[dom_parser.n_structural_indexes]; -} -simdjson_really_inline bool json_iterator::at_beginning() const noexcept { - return next_structural == dom_parser.structural_indexes.get(); -} -simdjson_really_inline uint8_t json_iterator::last_structural() const noexcept { - return buf[dom_parser - .structural_indexes[dom_parser.n_structural_indexes - 1]]; -} - -simdjson_really_inline void json_iterator::log_value(const char *type) const - noexcept { - logger::log_line(*this, "", type, ""); -} - -simdjson_really_inline void json_iterator::log_start_value( - const char *type) const noexcept { - logger::log_line(*this, "+", type, ""); - if (logger::LOG_ENABLED) { - logger::log_depth++; - } -} - -simdjson_really_inline void json_iterator::log_end_value(const char *type) const - noexcept { - if (logger::LOG_ENABLED) { - logger::log_depth--; - } - logger::log_line(*this, "-", type, ""); -} - -simdjson_really_inline void json_iterator::log_error(const char *error) const - noexcept { - logger::log_line(*this, "", "ERROR", error); -} - -template -simdjson_warn_unused simdjson_really_inline error_code -json_iterator::visit_root_primitive(V &visitor, const uint8_t *value) noexcept { - switch (*value) { - case '"': - return visitor.visit_root_string(*this, value); - case 't': - return visitor.visit_root_true_atom(*this, value); - case 'f': - return visitor.visit_root_false_atom(*this, value); - case 'n': - return visitor.visit_root_null_atom(*this, value); - case '-': - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - return visitor.visit_root_number(*this, value); - default: - log_error("Document starts with a non-value character"); - return TAPE_ERROR; - } -} -template -simdjson_warn_unused simdjson_really_inline error_code -json_iterator::visit_primitive(V &visitor, const uint8_t *value) noexcept { - switch (*value) { - case '"': - return visitor.visit_string(*this, value); - case 't': - return visitor.visit_true_atom(*this, value); - case 'f': - return visitor.visit_false_atom(*this, value); - case 'n': - return visitor.visit_null_atom(*this, value); - case '-': - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - return visitor.visit_number(*this, value); - default: - log_error("Non-value found when value was expected!"); - return TAPE_ERROR; - } -} - -} // namespace stage2 -} // unnamed namespace -} // namespace ppc64 -} // namespace simdjson -/* end file src/generic/stage2/json_iterator.h */ -/* begin file src/generic/stage2/tape_writer.h */ -namespace simdjson { -namespace ppc64 { -namespace { -namespace stage2 { - -struct tape_writer { - /** The next place to write to tape */ - uint64_t *next_tape_loc; - - /** Write a signed 64-bit value to tape. */ - simdjson_really_inline void append_s64(int64_t value) noexcept; - - /** Write an unsigned 64-bit value to tape. */ - simdjson_really_inline void append_u64(uint64_t value) noexcept; - - /** Write a double value to tape. */ - simdjson_really_inline void append_double(double value) noexcept; - - /** - * Append a tape entry (an 8-bit type,and 56 bits worth of value). - */ - simdjson_really_inline void append(uint64_t val, - internal::tape_type t) noexcept; - - /** - * Skip the current tape entry without writing. - * - * Used to skip the start of the container, since we'll come back later to - * fill it in when the - * container ends. - */ - simdjson_really_inline void skip() noexcept; - - /** - * Skip the number of tape entries necessary to write a large u64 or i64. - */ - simdjson_really_inline void skip_large_integer() noexcept; - - /** - * Skip the number of tape entries necessary to write a double. - */ - simdjson_really_inline void skip_double() noexcept; - - /** - * Write a value to a known location on tape. - * - * Used to go back and write out the start of a container after the - * container ends. - */ - simdjson_really_inline static void write(uint64_t &tape_loc, - uint64_t val, - internal::tape_type t) noexcept; - - private: - /** - * Append both the tape entry, and a supplementary value following it. Used - * for types that need - * all 64 bits, such as double and uint64_t. - */ - template - simdjson_really_inline void append2(uint64_t val, - T val2, - internal::tape_type t) noexcept; -}; // struct number_writer - -simdjson_really_inline void tape_writer::append_s64(int64_t value) noexcept { - append2(0, value, internal::tape_type::INT64); -} - -simdjson_really_inline void tape_writer::append_u64(uint64_t value) noexcept { - append(0, internal::tape_type::UINT64); - *next_tape_loc = value; - next_tape_loc++; -} - -/** Write a double value to tape. */ -simdjson_really_inline void tape_writer::append_double(double value) noexcept { - append2(0, value, internal::tape_type::DOUBLE); -} - -simdjson_really_inline void tape_writer::skip() noexcept { next_tape_loc++; } - -simdjson_really_inline void tape_writer::skip_large_integer() noexcept { - next_tape_loc += 2; -} - -simdjson_really_inline void tape_writer::skip_double() noexcept { - next_tape_loc += 2; -} - -simdjson_really_inline void tape_writer::append( - uint64_t val, internal::tape_type t) noexcept { - *next_tape_loc = val | ((uint64_t(char(t))) << 56); - next_tape_loc++; -} - -template -simdjson_really_inline void tape_writer::append2( - uint64_t val, T val2, internal::tape_type t) noexcept { - append(val, t); - static_assert(sizeof(val2) == sizeof(*next_tape_loc), - "Type is not 64 bits!"); - memcpy(next_tape_loc, &val2, sizeof(val2)); - next_tape_loc++; -} - -simdjson_really_inline void tape_writer::write(uint64_t &tape_loc, - uint64_t val, - internal::tape_type t) noexcept { - tape_loc = val | ((uint64_t(char(t))) << 56); -} - -} // namespace stage2 -} // unnamed namespace -} // namespace ppc64 -} // namespace simdjson -/* end file src/generic/stage2/tape_writer.h */ - -namespace simdjson { -namespace ppc64 { -namespace { -namespace stage2 { - -struct tape_builder { - template - simdjson_warn_unused static simdjson_really_inline error_code - parse_document(dom_parser_implementation &dom_parser, - dom::document &doc) noexcept; - - /** Called when a non-empty document starts. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_document_start(json_iterator &iter) noexcept; - /** Called when a non-empty document ends without error. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_document_end(json_iterator &iter) noexcept; - - /** Called when a non-empty array starts. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_array_start(json_iterator &iter) noexcept; - /** Called when a non-empty array ends. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_array_end(json_iterator &iter) noexcept; - /** Called when an empty array is found. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_empty_array(json_iterator &iter) noexcept; - - /** Called when a non-empty object starts. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_object_start(json_iterator &iter) noexcept; - /** - * Called when a key in a field is encountered. - * - * primitive, visit_object_start, visit_empty_object, visit_array_start, or - * visit_empty_array - * will be called after this with the field value. - */ - simdjson_warn_unused simdjson_really_inline error_code - visit_key(json_iterator &iter, const uint8_t *key) noexcept; - /** Called when a non-empty object ends. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_object_end(json_iterator &iter) noexcept; - /** Called when an empty object is found. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_empty_object(json_iterator &iter) noexcept; - - /** - * Called when a string, number, boolean or null is found. - */ - simdjson_warn_unused simdjson_really_inline error_code - visit_primitive(json_iterator &iter, const uint8_t *value) noexcept; - /** - * Called when a string, number, boolean or null is found at the top level - * of a document (i.e. - * when there is no array or object and the entire document is a single - * string, number, boolean or - * null. - * - * This is separate from primitive() because simdjson's normal primitive - * parsing routines assume - * there is at least one more token after the value, which is only true in - * an array or object. - */ - simdjson_warn_unused simdjson_really_inline error_code - visit_root_primitive(json_iterator &iter, const uint8_t *value) noexcept; - - simdjson_warn_unused simdjson_really_inline error_code visit_string( - json_iterator &iter, const uint8_t *value, bool key = false) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_number(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_true_atom(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_false_atom(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_null_atom(json_iterator &iter, const uint8_t *value) noexcept; - - simdjson_warn_unused simdjson_really_inline error_code - visit_root_string(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_root_number(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_root_true_atom(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_root_false_atom(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_root_null_atom(json_iterator &iter, const uint8_t *value) noexcept; - - /** Called each time a new field or element in an array or object is found. - */ - simdjson_warn_unused simdjson_really_inline error_code - increment_count(json_iterator &iter) noexcept; - - /** Next location to write to tape */ - tape_writer tape; - - private: - /** Next write location in the string buf for stage 2 parsing */ - uint8_t *current_string_buf_loc; - - simdjson_really_inline tape_builder(dom::document &doc) noexcept; - - simdjson_really_inline uint32_t next_tape_index(json_iterator &iter) const - noexcept; - simdjson_really_inline void start_container(json_iterator &iter) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - end_container(json_iterator &iter, - internal::tape_type start, - internal::tape_type end) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - empty_container(json_iterator &iter, - internal::tape_type start, - internal::tape_type end) noexcept; - simdjson_really_inline uint8_t *on_start_string( - json_iterator &iter) noexcept; - simdjson_really_inline void on_end_string(uint8_t *dst) noexcept; -}; // class tape_builder - -template -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::parse_document(dom_parser_implementation &dom_parser, - dom::document &doc) noexcept { - dom_parser.doc = &doc; - json_iterator iter(dom_parser, - STREAMING ? dom_parser.next_structural_index : 0); - tape_builder builder(doc); - return iter.walk_document(builder); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_primitive(json_iterator &iter, - const uint8_t *value) noexcept { - return iter.visit_root_primitive(*this, value); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_primitive(json_iterator &iter, - const uint8_t *value) noexcept { - return iter.visit_primitive(*this, value); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_empty_object(json_iterator &iter) noexcept { - return empty_container(iter, - internal::tape_type::START_OBJECT, - internal::tape_type::END_OBJECT); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_empty_array(json_iterator &iter) noexcept { - return empty_container( - iter, internal::tape_type::START_ARRAY, internal::tape_type::END_ARRAY); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_document_start(json_iterator &iter) noexcept { - start_container(iter); - return SUCCESS; -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_object_start(json_iterator &iter) noexcept { - start_container(iter); - return SUCCESS; -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_array_start(json_iterator &iter) noexcept { - start_container(iter); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_object_end(json_iterator &iter) noexcept { - return end_container(iter, - internal::tape_type::START_OBJECT, - internal::tape_type::END_OBJECT); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_array_end(json_iterator &iter) noexcept { - return end_container( - iter, internal::tape_type::START_ARRAY, internal::tape_type::END_ARRAY); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_document_end(json_iterator &iter) noexcept { - constexpr uint32_t start_tape_index = 0; - tape.append(start_tape_index, internal::tape_type::ROOT); - tape_writer::write(iter.dom_parser.doc->tape[start_tape_index], - next_tape_index(iter), - internal::tape_type::ROOT); - return SUCCESS; -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_key(json_iterator &iter, const uint8_t *key) noexcept { - return visit_string(iter, key, true); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::increment_count(json_iterator &iter) noexcept { - iter.dom_parser.open_containers[iter.depth] - .count++; // we have a key value pair in the object at - // parser.dom_parser.depth - 1 - return SUCCESS; -} - -simdjson_really_inline tape_builder::tape_builder(dom::document &doc) noexcept - : tape{doc.tape.get()}, - current_string_buf_loc{doc.string_buf.get()} {} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_string(json_iterator &iter, - const uint8_t *value, - bool key) noexcept { - iter.log_value(key ? "key" : "string"); - uint8_t *dst = on_start_string(iter); - dst = stringparsing::parse_string(value + 1, dst); - if (dst == nullptr) { - iter.log_error("Invalid escape in string"); - return STRING_ERROR; - } - on_end_string(dst); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_string(json_iterator &iter, - const uint8_t *value) noexcept { - return visit_string(iter, value); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_number(json_iterator &iter, const uint8_t *value) noexcept { - iter.log_value("number"); - return numberparsing::parse_number(value, tape); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_number(json_iterator &iter, - const uint8_t *value) noexcept { - // - // We need to make a copy to make sure that the string is space terminated. - // This is not about padding the input, which should already padded up - // to len + SIMDJSON_PADDING. However, we have no control at this stage - // on how the padding was done. What if the input string was padded with - // nulls? - // It is quite common for an input string to have an extra null character (C - // string). - // We do not want to allow 9\0 (where \0 is the null character) inside a - // JSON - // document, but the string "9\0" by itself is fine. So we make a copy and - // pad the input with spaces when we know that there is just one input - // element. - // This copy is relatively expensive, but it will almost never be called in - // practice unless you are in the strange scenario where you have many JSON - // documents made of single atoms. - // - std::unique_ptr copy( - new (std::nothrow) uint8_t[iter.remaining_len() + SIMDJSON_PADDING]); - if (copy.get() == nullptr) { - return MEMALLOC; - } - std::memcpy(copy.get(), value, iter.remaining_len()); - std::memset(copy.get() + iter.remaining_len(), ' ', SIMDJSON_PADDING); - error_code error = visit_number(iter, copy.get()); - return error; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_true_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("true"); - if (!atomparsing::is_valid_true_atom(value)) { - return T_ATOM_ERROR; - } - tape.append(0, internal::tape_type::TRUE_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_true_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("true"); - if (!atomparsing::is_valid_true_atom(value, iter.remaining_len())) { - return T_ATOM_ERROR; - } - tape.append(0, internal::tape_type::TRUE_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_false_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("false"); - if (!atomparsing::is_valid_false_atom(value)) { - return F_ATOM_ERROR; - } - tape.append(0, internal::tape_type::FALSE_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_false_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("false"); - if (!atomparsing::is_valid_false_atom(value, iter.remaining_len())) { - return F_ATOM_ERROR; - } - tape.append(0, internal::tape_type::FALSE_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_null_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("null"); - if (!atomparsing::is_valid_null_atom(value)) { - return N_ATOM_ERROR; - } - tape.append(0, internal::tape_type::NULL_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_null_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("null"); - if (!atomparsing::is_valid_null_atom(value, iter.remaining_len())) { - return N_ATOM_ERROR; - } - tape.append(0, internal::tape_type::NULL_VALUE); - return SUCCESS; -} - -// private: - -simdjson_really_inline uint32_t -tape_builder::next_tape_index(json_iterator &iter) const noexcept { - return uint32_t(tape.next_tape_loc - iter.dom_parser.doc->tape.get()); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::empty_container(json_iterator &iter, - internal::tape_type start, - internal::tape_type end) noexcept { - auto start_index = next_tape_index(iter); - tape.append(start_index + 2, start); - tape.append(start_index, end); - return SUCCESS; -} - -simdjson_really_inline void tape_builder::start_container( - json_iterator &iter) noexcept { - iter.dom_parser.open_containers[iter.depth].tape_index = - next_tape_index(iter); - iter.dom_parser.open_containers[iter.depth].count = 0; - tape.skip(); // We don't actually *write* the start element until the end. -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::end_container(json_iterator &iter, - internal::tape_type start, - internal::tape_type end) noexcept { - // Write the ending tape element, pointing at the start location - const uint32_t start_tape_index = - iter.dom_parser.open_containers[iter.depth].tape_index; - tape.append(start_tape_index, end); - // Write the start tape element, pointing at the end location (and including - // count) - // count can overflow if it exceeds 24 bits... so we saturate - // the convention being that a cnt of 0xffffff or more is undetermined in - // value (>= 0xffffff). - const uint32_t count = iter.dom_parser.open_containers[iter.depth].count; - const uint32_t cntsat = count > 0xFFFFFF ? 0xFFFFFF : count; - tape_writer::write(iter.dom_parser.doc->tape[start_tape_index], - next_tape_index(iter) | (uint64_t(cntsat) << 32), - start); - return SUCCESS; -} - -simdjson_really_inline uint8_t *tape_builder::on_start_string( - json_iterator &iter) noexcept { - // we advance the point, accounting for the fact that we have a NULL - // termination - tape.append(current_string_buf_loc - iter.dom_parser.doc->string_buf.get(), - internal::tape_type::STRING); - return current_string_buf_loc + sizeof(uint32_t); -} - -simdjson_really_inline void tape_builder::on_end_string(uint8_t *dst) noexcept { - uint32_t str_length = - uint32_t(dst - (current_string_buf_loc + sizeof(uint32_t))); - // TODO check for overflow in case someone has a crazy string (>=4GB?) - // But only add the overflow check when the document itself exceeds 4GB - // Currently unneeded because we refuse to parse docs larger or equal to - // 4GB. - memcpy(current_string_buf_loc, &str_length, sizeof(uint32_t)); - // NULL termination is still handy if you expect all your strings to - // be NULL terminated? It comes at a small cost - *dst = 0; - current_string_buf_loc = dst + 1; -} - -} // namespace stage2 -} // unnamed namespace -} // namespace ppc64 -} // namespace simdjson -/* end file src/generic/stage2/tape_builder.h */ - -// -// Implementation-specific overrides -// -namespace simdjson { -namespace ppc64 { -namespace { -namespace stage1 { - -simdjson_really_inline uint64_t -json_string_scanner::find_escaped(uint64_t backslash) { - // On PPC, we don't short-circuit this if there are no backslashes, because - // the branch gives us no - // benefit and therefore makes things worse. - // if (!backslash) { uint64_t escaped = prev_escaped; prev_escaped = 0; - // return escaped; } - return find_escaped_branchless(backslash); -} - -} // namespace stage1 -} // unnamed namespace - -simdjson_warn_unused error_code implementation::minify(const uint8_t *buf, - size_t len, - uint8_t *dst, - size_t &dst_len) const - noexcept { - return ppc64::stage1::json_minifier::minify<64>(buf, len, dst, dst_len); -} - -simdjson_warn_unused error_code dom_parser_implementation::stage1( - const uint8_t *_buf, size_t _len, stage1_mode streaming) noexcept { - this->buf = _buf; - this->len = _len; - return ppc64::stage1::json_structural_indexer::index<64>( - buf, len, *this, streaming); -} - -simdjson_warn_unused bool implementation::validate_utf8(const char *buf, - size_t len) const - noexcept { - return ppc64::stage1::generic_validate_utf8(buf, len); -} - -simdjson_warn_unused error_code -dom_parser_implementation::stage2(dom::document &_doc) noexcept { - return stage2::tape_builder::parse_document(*this, _doc); -} - -simdjson_warn_unused error_code -dom_parser_implementation::stage2_next(dom::document &_doc) noexcept { - return stage2::tape_builder::parse_document(*this, _doc); -} - -simdjson_warn_unused error_code dom_parser_implementation::parse( - const uint8_t *_buf, size_t _len, dom::document &_doc) noexcept { - auto error = stage1(_buf, _len, stage1_mode::regular); - if (error) { - return error; - } - return stage2(_doc); -} - -} // namespace ppc64 -} // namespace simdjson - -/* begin file include/simdjson/ppc64/end.h */ -/* end file include/simdjson/ppc64/end.h */ -/* end file src/ppc64/dom_parser_implementation.cpp */ -#endif -#if SIMDJSON_IMPLEMENTATION_WESTMERE -/* begin file src/westmere/implementation.cpp */ -/* begin file include/simdjson/westmere/begin.h */ -// redefining SIMDJSON_IMPLEMENTATION to "westmere" -// #define SIMDJSON_IMPLEMENTATION westmere -SIMDJSON_TARGET_WESTMERE -/* end file include/simdjson/westmere/begin.h */ - -namespace simdjson { -namespace westmere { - -simdjson_warn_unused error_code -implementation::create_dom_parser_implementation( - size_t capacity, - size_t max_depth, - std::unique_ptr &dst) const noexcept { - dst.reset(new (std::nothrow) dom_parser_implementation()); - if (!dst) { - return MEMALLOC; - } - if (auto err = dst->set_capacity(capacity)) return err; - if (auto err = dst->set_max_depth(max_depth)) return err; - return SUCCESS; -} - -} // namespace westmere -} // namespace simdjson - -/* begin file include/simdjson/westmere/end.h */ -SIMDJSON_UNTARGET_WESTMERE -/* end file include/simdjson/westmere/end.h */ -/* end file src/westmere/implementation.cpp */ -/* begin file src/westmere/dom_parser_implementation.cpp */ -/* begin file include/simdjson/westmere/begin.h */ -// redefining SIMDJSON_IMPLEMENTATION to "westmere" -// #define SIMDJSON_IMPLEMENTATION westmere -SIMDJSON_TARGET_WESTMERE -/* end file include/simdjson/westmere/begin.h */ - -// -// Stage 1 -// - -namespace simdjson { -namespace westmere { -namespace { - -using namespace simd; - -struct json_character_block { - static simdjson_really_inline json_character_block - classify(const simd::simd8x64 &in); - - simdjson_really_inline uint64_t whitespace() const noexcept { - return _whitespace; - } - simdjson_really_inline uint64_t op() const noexcept { return _op; } - simdjson_really_inline uint64_t scalar() const noexcept { - return ~(op() | whitespace()); - } - - uint64_t _whitespace; - uint64_t _op; -}; - -simdjson_really_inline json_character_block -json_character_block::classify(const simd::simd8x64 &in) { - // These lookups rely on the fact that anything < 127 will match the lower 4 - // bits, which is why - // we can't use the generic lookup_16. - auto whitespace_table = simd8::repeat_16(' ', - 100, - 100, - 100, - 17, - 100, - 113, - 2, - 100, - '\t', - '\n', - 112, - 100, - '\r', - 100, - 100); - - // The 6 operators (:,[]{}) have these values: - // - // , 2C - // : 3A - // [ 5B - // { 7B - // ] 5D - // } 7D - // - // If you use | 0x20 to turn [ and ] into { and }, the lower 4 bits of each - // character is unique. - // We exploit this, using a simd 4-bit lookup to tell us which character - // match against, and then - // match it (against | 0x20). - // - // To prevent recognizing other characters, everything else gets compared - // with 0, which cannot - // match due to the | 0x20. - // - // NOTE: Due to the | 0x20, this ALSO treats and (control - // characters 0C and 1A) like , - // and :. This gets caught in stage 2, which checks the actual character to - // ensure the right - // operators are in the right places. - const auto op_table = - simd8::repeat_16(0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - ':', - '{', // : = 3A, [ = 5B, { = 7B - ',', - '}', - 0, - 0 // , = 2C, ] = 5D, } = 7D - ); - - // We compute whitespace and op separately. If the code later only use one - // or the - // other, given the fact that all functions are aggressively inlined, we can - // hope that useless computations will be omitted. This is namely case when - // minifying (we only need whitespace). - - - const uint64_t whitespace = - in.eq({_mm_shuffle_epi8(whitespace_table, in.chunks[0]), - _mm_shuffle_epi8(whitespace_table, in.chunks[1]), - _mm_shuffle_epi8(whitespace_table, in.chunks[2]), - _mm_shuffle_epi8(whitespace_table, in.chunks[3])}); - // Turn [ and ] into { and } - const simd8x64 curlified{in.chunks[0] | 0x20, - in.chunks[1] | 0x20, - in.chunks[2] | 0x20, - in.chunks[3] | 0x20}; - const uint64_t op = - curlified.eq({_mm_shuffle_epi8(op_table, in.chunks[0]), - _mm_shuffle_epi8(op_table, in.chunks[1]), - _mm_shuffle_epi8(op_table, in.chunks[2]), - _mm_shuffle_epi8(op_table, in.chunks[3])}); - return {whitespace, op}; -} - -simdjson_really_inline bool is_ascii(const simd8x64 &input) { - return input.reduce_or().is_ascii(); -} - -simdjson_unused simdjson_really_inline simd8 must_be_continuation( - const simd8 prev1, - const simd8 prev2, - const simd8 prev3) { - simd8 is_second_byte = - prev1.saturating_sub(0b11000000u - 1); // Only 11______ will be > 0 - simd8 is_third_byte = - prev2.saturating_sub(0b11100000u - 1); // Only 111_____ will be > 0 - simd8 is_fourth_byte = - prev3.saturating_sub(0b11110000u - 1); // Only 1111____ will be > 0 - // Caller requires a bool (all 1's). All values resulting from the - // subtraction will be <= 64, so signed comparison is fine. - return simd8(is_second_byte | is_third_byte | is_fourth_byte) > - int8_t(0); -} - -simdjson_really_inline simd8 must_be_2_3_continuation( - const simd8 prev2, const simd8 prev3) { - simd8 is_third_byte = - prev2.saturating_sub(0b11100000u - 1); // Only 111_____ will be > 0 - simd8 is_fourth_byte = - prev3.saturating_sub(0b11110000u - 1); // Only 1111____ will be > 0 - // Caller requires a bool (all 1's). All values resulting from the - // subtraction will be <= 64, so signed comparison is fine. - return simd8(is_third_byte | is_fourth_byte) > int8_t(0); -} - -} // unnamed namespace -} // namespace westmere -} // namespace simdjson - -/* begin file src/generic/stage1/utf8_lookup4_algorithm.h */ -namespace simdjson { -namespace westmere { -namespace { -namespace utf8_validation { - -using namespace simd; - -simdjson_really_inline simd8 check_special_cases( - const simd8 input, const simd8 prev1) { - // Bit 0 = Too Short (lead byte/ASCII followed by lead byte/ASCII) - // Bit 1 = Too Long (ASCII followed by continuation) - // Bit 2 = Overlong 3-byte - // Bit 4 = Surrogate - // Bit 5 = Overlong 2-byte - // Bit 7 = Two Continuations - constexpr const uint8_t TOO_SHORT = 1 << 0; // 11______ 0_______ - // 11______ 11______ - constexpr const uint8_t TOO_LONG = 1 << 1; // 0_______ 10______ - constexpr const uint8_t OVERLONG_3 = 1 << 2; // 11100000 100_____ - constexpr const uint8_t SURROGATE = 1 << 4; // 11101101 101_____ - constexpr const uint8_t OVERLONG_2 = 1 << 5; // 1100000_ 10______ - constexpr const uint8_t TWO_CONTS = 1 << 7; // 10______ 10______ - constexpr const uint8_t TOO_LARGE = 1 << 3; // 11110100 1001____ - // 11110100 101_____ - // 11110101 1001____ - // 11110101 101_____ - // 1111011_ 1001____ - // 1111011_ 101_____ - // 11111___ 1001____ - // 11111___ 101_____ - constexpr const uint8_t TOO_LARGE_1000 = 1 << 6; - // 11110101 1000____ - // 1111011_ 1000____ - // 11111___ 1000____ - constexpr const uint8_t OVERLONG_4 = 1 << 6; // 11110000 1000____ - - const simd8 byte_1_high = prev1.shr<4>().lookup_16( - // 0_______ ________ - TOO_LONG, - TOO_LONG, - TOO_LONG, - TOO_LONG, - TOO_LONG, - TOO_LONG, - TOO_LONG, - TOO_LONG, - // 10______ ________ - TWO_CONTS, - TWO_CONTS, - TWO_CONTS, - TWO_CONTS, - // 1100____ ________ - TOO_SHORT | OVERLONG_2, - // 1101____ ________ - TOO_SHORT, - // 1110____ ________ - TOO_SHORT | OVERLONG_3 | SURROGATE, - // 1111____ ________ - TOO_SHORT | TOO_LARGE | TOO_LARGE_1000 | OVERLONG_4); - constexpr const uint8_t CARRY = - TOO_SHORT | TOO_LONG | TWO_CONTS; // These all have ____ in byte 1 . - const simd8 byte_1_low = - (prev1 & 0x0F) - .lookup_16( - // ____0000 ________ - CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4, - // ____0001 ________ - CARRY | OVERLONG_2, - // ____001_ ________ - CARRY, - CARRY, - - // ____0100 ________ - CARRY | TOO_LARGE, - // ____0101 ________ - CARRY | TOO_LARGE | TOO_LARGE_1000, - // ____011_ ________ - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000, - - // ____1___ ________ - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000, - // ____1101 ________ - CARRY | TOO_LARGE | TOO_LARGE_1000 | SURROGATE, - CARRY | TOO_LARGE | TOO_LARGE_1000, - CARRY | TOO_LARGE | TOO_LARGE_1000); - const simd8 byte_2_high = input.shr<4>().lookup_16( - // ________ 0_______ - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - - // ________ 1000____ - TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE_1000 | - OVERLONG_4, - // ________ 1001____ - TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE, - // ________ 101_____ - TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE | TOO_LARGE, - TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE | TOO_LARGE, - - // ________ 11______ - TOO_SHORT, - TOO_SHORT, - TOO_SHORT, - TOO_SHORT); - return (byte_1_high & byte_1_low & byte_2_high); -} -simdjson_really_inline simd8 check_multibyte_lengths( - const simd8 input, - const simd8 prev_input, - const simd8 sc) { - simd8 prev2 = input.prev<2>(prev_input); - simd8 prev3 = input.prev<3>(prev_input); - simd8 must23 = - simd8(must_be_2_3_continuation(prev2, prev3)); - simd8 must23_80 = must23 & uint8_t(0x80); - return must23_80 ^ sc; -} - -// -// Return nonzero if there are incomplete multibyte characters at the end of the -// block: -// e.g. if there is a 4-byte character, but it's 3 bytes from the end. -// -simdjson_really_inline simd8 is_incomplete( - const simd8 input) { - // If the previous input's last 3 bytes match this, they're too short (they - // ended at EOF): - // ... 1111____ 111_____ 11______ - static const uint8_t max_array[32] = {255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 0b11110000u - 1, - 0b11100000u - 1, - 0b11000000u - 1}; - const simd8 max_value( - &max_array[sizeof(max_array) - sizeof(simd8)]); - return input.gt_bits(max_value); -} - -struct utf8_checker { - // If this is nonzero, there has been a UTF-8 error. - simd8 error; - // The last input we received - simd8 prev_input_block; - // Whether the last input we received was incomplete (used for ASCII fast - // path) - simd8 prev_incomplete; - - // - // Check whether the current bytes are valid UTF-8. - // - simdjson_really_inline void check_utf8_bytes( - const simd8 input, const simd8 prev_input) { - // Flip prev1...prev3 so we can easily determine if they are 2+, 3+ or - // 4+ lead bytes - // (2, 3, 4-byte leads become large positive numbers instead of small - // negative numbers) - simd8 prev1 = input.prev<1>(prev_input); - simd8 sc = check_special_cases(input, prev1); - this->error |= check_multibyte_lengths(input, prev_input, sc); - } - - // The only problem that can happen at EOF is that a multibyte character is - // too short - // or a byte value too large in the last bytes: check_special_cases only - // checks for bytes - // too large in the first of two bytes. - simdjson_really_inline void check_eof() { - // If the previous block had incomplete UTF-8 characters at the end, an - // ASCII block can't - // possibly finish them. - this->error |= this->prev_incomplete; - } - - simdjson_really_inline void check_next_input( - const simd8x64 &input) { - if (simdjson_likely(is_ascii(input))) { - this->error |= this->prev_incomplete; - } else { - // you might think that a for-loop would work, but under Visual - // Studio, it is not good enough. - static_assert( - (simd8x64::NUM_CHUNKS == 2) || - (simd8x64::NUM_CHUNKS == 4), - "We support either two or four chunks per 64-byte block."); - if (simd8x64::NUM_CHUNKS == 2) { - this->check_utf8_bytes(input.chunks[0], this->prev_input_block); - this->check_utf8_bytes(input.chunks[1], input.chunks[0]); - } else if (simd8x64::NUM_CHUNKS == 4) { - this->check_utf8_bytes(input.chunks[0], this->prev_input_block); - this->check_utf8_bytes(input.chunks[1], input.chunks[0]); - this->check_utf8_bytes(input.chunks[2], input.chunks[1]); - this->check_utf8_bytes(input.chunks[3], input.chunks[2]); - } - this->prev_incomplete = - is_incomplete(input.chunks[simd8x64::NUM_CHUNKS - 1]); - this->prev_input_block = - input.chunks[simd8x64::NUM_CHUNKS - 1]; - } - } - // do not forget to call check_eof! - simdjson_really_inline error_code errors() { - return this->error.any_bits_set_anywhere() ? error_code::UTF8_ERROR - : error_code::SUCCESS; - } - -}; // struct utf8_checker -} // namespace utf8_validation - -using utf8_validation::utf8_checker; - -} // unnamed namespace -} // namespace westmere -} // namespace simdjson -/* end file src/generic/stage1/utf8_lookup4_algorithm.h */ -/* begin file src/generic/stage1/json_structural_indexer.h */ -// This file contains the common code every implementation uses in stage1 -// It is intended to be included multiple times and compiled multiple times -// We assume the file in which it is included already includes -// "simdjson/stage1.h" (this simplifies amalgation) - -/* begin file src/generic/stage1/buf_block_reader.h */ -namespace simdjson { -namespace westmere { -namespace { - -// Walks through a buffer in block-sized increments, loading the last part with -// spaces -template -struct buf_block_reader { - public: - simdjson_really_inline buf_block_reader(const uint8_t *_buf, size_t _len); - simdjson_really_inline size_t block_index(); - simdjson_really_inline bool has_full_block() const; - simdjson_really_inline const uint8_t *full_block() const; - /** - * Get the last block, padded with spaces. - * - * There will always be a last block, with at least 1 byte, unless len == 0 - * (in which case this - * function fills the buffer with spaces and returns 0. In particular, if - * len == STEP_SIZE there - * will be 0 full_blocks and 1 remainder block with STEP_SIZE bytes and no - * spaces for padding. - * - * @return the number of effective characters in the last block. - */ - simdjson_really_inline size_t get_remainder(uint8_t *dst) const; - simdjson_really_inline void advance(); - - private: - const uint8_t *buf; - const size_t len; - const size_t lenminusstep; - size_t idx; -}; - -// Routines to print masks and text for debugging bitmask operations -simdjson_unused static char *format_input_text_64(const uint8_t *text) { - static char buf[sizeof(simd8x64) + 1]; - for (size_t i = 0; i < sizeof(simd8x64); i++) { - buf[i] = int8_t(text[i]) < ' ' ? '_' : int8_t(text[i]); - } - buf[sizeof(simd8x64)] = '\0'; - return buf; -} - -// Routines to print masks and text for debugging bitmask operations -simdjson_unused static char *format_input_text(const simd8x64 &in) { - static char buf[sizeof(simd8x64) + 1]; - in.store(reinterpret_cast(buf)); - for (size_t i = 0; i < sizeof(simd8x64); i++) { - if (buf[i] < ' ') { - buf[i] = '_'; - } - } - buf[sizeof(simd8x64)] = '\0'; - return buf; -} - -simdjson_unused static char *format_mask(uint64_t mask) { - static char buf[sizeof(simd8x64) + 1]; - for (size_t i = 0; i < 64; i++) { - buf[i] = (mask & (size_t(1) << i)) ? 'X' : ' '; - } - buf[64] = '\0'; - return buf; -} - -template -simdjson_really_inline buf_block_reader::buf_block_reader( - const uint8_t *_buf, size_t _len) - : buf{_buf}, - len{_len}, - lenminusstep{len < STEP_SIZE ? 0 : len - STEP_SIZE}, - idx{0} {} - -template -simdjson_really_inline size_t buf_block_reader::block_index() { - return idx; -} - -template -simdjson_really_inline bool buf_block_reader::has_full_block() - const { - return idx < lenminusstep; -} - -template -simdjson_really_inline const uint8_t *buf_block_reader::full_block() - const { - return &buf[idx]; -} - -template -simdjson_really_inline size_t -buf_block_reader::get_remainder(uint8_t *dst) const { - if (len == idx) { - return 0; - } // memcpy(dst, null, 0) will trigger an error with some sanitizers - std::memset(dst, 0x20, STEP_SIZE); // std::memset STEP_SIZE because it's - // more efficient to write out 8 or 16 - // bytes at once. - std::memcpy(dst, buf + idx, len - idx); - return len - idx; -} - -template -simdjson_really_inline void buf_block_reader::advance() { - idx += STEP_SIZE; -} - -} // unnamed namespace -} // namespace westmere -} // namespace simdjson -/* end file src/generic/stage1/buf_block_reader.h */ -/* begin file src/generic/stage1/json_string_scanner.h */ -namespace simdjson { -namespace westmere { -namespace { -namespace stage1 { - -struct json_string_block { - // We spell out the constructors in the hope of resolving inlining issues - // with Visual Studio 2017 - simdjson_really_inline json_string_block(uint64_t backslash, - uint64_t escaped, - uint64_t quote, - uint64_t in_string) - : _backslash(backslash), - _escaped(escaped), - _quote(quote), - _in_string(in_string) {} - - // Escaped characters (characters following an escape() character) - simdjson_really_inline uint64_t escaped() const { return _escaped; } - // Escape characters (backslashes that are not escaped--i.e. in \\, includes - // only the first \) - simdjson_really_inline uint64_t escape() const { - return _backslash & ~_escaped; - } - // Real (non-backslashed) quotes - simdjson_really_inline uint64_t quote() const { return _quote; } - // Start quotes of strings - simdjson_really_inline uint64_t string_start() const { - return _quote & _in_string; - } - // End quotes of strings - simdjson_really_inline uint64_t string_end() const { - return _quote & ~_in_string; - } - // Only characters inside the string (not including the quotes) - simdjson_really_inline uint64_t string_content() const { - return _in_string & ~_quote; - } - // Return a mask of whether the given characters are inside a string (only - // works on non-quotes) - simdjson_really_inline uint64_t - non_quote_inside_string(uint64_t mask) const { - return mask & _in_string; - } - // Return a mask of whether the given characters are inside a string (only - // works on non-quotes) - simdjson_really_inline uint64_t - non_quote_outside_string(uint64_t mask) const { - return mask & ~_in_string; - } - // Tail of string (everything except the start quote) - simdjson_really_inline uint64_t string_tail() const { - return _in_string ^ _quote; - } - - // backslash characters - uint64_t _backslash; - // escaped characters (backslashed--does not include the hex characters - // after \u) - uint64_t _escaped; - // real quotes (non-backslashed ones) - uint64_t _quote; - // string characters (includes start quote but not end quote) - uint64_t _in_string; -}; - -// Scans blocks for string characters, storing the state necessary to do so -class json_string_scanner { - public: - simdjson_really_inline json_string_block - next(const simd::simd8x64 &in); - // Returns either UNCLOSED_STRING or SUCCESS - simdjson_really_inline error_code finish(); - - private: - // Intended to be defined by the implementation - simdjson_really_inline uint64_t find_escaped(uint64_t escape); - simdjson_really_inline uint64_t find_escaped_branchless(uint64_t escape); - - // Whether the last iteration was still inside a string (all 1's = true, all - // 0's = false). - uint64_t prev_in_string = 0ULL; - // Whether the first character of the next iteration is escaped. - uint64_t prev_escaped = 0ULL; -}; - -// -// Finds escaped characters (characters following \). -// -// Handles runs of backslashes like \\\" and \\\\" correctly (yielding 0101 and -// 01010, respectively). -// -// Does this by: -// - Shift the escape mask to get potentially escaped characters (characters -// after backslashes). -// - Mask escaped sequences that start on *even* bits with 1010101010 (odd bits -// are escaped, even bits are not) -// - Mask escaped sequences that start on *odd* bits with 0101010101 (even bits -// are escaped, odd bits are not) -// -// To distinguish between escaped sequences starting on even/odd bits, it finds -// the start of all -// escape sequences, filters out the ones that start on even bits, and adds that -// to the mask of -// escape sequences. This causes the addition to clear out the sequences -// starting on odd bits (since -// the start bit causes a carry), and leaves even-bit sequences alone. -// -// Example: -// -// text | \\\ | \\\"\\\" \\\" \\"\\" | -// escape | xxx | xx xxx xxx xx xx | Removed overflow backslash; -// will | it into follows_escape -// odd_starts | x | x x x | escape & ~even_bits & -// ~follows_escape -// even_seq | c| cxxx c xx c | c = carry bit -- will be -// masked out later -// invert_mask | | cxxx c xx c| even_seq << 1 -// follows_escape | xx | x xx xxx xxx xx xx | Includes overflow bit -// escaped | x | x x x x x x x x | -// desired | x | x x x x x x x x | -// text | \\\ | \\\"\\\" \\\" \\"\\" | -// -simdjson_really_inline uint64_t -json_string_scanner::find_escaped_branchless(uint64_t backslash) { - // If there was overflow, pretend the first character isn't a backslash - backslash &= ~prev_escaped; - uint64_t follows_escape = backslash << 1 | prev_escaped; - - // Get sequences starting on even bits by clearing out the odd series using - // + - const uint64_t even_bits = 0x5555555555555555ULL; - uint64_t odd_sequence_starts = backslash & ~even_bits & ~follows_escape; - uint64_t sequences_starting_on_even_bits; - prev_escaped = add_overflow( - odd_sequence_starts, backslash, &sequences_starting_on_even_bits); - uint64_t invert_mask = - sequences_starting_on_even_bits - << 1; // The mask we want to return is the *escaped* bits, not escapes. - - // Mask every other backslashed character as an escaped character - // Flip the mask for sequences that start on even bits, to correct them - return (even_bits ^ invert_mask) & follows_escape; -} - -// -// Return a mask of all string characters plus end quotes. -// -// prev_escaped is overflow saying whether the next character is escaped. -// prev_in_string is overflow saying whether we're still in a string. -// -// Backslash sequences outside of quotes will be detected in stage 2. -// -simdjson_really_inline json_string_block -json_string_scanner::next(const simd::simd8x64 &in) { - const uint64_t backslash = in.eq('\\'); - const uint64_t escaped = find_escaped(backslash); - const uint64_t quote = in.eq('"') & ~escaped; - - // - // prefix_xor flips on bits inside the string (and flips off the end quote). - // - // Then we xor with prev_in_string: if we were in a string already, its - // effect is flipped - // (characters inside strings are outside, and characters outside strings - // are inside). - // - const uint64_t in_string = prefix_xor(quote) ^ prev_in_string; - - // - // Check if we're still in a string at the end of the box so the next block - // will know - // - // right shift of a signed value expected to be well-defined and standard - // compliant as of C++20, John Regher from Utah U. says this is fine code - // - prev_in_string = uint64_t(static_cast(in_string) >> 63); - - // Use ^ to turn the beginning quote off, and the end quote on. - - // We are returning a function-local object so either we get a move - // constructor - // or we get copy elision. - return json_string_block(backslash, escaped, quote, in_string); -} - -simdjson_really_inline error_code json_string_scanner::finish() { - if (prev_in_string) { - return UNCLOSED_STRING; - } - return SUCCESS; -} - -} // namespace stage1 -} // unnamed namespace -} // namespace westmere -} // namespace simdjson -/* end file src/generic/stage1/json_string_scanner.h */ -/* begin file src/generic/stage1/json_scanner.h */ -namespace simdjson { -namespace westmere { -namespace { -namespace stage1 { - -/** - * A block of scanned json, with information on operators and scalars. - * - * We seek to identify pseudo-structural characters. Anything that is inside - * a string must be omitted (hence & ~_string.string_tail()). - * Otherwise, pseudo-structural characters come in two forms. - * 1. We have the structural characters ([,],{,},:, comma). The - * term 'structural character' is from the JSON RFC. - * 2. We have the 'scalar pseudo-structural characters'. - * Scalars are quotes, and any character except structural characters and - * white space. - * - * To identify the scalar pseudo-structural characters, we must look at what - * comes - * before them: it must be a space, a quote or a structural characters. - * Starting with simdjson v0.3, we identify them by - * negation: we identify everything that is followed by a non-quote scalar, - * and we negate that. Whatever remains must be a 'scalar pseudo-structural - * character'. - */ -struct json_block { - public: - // We spell out the constructors in the hope of resolving inlining issues - // with Visual Studio 2017 - simdjson_really_inline json_block( - json_string_block &&string, - json_character_block characters, - uint64_t follows_potential_nonquote_scalar) - : _string(std::move(string)), - _characters(characters), - _follows_potential_nonquote_scalar( - follows_potential_nonquote_scalar) {} - simdjson_really_inline json_block( - json_string_block string, - json_character_block characters, - uint64_t follows_potential_nonquote_scalar) - : _string(string), - _characters(characters), - _follows_potential_nonquote_scalar( - follows_potential_nonquote_scalar) {} - - /** - * The start of structurals. - * In simdjson prior to v0.3, these were called the pseudo-structural - *characters. - **/ - simdjson_really_inline uint64_t structural_start() const noexcept { - return potential_structural_start() & ~_string.string_tail(); - } - /** All JSON whitespace (i.e. not in a string) */ - simdjson_really_inline uint64_t whitespace() const noexcept { - return non_quote_outside_string(_characters.whitespace()); - } - - // Helpers - - /** Whether the given characters are inside a string (only works on - * non-quotes) */ - simdjson_really_inline uint64_t non_quote_inside_string(uint64_t mask) const - noexcept { - return _string.non_quote_inside_string(mask); - } - /** Whether the given characters are outside a string (only works on - * non-quotes) */ - simdjson_really_inline uint64_t - non_quote_outside_string(uint64_t mask) const noexcept { - return _string.non_quote_outside_string(mask); - } - - // string and escape characters - json_string_block _string; - // whitespace, structural characters ('operators'), scalars - json_character_block _characters; - // whether the previous character was a scalar - uint64_t _follows_potential_nonquote_scalar; - - private: - // Potential structurals (i.e. disregarding strings) - - /** - * structural elements ([,],{,},:, comma) plus scalar starts like 123, true - *and "abc". - * They may reside inside a string. - **/ - simdjson_really_inline uint64_t potential_structural_start() const - noexcept { - return _characters.op() | potential_scalar_start(); - } - /** - * The start of non-operator runs, like 123, true and "abc". - * It main reside inside a string. - **/ - simdjson_really_inline uint64_t potential_scalar_start() const noexcept { - // The term "scalar" refers to anything except structural characters and - // white space - // (so letters, numbers, quotes). - // Whenever it is preceded by something that is not a structural element - // ({,},[,],:, ") nor a white-space - // then we know that it is irrelevant structurally. - return _characters.scalar() & ~follows_potential_scalar(); - } - /** - * Whether the given character is immediately after a non-operator like 123, - * true. - * The characters following a quote are not included. - */ - simdjson_really_inline uint64_t follows_potential_scalar() const noexcept { - // _follows_potential_nonquote_scalar: is defined as marking any - // character that follows a character - // that is not a structural element ({,},[,],:, comma) nor a quote (") - // and that is not a - // white space. - // It is understood that within quoted region, anything at all could be - // marked (irrelevant). - return _follows_potential_nonquote_scalar; - } -}; - -/** - * Scans JSON for important bits: structural characters or 'operators', strings, - * and scalars. - * - * The scanner starts by calculating two distinct things: - * - string characters (taking \" into account) - * - structural characters or 'operators' ([]{},:, comma) - * and scalars (runs of non-operators like 123, true and "abc") - * - * To minimize data dependency (a key component of the scanner's speed), it - * finds these in parallel: - * in particular, the operator/scalar bit will find plenty of things that are - * actually part of - * strings. When we're done, json_block will fuse the two together by masking - * out tokens that are - * part of a string. - */ -class json_scanner { - public: - json_scanner() {} - simdjson_really_inline json_block next(const simd::simd8x64 &in); - // Returns either UNCLOSED_STRING or SUCCESS - simdjson_really_inline error_code finish(); - - private: - // Whether the last character of the previous iteration is part of a scalar - // token - // (anything except whitespace or a structural character/'operator'). - uint64_t prev_scalar = 0ULL; - json_string_scanner string_scanner{}; -}; - - -// -// Check if the current character immediately follows a matching character. -// -// For example, this checks for quotes with backslashes in front of them: -// -// const uint64_t backslashed_quote = in.eq('"') & -// immediately_follows(in.eq('\'), prev_backslash); -// -simdjson_really_inline uint64_t follows(const uint64_t match, - uint64_t &overflow) { - const uint64_t result = match << 1 | overflow; - overflow = match >> 63; - return result; -} - -simdjson_really_inline json_block -json_scanner::next(const simd::simd8x64 &in) { - json_string_block strings = string_scanner.next(in); - // identifies the white-space and the structural characters - json_character_block characters = json_character_block::classify(in); - // The term "scalar" refers to anything except structural characters and - // white space - // (so letters, numbers, quotes). - // We want follows_scalar to mark anything that follows a non-quote scalar - // (so letters and numbers). - // - // A terminal quote should either be followed by a structural character - // (comma, brace, bracket, colon) - // or nothing. However, we still want ' "a string"true ' to mark the 't' of - // 'true' as a potential - // pseudo-structural character just like we would if we had ' "a string" - // true '; otherwise we - // may need to add an extra check when parsing strings. - // - // Performance: there are many ways to skin this cat. - const uint64_t nonquote_scalar = characters.scalar() & ~strings.quote(); - uint64_t follows_nonquote_scalar = follows(nonquote_scalar, prev_scalar); - // We are returning a function-local object so either we get a move - // constructor - // or we get copy elision. - return json_block(strings, // strings is a function-local object so either - // it moves or the copy is elided. - characters, - follows_nonquote_scalar); -} - -simdjson_really_inline error_code json_scanner::finish() { - return string_scanner.finish(); -} - -} // namespace stage1 -} // unnamed namespace -} // namespace westmere -} // namespace simdjson -/* end file src/generic/stage1/json_scanner.h */ -/* begin file src/generic/stage1/json_minifier.h */ -// This file contains the common code every implementation uses in stage1 -// It is intended to be included multiple times and compiled multiple times -// We assume the file in which it is included already includes -// "simdjson/stage1.h" (this simplifies amalgation) - -namespace simdjson { -namespace westmere { -namespace { -namespace stage1 { - -class json_minifier { - public: - template - static error_code minify(const uint8_t *buf, - size_t len, - uint8_t *dst, - size_t &dst_len) noexcept; - - private: - simdjson_really_inline json_minifier(uint8_t *_dst) : dst{_dst} {} - template - simdjson_really_inline void step( - const uint8_t *block_buf, buf_block_reader &reader) noexcept; - simdjson_really_inline void next(const simd::simd8x64 &in, - const json_block &block); - simdjson_really_inline error_code finish(uint8_t *dst_start, - size_t &dst_len); - json_scanner scanner{}; - uint8_t *dst; -}; - -simdjson_really_inline void json_minifier::next( - const simd::simd8x64 &in, const json_block &block) { - uint64_t mask = block.whitespace(); - dst += in.compress(mask, dst); -} - -simdjson_really_inline error_code json_minifier::finish(uint8_t *dst_start, - size_t &dst_len) { - error_code error = scanner.finish(); - if (error) { - dst_len = 0; - return error; - } - dst_len = dst - dst_start; - return SUCCESS; -} - -template <> -simdjson_really_inline void json_minifier::step<128>( - const uint8_t *block_buf, buf_block_reader<128> &reader) noexcept { - simd::simd8x64 in_1(block_buf); - simd::simd8x64 in_2(block_buf + 64); - json_block block_1 = scanner.next(in_1); - json_block block_2 = scanner.next(in_2); - this->next(in_1, block_1); - this->next(in_2, block_2); - reader.advance(); -} - -template <> -simdjson_really_inline void json_minifier::step<64>( - const uint8_t *block_buf, buf_block_reader<64> &reader) noexcept { - simd::simd8x64 in_1(block_buf); - json_block block_1 = scanner.next(in_1); - this->next(block_buf, block_1); - reader.advance(); -} - -template -error_code json_minifier::minify(const uint8_t *buf, - size_t len, - uint8_t *dst, - size_t &dst_len) noexcept { - buf_block_reader reader(buf, len); - json_minifier minifier(dst); - - // Index the first n-1 blocks - while (reader.has_full_block()) { - minifier.step(reader.full_block(), reader); - } - - // Index the last (remainder) block, padded with spaces - uint8_t block[STEP_SIZE]; - size_t remaining_bytes = reader.get_remainder(block); - if (remaining_bytes > 0) { - // We do not want to write directly to the output stream. Rather, we - // write - // to a local buffer (for safety). - uint8_t out_block[STEP_SIZE]; - uint8_t *const guarded_dst{minifier.dst}; - minifier.dst = out_block; - minifier.step(block, reader); - size_t to_write = minifier.dst - out_block; - // In some cases, we could be enticed to consider the padded spaces - // as part of the string. This is fine as long as we do not write more - // than we consumed. - if (to_write > remaining_bytes) { - to_write = remaining_bytes; - } - memcpy(guarded_dst, out_block, to_write); - minifier.dst = guarded_dst + to_write; - } - return minifier.finish(dst, dst_len); -} - -} // namespace stage1 -} // unnamed namespace -} // namespace westmere -} // namespace simdjson -/* end file src/generic/stage1/json_minifier.h */ -/* begin file src/generic/stage1/find_next_document_index.h */ -namespace simdjson { -namespace westmere { -namespace { - -/** - * This algorithm is used to quickly identify the last structural position that - * makes up a complete document. - * - * It does this by going backwards and finding the last *document boundary* (a - * place where one value follows another without a comma between them). If the - * last document (the characters after the boundary) has an equal number of - * start and end brackets, it is considered complete. - * - * Simply put, we iterate over the structural characters, starting from - * the end. We consider that we found the end of a JSON document when the - * first element of the pair is NOT one of these characters: '{' '[' ':' ',' - * and when the second element is NOT one of these characters: '}' ']' ':' ','. - * - * This simple comparison works most of the time, but it does not cover cases - * where the batch's structural indexes contain a perfect amount of documents. - * In such a case, we do not have access to the structural index which follows - * the last document, therefore, we do not have access to the second element in - * the pair, and that means we cannot identify the last document. To fix this - * issue, we keep a count of the open and closed curly/square braces we found - * while searching for the pair. When we find a pair AND the count of open and - * closed curly/square braces is the same, we know that we just passed a - * complete document, therefore the last json buffer location is the end of the - * batch. - */ -simdjson_really_inline uint32_t -find_next_document_index(dom_parser_implementation &parser) { - // Variant: do not count separately, just figure out depth - if (parser.n_structural_indexes == 0) { - return 0; - } - auto arr_cnt = 0; - auto obj_cnt = 0; - for (auto i = parser.n_structural_indexes - 1; i > 0; i--) { - auto idxb = parser.structural_indexes[i]; - switch (parser.buf[idxb]) { - case ':': - case ',': - continue; - case '}': - obj_cnt--; - continue; - case ']': - arr_cnt--; - continue; - case '{': - obj_cnt++; - break; - case '[': - arr_cnt++; - break; - } - auto idxa = parser.structural_indexes[i - 1]; - switch (parser.buf[idxa]) { - case '{': - case '[': - case ':': - case ',': - continue; - } - // Last document is complete, so the next document will appear after! - if (!arr_cnt && !obj_cnt) { - return parser.n_structural_indexes; - } - // Last document is incomplete; mark the document at i + 1 as the next - // one - return i; - } - // If we made it to the end, we want to finish counting to see if we have a - // full document. - switch (parser.buf[parser.structural_indexes[0]]) { - case '}': - obj_cnt--; - break; - case ']': - arr_cnt--; - break; - case '{': - obj_cnt++; - break; - case '[': - arr_cnt++; - break; - } - if (!arr_cnt && !obj_cnt) { - // We have a complete document. - return parser.n_structural_indexes; - } - return 0; -} - -} // unnamed namespace -} // namespace westmere -} // namespace simdjson -/* end file src/generic/stage1/find_next_document_index.h */ - -namespace simdjson { -namespace westmere { -namespace { -namespace stage1 { - -class bit_indexer { - public: - uint32_t *tail; - - simdjson_really_inline bit_indexer(uint32_t *index_buf) : tail(index_buf) {} - - // flatten out values in 'bits' assuming that they are are to have values of - // idx - // plus their position in the bitvector, and store these indexes at - // base_ptr[base] incrementing base as we go - // will potentially store extra values beyond end of valid bits, so base_ptr - // needs to be large enough to handle this - simdjson_really_inline void write(uint32_t idx, uint64_t bits) { - // In some instances, the next branch is expensive because it is - // mispredicted. - // Unfortunately, in other cases, - // it helps tremendously. - if (bits == 0) return; -#if defined(SIMDJSON_PREFER_REVERSE_BITS) - /** - * ARM lacks a fast trailing zero instruction, but it has a fast - * bit reversal instruction and a fast leading zero instruction. - * Thus it may be profitable to reverse the bits (once) and then - * to rely on a sequence of instructions that call the leading - * zero instruction. - * - * Performance notes: - * The chosen routine is not optimal in terms of data dependency - * since zero_leading_bit might require two instructions. However, - * it tends to minimize the total number of instructions which is - * beneficial. - */ - - uint64_t rev_bits = reverse_bits(bits); - int cnt = static_cast(count_ones(bits)); - int i = 0; - // Do the first 8 all together - for (; i < 8; i++) { - int lz = leading_zeroes(rev_bits); - this->tail[i] = static_cast(idx) + lz; - rev_bits = zero_leading_bit(rev_bits, lz); - } - // Do the next 8 all together (we hope in most cases it won't happen at - // all - // and the branch is easily predicted). - if (simdjson_unlikely(cnt > 8)) { - i = 8; - for (; i < 16; i++) { - int lz = leading_zeroes(rev_bits); - this->tail[i] = static_cast(idx) + lz; - rev_bits = zero_leading_bit(rev_bits, lz); - } - - - // Most files don't have 16+ structurals per block, so we take - // several basically guaranteed - // branch mispredictions here. 16+ structurals per block means - // either punctuation ({} [] , :) - // or the start of a value ("abc" true 123) every four characters. - if (simdjson_unlikely(cnt > 16)) { - i = 16; - while (rev_bits != 0) { - int lz = leading_zeroes(rev_bits); - this->tail[i++] = static_cast(idx) + lz; - rev_bits = zero_leading_bit(rev_bits, lz); - } - } - } - this->tail += cnt; -#else // SIMDJSON_PREFER_REVERSE_BITS - /** - * Under recent x64 systems, we often have both a fast trailing zero - * instruction and a fast 'clear-lower-bit' instruction so the following - * algorithm can be competitive. - */ - - int cnt = static_cast(count_ones(bits)); - // Do the first 8 all together - for (int i = 0; i < 8; i++) { - this->tail[i] = idx + trailing_zeroes(bits); - bits = clear_lowest_bit(bits); - } - - // Do the next 8 all together (we hope in most cases it won't happen at - // all - // and the branch is easily predicted). - if (simdjson_unlikely(cnt > 8)) { - for (int i = 8; i < 16; i++) { - this->tail[i] = idx + trailing_zeroes(bits); - bits = clear_lowest_bit(bits); - } - - // Most files don't have 16+ structurals per block, so we take - // several basically guaranteed - // branch mispredictions here. 16+ structurals per block means - // either punctuation ({} [] , :) - // or the start of a value ("abc" true 123) every four characters. - if (simdjson_unlikely(cnt > 16)) { - int i = 16; - do { - this->tail[i] = idx + trailing_zeroes(bits); - bits = clear_lowest_bit(bits); - i++; - } while (i < cnt); - } - } - - this->tail += cnt; -#endif - } -}; - -class json_structural_indexer { - public: - /** - * Find the important bits of JSON in a 128-byte chunk, and add them to - * structural_indexes. - * - * @param partial Setting the partial parameter to true allows the - * find_structural_bits to - * tolerate unclosed strings. The caller should still ensure that the - * input is valid UTF-8. If - * you are processing substrings, you may want to call on a function like - * trimmed_length_safe_utf8. - */ - template - static error_code index(const uint8_t *buf, - size_t len, - dom_parser_implementation &parser, - stage1_mode partial) noexcept; - - private: - simdjson_really_inline json_structural_indexer( - uint32_t *structural_indexes); - template - simdjson_really_inline void step( - const uint8_t *block, buf_block_reader &reader) noexcept; - simdjson_really_inline void next(const simd::simd8x64 &in, - const json_block &block, - size_t idx); - simdjson_really_inline error_code finish(dom_parser_implementation &parser, - size_t idx, - size_t len, - stage1_mode partial); - - json_scanner scanner{}; - utf8_checker checker{}; - bit_indexer indexer; - uint64_t prev_structurals = 0; - uint64_t unescaped_chars_error = 0; -}; - -simdjson_really_inline json_structural_indexer::json_structural_indexer( - uint32_t *structural_indexes) - : indexer{structural_indexes} {} - -// Skip the last character if it is partial -simdjson_really_inline size_t trim_partial_utf8(const uint8_t *buf, - size_t len) { - if (simdjson_unlikely(len < 3)) { - switch (len) { - case 2: - if (buf[len - 1] >= 0b11000000) { - return len - 1; - } // 2-, 3- and 4-byte characters with only 1 byte left - if (buf[len - 2] >= 0b11100000) { - return len - 2; - } // 3- and 4-byte characters with only 2 bytes left - return len; - case 1: - if (buf[len - 1] >= 0b11000000) { - return len - 1; - } // 2-, 3- and 4-byte characters with only 1 byte left - return len; - case 0: - return len; - } - } - if (buf[len - 1] >= 0b11000000) { - return len - 1; - } // 2-, 3- and 4-byte characters with only 1 byte left - if (buf[len - 2] >= 0b11100000) { - return len - 2; - } // 3- and 4-byte characters with only 1 byte left - if (buf[len - 3] >= 0b11110000) { - return len - 3; - } // 4-byte characters with only 3 bytes left - return len; -} - -// -// PERF NOTES: -// We pipe 2 inputs through these stages: -// 1. Load JSON into registers. This takes a long time and is highly -// parallelizable, so we load -// 2 inputs' worth at once so that by the time step 2 is looking for them -// input, it's available. -// 2. Scan the JSON for critical data: strings, scalars and operators. This is -// the critical path. -// The output of step 1 depends entirely on this information. These functions -// don't quite use -// up enough CPU: the second half of the functions is highly serial, only -// using 1 execution core -// at a time. The second input's scans has some dependency on the first ones -// finishing it, but -// they can make a lot of progress before they need that information. -// 3. Step 1 doesn't use enough capacity, so we run some extra stuff while we're -// waiting for that -// to finish: utf-8 checks and generating the output from the last iteration. -// -// The reason we run 2 inputs at a time, is steps 2 and 3 are *still* not enough -// to soak up all -// available capacity with just one input. Running 2 at a time seems to give the -// CPU a good enough -// workout. -// -template -error_code json_structural_indexer::index(const uint8_t *buf, - size_t len, - dom_parser_implementation &parser, - stage1_mode partial) noexcept { - if (simdjson_unlikely(len > parser.capacity())) { - return CAPACITY; - } - // We guard the rest of the code so that we can assume that len > 0 - // throughout. - if (len == 0) { - return EMPTY; - } - if (is_streaming(partial)) { - len = trim_partial_utf8(buf, len); - // If you end up with an empty window after trimming - // the partial UTF-8 bytes, then chances are good that you - // have an UTF-8 formatting error. - if (len == 0) { - return UTF8_ERROR; - } - } - buf_block_reader reader(buf, len); - json_structural_indexer indexer(parser.structural_indexes.get()); - - // Read all but the last block - while (reader.has_full_block()) { - indexer.step(reader.full_block(), reader); - } - // Take care of the last block (will always be there unless file is empty - // which is - // not supposed to happen.) - uint8_t block[STEP_SIZE]; - if (simdjson_unlikely(reader.get_remainder(block) == 0)) { - return UNEXPECTED_ERROR; - } - indexer.step(block, reader); - return indexer.finish(parser, reader.block_index(), len, partial); -} - -template <> -simdjson_really_inline void json_structural_indexer::step<128>( - const uint8_t *block, buf_block_reader<128> &reader) noexcept { - simd::simd8x64 in_1(block); - simd::simd8x64 in_2(block + 64); - json_block block_1 = scanner.next(in_1); - json_block block_2 = scanner.next(in_2); - this->next(in_1, block_1, reader.block_index()); - this->next(in_2, block_2, reader.block_index() + 64); - reader.advance(); -} - -template <> -simdjson_really_inline void json_structural_indexer::step<64>( - const uint8_t *block, buf_block_reader<64> &reader) noexcept { - simd::simd8x64 in_1(block); - json_block block_1 = scanner.next(in_1); - this->next(in_1, block_1, reader.block_index()); - reader.advance(); -} - -simdjson_really_inline void json_structural_indexer::next( - const simd::simd8x64 &in, const json_block &block, size_t idx) { - uint64_t unescaped = in.lteq(0x1F); - checker.check_next_input(in); - indexer.write(uint32_t(idx - 64), prev_structurals); // Output *last* - // iteration's - // structurals to the - // parser - prev_structurals = block.structural_start(); - unescaped_chars_error |= block.non_quote_inside_string(unescaped); -} - -simdjson_really_inline error_code -json_structural_indexer::finish(dom_parser_implementation &parser, - size_t idx, - size_t len, - stage1_mode partial) { - // Write out the final iteration's structurals - indexer.write(uint32_t(idx - 64), prev_structurals); - error_code error = scanner.finish(); - // We deliberately break down the next expression so that it is - // human readable. - const bool should_we_exit = - is_streaming(partial) - ? ((error != SUCCESS) && - (error != - UNCLOSED_STRING)) // when partial we tolerate UNCLOSED_STRING - : (error != SUCCESS); // if partial is false, we must have SUCCESS - const bool have_unclosed_string = (error == UNCLOSED_STRING); - if (simdjson_unlikely(should_we_exit)) { - return error; - } - - if (unescaped_chars_error) { - return UNESCAPED_CHARS; - } - parser.n_structural_indexes = - uint32_t(indexer.tail - parser.structural_indexes.get()); - /*** - * The On Demand API requires special padding. - * - * This is related to https://github.com/simdjson/simdjson/issues/906 - * Basically, we want to make sure that if the parsing continues beyond the - *last (valid) - * structural character, it quickly stops. - * Only three structural characters can be repeated without triggering an - *error in JSON: [,] and }. - * We repeat the padding character (at 'len'). We don't know what it is, but - *if the parsing - * continues, then it must be [,] or }. - * Suppose it is ] or }. We backtrack to the first character, what could it - *be that would - * not trigger an error? It could be ] or } but no, because you can't start - *a document that way. - * It can't be a comma, a colon or any simple value. So the only way we - *could continue is - * if the repeated character is [. But if so, the document must start with - *[. But if the document - * starts with [, it should end with ]. If we enforce that rule, then we - *would get - * ][[ which is invalid. - * - * This is illustrated with the test array_iterate_unclosed_error() on the - *following input: - * R"({ "a": [,,)" - **/ - parser.structural_indexes[parser.n_structural_indexes] = - uint32_t(len); // used later in partial == stage1_mode::streaming_final - parser.structural_indexes[parser.n_structural_indexes + 1] = uint32_t(len); - parser.structural_indexes[parser.n_structural_indexes + 2] = 0; - parser.next_structural_index = 0; - // a valid JSON file cannot have zero structural indexes - we should have - // found something - if (simdjson_unlikely(parser.n_structural_indexes == 0u)) { - return EMPTY; - } - if (simdjson_unlikely( - parser.structural_indexes[parser.n_structural_indexes - 1] > len)) { - return UNEXPECTED_ERROR; - } - if (partial == stage1_mode::streaming_partial) { - // If we have an unclosed string, then the last structural - // will be the quote and we want to make sure to omit it. - if (have_unclosed_string) { - parser.n_structural_indexes--; - // a valid JSON file cannot have zero structural indexes - we should - // have found something - if (simdjson_unlikely(parser.n_structural_indexes == 0u)) { - return CAPACITY; - } - } - // We truncate the input to the end of the last complete document (or - // zero). - auto new_structural_indexes = find_next_document_index(parser); - if (new_structural_indexes == 0 && parser.n_structural_indexes > 0) { - if (parser.structural_indexes[0] == 0) { - // If the buffer is partial and we started at index 0 but the - // document is - // incomplete, it's too big to parse. - return CAPACITY; - } else { - // It is possible that the document could be parsed, we just had - // a lot - // of white space. - parser.n_structural_indexes = 0; - return EMPTY; - } - } - - parser.n_structural_indexes = new_structural_indexes; - } else if (partial == stage1_mode::streaming_final) { - if (have_unclosed_string) { - parser.n_structural_indexes--; - } - // We truncate the input to the end of the last complete document (or - // zero). - // Because partial == stage1_mode::streaming_final, it means that we may - // silently ignore trailing garbage. Though it sounds bad, we do it - // deliberately because many people who have streams of JSON documents - // will truncate them for processing. E.g., imagine that you are - // uncompressing - // the data from a size file or receiving it in chunks from the network. - // You - // may not know where exactly the last document will be. Meanwhile the - // document_stream instances allow people to know the JSON documents - // they are - // parsing (see the iterator.source() method). - parser.n_structural_indexes = find_next_document_index(parser); - // We store the initial n_structural_indexes so that the client can see - // whether we used truncation. If initial_n_structural_indexes == - // parser.n_structural_indexes, - // then this will query - // parser.structural_indexes[parser.n_structural_indexes] which is len, - // otherwise, it will copy some prior index. - parser.structural_indexes[parser.n_structural_indexes + 1] = - parser.structural_indexes[parser.n_structural_indexes]; - // This next line is critical, do not change it unless you understand - // what you are - // doing. - parser.structural_indexes[parser.n_structural_indexes] = uint32_t(len); - if (simdjson_unlikely(parser.n_structural_indexes == 0u)) { - // We tolerate an unclosed string at the very end of the stream. - // Indeed, users - // often load their data in bulk without being careful and they want - // us to ignore - // the trailing garbage. - return EMPTY; - } - } - checker.check_eof(); - return checker.errors(); -} - -} // namespace stage1 -} // unnamed namespace -} // namespace westmere -} // namespace simdjson -/* end file src/generic/stage1/json_structural_indexer.h */ -/* begin file src/generic/stage1/utf8_validator.h */ -namespace simdjson { -namespace westmere { -namespace { -namespace stage1 { - -/** - * Validates that the string is actual UTF-8. - */ -template -bool generic_validate_utf8(const uint8_t *input, size_t length) { - checker c{}; - buf_block_reader<64> reader(input, length); - while (reader.has_full_block()) { - simd::simd8x64 in(reader.full_block()); - c.check_next_input(in); - reader.advance(); - } - uint8_t block[64]{}; - reader.get_remainder(block); - simd::simd8x64 in(block); - c.check_next_input(in); - reader.advance(); - c.check_eof(); - return c.errors() == error_code::SUCCESS; -} - -bool generic_validate_utf8(const char *input, size_t length) { - return generic_validate_utf8( - reinterpret_cast(input), length); -} - -} // namespace stage1 -} // unnamed namespace -} // namespace westmere -} // namespace simdjson -/* end file src/generic/stage1/utf8_validator.h */ - -// -// Stage 2 -// -/* begin file src/generic/stage2/tape_builder.h */ -/* begin file src/generic/stage2/json_iterator.h */ -/* begin file src/generic/stage2/logger.h */ -// This is for an internal-only stage 2 specific logger. -// Set LOG_ENABLED = true to log what stage 2 is doing! -namespace simdjson { -namespace westmere { -namespace { -namespace logger { - -static constexpr const char *DASHES = - "--------------------------------------------------------------------------" - "--------------------------------------------------------------------------" - "--------------------------------------------------------------------------" - "----------------------------------"; - -#if SIMDJSON_VERBOSE_LOGGING -static constexpr const bool LOG_ENABLED = true; -#else -static constexpr const bool LOG_ENABLED = false; -#endif -static constexpr const int LOG_EVENT_LEN = 20; -static constexpr const int LOG_BUFFER_LEN = 30; -static constexpr const int LOG_SMALL_BUFFER_LEN = 10; -static constexpr const int LOG_INDEX_LEN = 5; - -static int log_depth; // Not threadsafe. Log only. - -// Helper to turn unprintable or newline characters into spaces -static simdjson_really_inline char printable_char(char c) { - if (c >= 0x20) { - return c; - } else { - return ' '; - } -} - -// Print the header and set up log_start -static simdjson_really_inline void log_start() { - if (LOG_ENABLED) { - log_depth = 0; - printf("\n"); - printf("| %-*s | %-*s | %-*s | %-*s | Detail |\n", - LOG_EVENT_LEN, - "Event", - LOG_BUFFER_LEN, - "Buffer", - LOG_SMALL_BUFFER_LEN, - "Next", - 5, - "Next#"); - printf("|%.*s|%.*s|%.*s|%.*s|--------|\n", - LOG_EVENT_LEN + 2, - DASHES, - LOG_BUFFER_LEN + 2, - DASHES, - LOG_SMALL_BUFFER_LEN + 2, - DASHES, - 5 + 2, - DASHES); - } -} - -simdjson_unused static simdjson_really_inline void log_string( - const char *message) { - if (LOG_ENABLED) { - printf("%s\n", message); - } -} - -// Logs a single line from the stage 2 DOM parser -template -static simdjson_really_inline void log_line(S &structurals, - const char *title_prefix, - const char *title, - const char *detail) { - if (LOG_ENABLED) { - printf("| %*s%s%-*s ", - log_depth * 2, - "", - title_prefix, - LOG_EVENT_LEN - log_depth * 2 - int(strlen(title_prefix)), - title); - auto current_index = structurals.at_beginning() - ? nullptr - : structurals.next_structural - 1; - auto next_index = structurals.next_structural; - auto current = current_index ? &structurals.buf[*current_index] - : reinterpret_cast( - " " - " "); - auto next = &structurals.buf[*next_index]; - { - // Print the next N characters in the buffer. - printf("| "); - // Otherwise, print the characters starting from the buffer - // position. - // Print spaces for unprintable or newline characters. - for (int i = 0; i < LOG_BUFFER_LEN; i++) { - printf("%c", printable_char(current[i])); - } - printf(" "); - // Print the next N characters in the buffer. - printf("| "); - // Otherwise, print the characters starting from the buffer - // position. - // Print spaces for unprintable or newline characters. - for (int i = 0; i < LOG_SMALL_BUFFER_LEN; i++) { - printf("%c", printable_char(next[i])); - } - printf(" "); - } - if (current_index) { - printf("| %*u ", LOG_INDEX_LEN, *current_index); - } else { - printf("| %-*s ", LOG_INDEX_LEN, ""); - } - // printf("| %*u ", LOG_INDEX_LEN, structurals.next_tape_index()); - printf("| %-s ", detail); - printf("|\n"); - } -} - -} // namespace logger -} // unnamed namespace -} // namespace westmere -} // namespace simdjson -/* end file src/generic/stage2/logger.h */ - -namespace simdjson { -namespace westmere { -namespace { -namespace stage2 { - -class json_iterator { - public: - const uint8_t *const buf; - uint32_t *next_structural; - dom_parser_implementation &dom_parser; - uint32_t depth{0}; - - /** - * Walk the JSON document. - * - * The visitor receives callbacks when values are encountered. All callbacks - * pass the iterator as - * the first parameter; some callbacks have other parameters as well: - * - * - visit_document_start() - at the beginning. - * - visit_document_end() - at the end (if things were successful). - * - * - visit_array_start() - at the start `[` of a non-empty array. - * - visit_array_end() - at the end `]` of a non-empty array. - * - visit_empty_array() - when an empty array is encountered. - * - * - visit_object_end() - at the start `]` of a non-empty object. - * - visit_object_start() - at the end `]` of a non-empty object. - * - visit_empty_object() - when an empty object is encountered. - * - visit_key(const uint8_t *key) - when a key in an object field is - * encountered. key is - * guaranteed to point at the first quote - * of the string (`"key"`). - * - visit_primitive(const uint8_t *value) - when a value is a string, - * number, boolean or null. - * - visit_root_primitive(iter, uint8_t *value) - when the top-level value - * is a string, number, boolean or null. - * - * - increment_count(iter) - each time a value is found in an array or - * object. - */ - template - simdjson_warn_unused simdjson_really_inline error_code - walk_document(V &visitor) noexcept; - - /** - * Create an iterator capable of walking a JSON document. - * - * The document must have already passed through stage 1. - */ - simdjson_really_inline json_iterator(dom_parser_implementation &_dom_parser, - size_t start_structural_index); - - /** - * Look at the next token. - * - * Tokens can be strings, numbers, booleans, null, or operators (`[{]},:`)). - * - * They may include invalid JSON as well (such as `1.2.3` or `ture`). - */ - simdjson_really_inline const uint8_t *peek() const noexcept; - /** - * Advance to the next token. - * - * Tokens can be strings, numbers, booleans, null, or operators (`[{]},:`)). - * - * They may include invalid JSON as well (such as `1.2.3` or `ture`). - */ - simdjson_really_inline const uint8_t *advance() noexcept; - /** - * Get the remaining length of the document, from the start of the current - * token. - */ - simdjson_really_inline size_t remaining_len() const noexcept; - /** - * Check if we are at the end of the document. - * - * If this is true, there are no more tokens. - */ - simdjson_really_inline bool at_eof() const noexcept; - /** - * Check if we are at the beginning of the document. - */ - simdjson_really_inline bool at_beginning() const noexcept; - simdjson_really_inline uint8_t last_structural() const noexcept; - - /** - * Log that a value has been found. - * - * Set LOG_ENABLED=true in logger.h to see logging. - */ - simdjson_really_inline void log_value(const char *type) const noexcept; - /** - * Log the start of a multipart value. - * - * Set LOG_ENABLED=true in logger.h to see logging. - */ - simdjson_really_inline void log_start_value(const char *type) const - noexcept; - /** - * Log the end of a multipart value. - * - * Set LOG_ENABLED=true in logger.h to see logging. - */ - simdjson_really_inline void log_end_value(const char *type) const noexcept; - /** - * Log an error. - * - * Set LOG_ENABLED=true in logger.h to see logging. - */ - simdjson_really_inline void log_error(const char *error) const noexcept; - - template - simdjson_warn_unused simdjson_really_inline error_code - visit_root_primitive(V &visitor, const uint8_t *value) noexcept; - template - simdjson_warn_unused simdjson_really_inline error_code - visit_primitive(V &visitor, const uint8_t *value) noexcept; -}; - -template -simdjson_warn_unused simdjson_really_inline error_code -json_iterator::walk_document(V &visitor) noexcept { - logger::log_start(); - - // - // Start the document - // - if (at_eof()) { - return EMPTY; - } - log_start_value("document"); - SIMDJSON_TRY(visitor.visit_document_start(*this)); - - // - // Read first value - // - { - auto value = advance(); - - // Make sure the outer object or array is closed before continuing; - // otherwise, there are ways we - // could get into memory corruption. See - // https://github.com/simdjson/simdjson/issues/906 - if (!STREAMING) { - switch (*value) { - case '{': - if (last_structural() != '}') { - log_value("starting brace unmatched"); - return TAPE_ERROR; - }; - break; - case '[': - if (last_structural() != ']') { - log_value("starting bracket unmatched"); - return TAPE_ERROR; - }; - break; - } - } - - switch (*value) { - case '{': - if (*peek() == '}') { - advance(); - log_value("empty object"); - SIMDJSON_TRY(visitor.visit_empty_object(*this)); - break; - } - goto object_begin; - case '[': - if (*peek() == ']') { - advance(); - log_value("empty array"); - SIMDJSON_TRY(visitor.visit_empty_array(*this)); - break; - } - goto array_begin; - default: - SIMDJSON_TRY(visitor.visit_root_primitive(*this, value)); - break; - } - } - goto document_end; - -// -// Object parser states -// -object_begin: - log_start_value("object"); - depth++; - if (depth >= dom_parser.max_depth()) { - log_error("Exceeded max depth!"); - return DEPTH_ERROR; - } - dom_parser.is_array[depth] = false; - SIMDJSON_TRY(visitor.visit_object_start(*this)); - - { - auto key = advance(); - if (*key != '"') { - log_error("Object does not start with a key"); - return TAPE_ERROR; - } - SIMDJSON_TRY(visitor.increment_count(*this)); - SIMDJSON_TRY(visitor.visit_key(*this, key)); - } - -object_field: - if (simdjson_unlikely(*advance() != ':')) { - log_error("Missing colon after key in object"); - return TAPE_ERROR; - } - { - auto value = advance(); - switch (*value) { - case '{': - if (*peek() == '}') { - advance(); - log_value("empty object"); - SIMDJSON_TRY(visitor.visit_empty_object(*this)); - break; - } - goto object_begin; - case '[': - if (*peek() == ']') { - advance(); - log_value("empty array"); - SIMDJSON_TRY(visitor.visit_empty_array(*this)); - break; - } - goto array_begin; - default: - SIMDJSON_TRY(visitor.visit_primitive(*this, value)); - break; - } - } - -object_continue: - switch (*advance()) { - case ',': - SIMDJSON_TRY(visitor.increment_count(*this)); - { - auto key = advance(); - if (simdjson_unlikely(*key != '"')) { - log_error( - "Key string missing at beginning of field in object"); - return TAPE_ERROR; - } - SIMDJSON_TRY(visitor.visit_key(*this, key)); - } - goto object_field; - case '}': - log_end_value("object"); - SIMDJSON_TRY(visitor.visit_object_end(*this)); - goto scope_end; - default: - log_error("No comma between object fields"); - return TAPE_ERROR; - } - -scope_end: - depth--; - if (depth == 0) { - goto document_end; - } - if (dom_parser.is_array[depth]) { - goto array_continue; - } - goto object_continue; - -// -// Array parser states -// -array_begin: - log_start_value("array"); - depth++; - if (depth >= dom_parser.max_depth()) { - log_error("Exceeded max depth!"); - return DEPTH_ERROR; - } - dom_parser.is_array[depth] = true; - SIMDJSON_TRY(visitor.visit_array_start(*this)); - SIMDJSON_TRY(visitor.increment_count(*this)); - -array_value : { - auto value = advance(); - switch (*value) { - case '{': - if (*peek() == '}') { - advance(); - log_value("empty object"); - SIMDJSON_TRY(visitor.visit_empty_object(*this)); - break; - } - goto object_begin; - case '[': - if (*peek() == ']') { - advance(); - log_value("empty array"); - SIMDJSON_TRY(visitor.visit_empty_array(*this)); - break; - } - goto array_begin; - default: - SIMDJSON_TRY(visitor.visit_primitive(*this, value)); - break; - } -} - -array_continue: - switch (*advance()) { - case ',': - SIMDJSON_TRY(visitor.increment_count(*this)); - goto array_value; - case ']': - log_end_value("array"); - SIMDJSON_TRY(visitor.visit_array_end(*this)); - goto scope_end; - default: - log_error("Missing comma between array values"); - return TAPE_ERROR; - } - -document_end: - log_end_value("document"); - SIMDJSON_TRY(visitor.visit_document_end(*this)); - - dom_parser.next_structural_index = - uint32_t(next_structural - &dom_parser.structural_indexes[0]); - - // If we didn't make it to the end, it's an error - if (!STREAMING && - dom_parser.next_structural_index != dom_parser.n_structural_indexes) { - log_error( - "More than one JSON value at the root of the document, or extra " - "characters at the end of the JSON!"); - return TAPE_ERROR; - } - - return SUCCESS; - -} // walk_document() - -simdjson_really_inline json_iterator::json_iterator( - dom_parser_implementation &_dom_parser, size_t start_structural_index) - : buf{_dom_parser.buf}, - next_structural{&_dom_parser.structural_indexes[start_structural_index]}, - dom_parser{_dom_parser} {} - -simdjson_really_inline const uint8_t *json_iterator::peek() const noexcept { - return &buf[*(next_structural)]; -} -simdjson_really_inline const uint8_t *json_iterator::advance() noexcept { - return &buf[*(next_structural++)]; -} -simdjson_really_inline size_t json_iterator::remaining_len() const noexcept { - return dom_parser.len - *(next_structural - 1); -} - -simdjson_really_inline bool json_iterator::at_eof() const noexcept { - return next_structural == - &dom_parser.structural_indexes[dom_parser.n_structural_indexes]; -} -simdjson_really_inline bool json_iterator::at_beginning() const noexcept { - return next_structural == dom_parser.structural_indexes.get(); -} -simdjson_really_inline uint8_t json_iterator::last_structural() const noexcept { - return buf[dom_parser - .structural_indexes[dom_parser.n_structural_indexes - 1]]; -} - -simdjson_really_inline void json_iterator::log_value(const char *type) const - noexcept { - logger::log_line(*this, "", type, ""); -} - -simdjson_really_inline void json_iterator::log_start_value( - const char *type) const noexcept { - logger::log_line(*this, "+", type, ""); - if (logger::LOG_ENABLED) { - logger::log_depth++; - } -} - -simdjson_really_inline void json_iterator::log_end_value(const char *type) const - noexcept { - if (logger::LOG_ENABLED) { - logger::log_depth--; - } - logger::log_line(*this, "-", type, ""); -} - -simdjson_really_inline void json_iterator::log_error(const char *error) const - noexcept { - logger::log_line(*this, "", "ERROR", error); -} - -template -simdjson_warn_unused simdjson_really_inline error_code -json_iterator::visit_root_primitive(V &visitor, const uint8_t *value) noexcept { - switch (*value) { - case '"': - return visitor.visit_root_string(*this, value); - case 't': - return visitor.visit_root_true_atom(*this, value); - case 'f': - return visitor.visit_root_false_atom(*this, value); - case 'n': - return visitor.visit_root_null_atom(*this, value); - case '-': - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - return visitor.visit_root_number(*this, value); - default: - log_error("Document starts with a non-value character"); - return TAPE_ERROR; - } -} -template -simdjson_warn_unused simdjson_really_inline error_code -json_iterator::visit_primitive(V &visitor, const uint8_t *value) noexcept { - switch (*value) { - case '"': - return visitor.visit_string(*this, value); - case 't': - return visitor.visit_true_atom(*this, value); - case 'f': - return visitor.visit_false_atom(*this, value); - case 'n': - return visitor.visit_null_atom(*this, value); - case '-': - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - return visitor.visit_number(*this, value); - default: - log_error("Non-value found when value was expected!"); - return TAPE_ERROR; - } -} - -} // namespace stage2 -} // unnamed namespace -} // namespace westmere -} // namespace simdjson -/* end file src/generic/stage2/json_iterator.h */ -/* begin file src/generic/stage2/tape_writer.h */ -namespace simdjson { -namespace westmere { -namespace { -namespace stage2 { - -struct tape_writer { - /** The next place to write to tape */ - uint64_t *next_tape_loc; - - /** Write a signed 64-bit value to tape. */ - simdjson_really_inline void append_s64(int64_t value) noexcept; - - /** Write an unsigned 64-bit value to tape. */ - simdjson_really_inline void append_u64(uint64_t value) noexcept; - - /** Write a double value to tape. */ - simdjson_really_inline void append_double(double value) noexcept; - - /** - * Append a tape entry (an 8-bit type,and 56 bits worth of value). - */ - simdjson_really_inline void append(uint64_t val, - internal::tape_type t) noexcept; - - /** - * Skip the current tape entry without writing. - * - * Used to skip the start of the container, since we'll come back later to - * fill it in when the - * container ends. - */ - simdjson_really_inline void skip() noexcept; - - /** - * Skip the number of tape entries necessary to write a large u64 or i64. - */ - simdjson_really_inline void skip_large_integer() noexcept; - - /** - * Skip the number of tape entries necessary to write a double. - */ - simdjson_really_inline void skip_double() noexcept; - - /** - * Write a value to a known location on tape. - * - * Used to go back and write out the start of a container after the - * container ends. - */ - simdjson_really_inline static void write(uint64_t &tape_loc, - uint64_t val, - internal::tape_type t) noexcept; - - private: - /** - * Append both the tape entry, and a supplementary value following it. Used - * for types that need - * all 64 bits, such as double and uint64_t. - */ - template - simdjson_really_inline void append2(uint64_t val, - T val2, - internal::tape_type t) noexcept; -}; // struct number_writer - -simdjson_really_inline void tape_writer::append_s64(int64_t value) noexcept { - append2(0, value, internal::tape_type::INT64); -} - -simdjson_really_inline void tape_writer::append_u64(uint64_t value) noexcept { - append(0, internal::tape_type::UINT64); - *next_tape_loc = value; - next_tape_loc++; -} - -/** Write a double value to tape. */ -simdjson_really_inline void tape_writer::append_double(double value) noexcept { - append2(0, value, internal::tape_type::DOUBLE); -} - -simdjson_really_inline void tape_writer::skip() noexcept { next_tape_loc++; } - -simdjson_really_inline void tape_writer::skip_large_integer() noexcept { - next_tape_loc += 2; -} - -simdjson_really_inline void tape_writer::skip_double() noexcept { - next_tape_loc += 2; -} - -simdjson_really_inline void tape_writer::append( - uint64_t val, internal::tape_type t) noexcept { - *next_tape_loc = val | ((uint64_t(char(t))) << 56); - next_tape_loc++; -} - -template -simdjson_really_inline void tape_writer::append2( - uint64_t val, T val2, internal::tape_type t) noexcept { - append(val, t); - static_assert(sizeof(val2) == sizeof(*next_tape_loc), - "Type is not 64 bits!"); - memcpy(next_tape_loc, &val2, sizeof(val2)); - next_tape_loc++; -} - -simdjson_really_inline void tape_writer::write(uint64_t &tape_loc, - uint64_t val, - internal::tape_type t) noexcept { - tape_loc = val | ((uint64_t(char(t))) << 56); -} - -} // namespace stage2 -} // unnamed namespace -} // namespace westmere -} // namespace simdjson -/* end file src/generic/stage2/tape_writer.h */ - -namespace simdjson { -namespace westmere { -namespace { -namespace stage2 { - -struct tape_builder { - template - simdjson_warn_unused static simdjson_really_inline error_code - parse_document(dom_parser_implementation &dom_parser, - dom::document &doc) noexcept; - - /** Called when a non-empty document starts. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_document_start(json_iterator &iter) noexcept; - /** Called when a non-empty document ends without error. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_document_end(json_iterator &iter) noexcept; - - /** Called when a non-empty array starts. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_array_start(json_iterator &iter) noexcept; - /** Called when a non-empty array ends. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_array_end(json_iterator &iter) noexcept; - /** Called when an empty array is found. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_empty_array(json_iterator &iter) noexcept; - - /** Called when a non-empty object starts. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_object_start(json_iterator &iter) noexcept; - /** - * Called when a key in a field is encountered. - * - * primitive, visit_object_start, visit_empty_object, visit_array_start, or - * visit_empty_array - * will be called after this with the field value. - */ - simdjson_warn_unused simdjson_really_inline error_code - visit_key(json_iterator &iter, const uint8_t *key) noexcept; - /** Called when a non-empty object ends. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_object_end(json_iterator &iter) noexcept; - /** Called when an empty object is found. */ - simdjson_warn_unused simdjson_really_inline error_code - visit_empty_object(json_iterator &iter) noexcept; - - /** - * Called when a string, number, boolean or null is found. - */ - simdjson_warn_unused simdjson_really_inline error_code - visit_primitive(json_iterator &iter, const uint8_t *value) noexcept; - /** - * Called when a string, number, boolean or null is found at the top level - * of a document (i.e. - * when there is no array or object and the entire document is a single - * string, number, boolean or - * null. - * - * This is separate from primitive() because simdjson's normal primitive - * parsing routines assume - * there is at least one more token after the value, which is only true in - * an array or object. - */ - simdjson_warn_unused simdjson_really_inline error_code - visit_root_primitive(json_iterator &iter, const uint8_t *value) noexcept; - - simdjson_warn_unused simdjson_really_inline error_code visit_string( - json_iterator &iter, const uint8_t *value, bool key = false) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_number(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_true_atom(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_false_atom(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_null_atom(json_iterator &iter, const uint8_t *value) noexcept; - - simdjson_warn_unused simdjson_really_inline error_code - visit_root_string(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_root_number(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_root_true_atom(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_root_false_atom(json_iterator &iter, const uint8_t *value) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - visit_root_null_atom(json_iterator &iter, const uint8_t *value) noexcept; - - /** Called each time a new field or element in an array or object is found. - */ - simdjson_warn_unused simdjson_really_inline error_code - increment_count(json_iterator &iter) noexcept; - - /** Next location to write to tape */ - tape_writer tape; - - private: - /** Next write location in the string buf for stage 2 parsing */ - uint8_t *current_string_buf_loc; - - simdjson_really_inline tape_builder(dom::document &doc) noexcept; - - simdjson_really_inline uint32_t next_tape_index(json_iterator &iter) const - noexcept; - simdjson_really_inline void start_container(json_iterator &iter) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - end_container(json_iterator &iter, - internal::tape_type start, - internal::tape_type end) noexcept; - simdjson_warn_unused simdjson_really_inline error_code - empty_container(json_iterator &iter, - internal::tape_type start, - internal::tape_type end) noexcept; - simdjson_really_inline uint8_t *on_start_string( - json_iterator &iter) noexcept; - simdjson_really_inline void on_end_string(uint8_t *dst) noexcept; -}; // class tape_builder - -template -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::parse_document(dom_parser_implementation &dom_parser, - dom::document &doc) noexcept { - dom_parser.doc = &doc; - json_iterator iter(dom_parser, - STREAMING ? dom_parser.next_structural_index : 0); - tape_builder builder(doc); - return iter.walk_document(builder); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_primitive(json_iterator &iter, - const uint8_t *value) noexcept { - return iter.visit_root_primitive(*this, value); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_primitive(json_iterator &iter, - const uint8_t *value) noexcept { - return iter.visit_primitive(*this, value); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_empty_object(json_iterator &iter) noexcept { - return empty_container(iter, - internal::tape_type::START_OBJECT, - internal::tape_type::END_OBJECT); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_empty_array(json_iterator &iter) noexcept { - return empty_container( - iter, internal::tape_type::START_ARRAY, internal::tape_type::END_ARRAY); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_document_start(json_iterator &iter) noexcept { - start_container(iter); - return SUCCESS; -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_object_start(json_iterator &iter) noexcept { - start_container(iter); - return SUCCESS; -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_array_start(json_iterator &iter) noexcept { - start_container(iter); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_object_end(json_iterator &iter) noexcept { - return end_container(iter, - internal::tape_type::START_OBJECT, - internal::tape_type::END_OBJECT); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_array_end(json_iterator &iter) noexcept { - return end_container( - iter, internal::tape_type::START_ARRAY, internal::tape_type::END_ARRAY); -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_document_end(json_iterator &iter) noexcept { - constexpr uint32_t start_tape_index = 0; - tape.append(start_tape_index, internal::tape_type::ROOT); - tape_writer::write(iter.dom_parser.doc->tape[start_tape_index], - next_tape_index(iter), - internal::tape_type::ROOT); - return SUCCESS; -} -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_key(json_iterator &iter, const uint8_t *key) noexcept { - return visit_string(iter, key, true); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::increment_count(json_iterator &iter) noexcept { - iter.dom_parser.open_containers[iter.depth] - .count++; // we have a key value pair in the object at - // parser.dom_parser.depth - 1 - return SUCCESS; -} - -simdjson_really_inline tape_builder::tape_builder(dom::document &doc) noexcept - : tape{doc.tape.get()}, - current_string_buf_loc{doc.string_buf.get()} {} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_string(json_iterator &iter, - const uint8_t *value, - bool key) noexcept { - iter.log_value(key ? "key" : "string"); - uint8_t *dst = on_start_string(iter); - dst = stringparsing::parse_string(value + 1, dst); - if (dst == nullptr) { - iter.log_error("Invalid escape in string"); - return STRING_ERROR; - } - on_end_string(dst); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_string(json_iterator &iter, - const uint8_t *value) noexcept { - return visit_string(iter, value); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_number(json_iterator &iter, const uint8_t *value) noexcept { - iter.log_value("number"); - return numberparsing::parse_number(value, tape); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_number(json_iterator &iter, - const uint8_t *value) noexcept { - // - // We need to make a copy to make sure that the string is space terminated. - // This is not about padding the input, which should already padded up - // to len + SIMDJSON_PADDING. However, we have no control at this stage - // on how the padding was done. What if the input string was padded with - // nulls? - // It is quite common for an input string to have an extra null character (C - // string). - // We do not want to allow 9\0 (where \0 is the null character) inside a - // JSON - // document, but the string "9\0" by itself is fine. So we make a copy and - // pad the input with spaces when we know that there is just one input - // element. - // This copy is relatively expensive, but it will almost never be called in - // practice unless you are in the strange scenario where you have many JSON - // documents made of single atoms. - // - std::unique_ptr copy( - new (std::nothrow) uint8_t[iter.remaining_len() + SIMDJSON_PADDING]); - if (copy.get() == nullptr) { - return MEMALLOC; - } - std::memcpy(copy.get(), value, iter.remaining_len()); - std::memset(copy.get() + iter.remaining_len(), ' ', SIMDJSON_PADDING); - error_code error = visit_number(iter, copy.get()); - return error; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_true_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("true"); - if (!atomparsing::is_valid_true_atom(value)) { - return T_ATOM_ERROR; - } - tape.append(0, internal::tape_type::TRUE_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_true_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("true"); - if (!atomparsing::is_valid_true_atom(value, iter.remaining_len())) { - return T_ATOM_ERROR; - } - tape.append(0, internal::tape_type::TRUE_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_false_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("false"); - if (!atomparsing::is_valid_false_atom(value)) { - return F_ATOM_ERROR; - } - tape.append(0, internal::tape_type::FALSE_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_false_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("false"); - if (!atomparsing::is_valid_false_atom(value, iter.remaining_len())) { - return F_ATOM_ERROR; - } - tape.append(0, internal::tape_type::FALSE_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_null_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("null"); - if (!atomparsing::is_valid_null_atom(value)) { - return N_ATOM_ERROR; - } - tape.append(0, internal::tape_type::NULL_VALUE); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::visit_root_null_atom(json_iterator &iter, - const uint8_t *value) noexcept { - iter.log_value("null"); - if (!atomparsing::is_valid_null_atom(value, iter.remaining_len())) { - return N_ATOM_ERROR; - } - tape.append(0, internal::tape_type::NULL_VALUE); - return SUCCESS; -} - -// private: - -simdjson_really_inline uint32_t -tape_builder::next_tape_index(json_iterator &iter) const noexcept { - return uint32_t(tape.next_tape_loc - iter.dom_parser.doc->tape.get()); -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::empty_container(json_iterator &iter, - internal::tape_type start, - internal::tape_type end) noexcept { - auto start_index = next_tape_index(iter); - tape.append(start_index + 2, start); - tape.append(start_index, end); - return SUCCESS; -} - -simdjson_really_inline void tape_builder::start_container( - json_iterator &iter) noexcept { - iter.dom_parser.open_containers[iter.depth].tape_index = - next_tape_index(iter); - iter.dom_parser.open_containers[iter.depth].count = 0; - tape.skip(); // We don't actually *write* the start element until the end. -} - -simdjson_warn_unused simdjson_really_inline error_code -tape_builder::end_container(json_iterator &iter, - internal::tape_type start, - internal::tape_type end) noexcept { - // Write the ending tape element, pointing at the start location - const uint32_t start_tape_index = - iter.dom_parser.open_containers[iter.depth].tape_index; - tape.append(start_tape_index, end); - // Write the start tape element, pointing at the end location (and including - // count) - // count can overflow if it exceeds 24 bits... so we saturate - // the convention being that a cnt of 0xffffff or more is undetermined in - // value (>= 0xffffff). - const uint32_t count = iter.dom_parser.open_containers[iter.depth].count; - const uint32_t cntsat = count > 0xFFFFFF ? 0xFFFFFF : count; - tape_writer::write(iter.dom_parser.doc->tape[start_tape_index], - next_tape_index(iter) | (uint64_t(cntsat) << 32), - start); - return SUCCESS; -} - -simdjson_really_inline uint8_t *tape_builder::on_start_string( - json_iterator &iter) noexcept { - // we advance the point, accounting for the fact that we have a NULL - // termination - tape.append(current_string_buf_loc - iter.dom_parser.doc->string_buf.get(), - internal::tape_type::STRING); - return current_string_buf_loc + sizeof(uint32_t); -} - -simdjson_really_inline void tape_builder::on_end_string(uint8_t *dst) noexcept { - uint32_t str_length = - uint32_t(dst - (current_string_buf_loc + sizeof(uint32_t))); - // TODO check for overflow in case someone has a crazy string (>=4GB?) - // But only add the overflow check when the document itself exceeds 4GB - // Currently unneeded because we refuse to parse docs larger or equal to - // 4GB. - memcpy(current_string_buf_loc, &str_length, sizeof(uint32_t)); - // NULL termination is still handy if you expect all your strings to - // be NULL terminated? It comes at a small cost - *dst = 0; - current_string_buf_loc = dst + 1; -} - -} // namespace stage2 -} // unnamed namespace -} // namespace westmere -} // namespace simdjson -/* end file src/generic/stage2/tape_builder.h */ - -// -// Implementation-specific overrides -// - -namespace simdjson { -namespace westmere { -namespace { -namespace stage1 { - -simdjson_really_inline uint64_t -json_string_scanner::find_escaped(uint64_t backslash) { - if (!backslash) { - uint64_t escaped = prev_escaped; - prev_escaped = 0; - return escaped; - } - return find_escaped_branchless(backslash); -} - -} // namespace stage1 -} // unnamed namespace - -simdjson_warn_unused error_code implementation::minify(const uint8_t *buf, - size_t len, - uint8_t *dst, - size_t &dst_len) const - noexcept { - return westmere::stage1::json_minifier::minify<64>(buf, len, dst, dst_len); -} - -simdjson_warn_unused error_code dom_parser_implementation::stage1( - const uint8_t *_buf, size_t _len, stage1_mode streaming) noexcept { - this->buf = _buf; - this->len = _len; - return westmere::stage1::json_structural_indexer::index<64>( - _buf, _len, *this, streaming); -} - -simdjson_warn_unused bool implementation::validate_utf8(const char *buf, - size_t len) const - noexcept { - return westmere::stage1::generic_validate_utf8(buf, len); -} - -simdjson_warn_unused error_code -dom_parser_implementation::stage2(dom::document &_doc) noexcept { - return stage2::tape_builder::parse_document(*this, _doc); -} - -simdjson_warn_unused error_code -dom_parser_implementation::stage2_next(dom::document &_doc) noexcept { - return stage2::tape_builder::parse_document(*this, _doc); -} - -simdjson_warn_unused error_code dom_parser_implementation::parse( - const uint8_t *_buf, size_t _len, dom::document &_doc) noexcept { - auto error = stage1(_buf, _len, stage1_mode::regular); - if (error) { - return error; - } - return stage2(_doc); -} - -} // namespace westmere -} // namespace simdjson - -/* begin file include/simdjson/westmere/end.h */ -SIMDJSON_UNTARGET_WESTMERE -/* end file include/simdjson/westmere/end.h */ -/* end file src/westmere/dom_parser_implementation.cpp */ -#endif - -SIMDJSON_POP_DISABLE_WARNINGS -/* end file src/simdjson.cpp */ diff --git a/speechx/speechx/utils/simdjson.h b/speechx/speechx/utils/simdjson.h deleted file mode 100644 index 28a9239b1..000000000 --- a/speechx/speechx/utils/simdjson.h +++ /dev/null @@ -1,37881 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* auto-generated on 2022-01-31 11:38:54 -0500. Do not edit! */ -/* begin file include/simdjson.h */ -#ifndef SIMDJSON_H -#define SIMDJSON_H - -/** - * @mainpage - * - * Check the - [README.md](https://github.com/simdjson/simdjson/blob/master/README.md#simdjson--parsing-gigabytes-of-json-per-second). - * - * Sample code. See - https://github.com/simdjson/simdjson/blob/master/doc/basics.md for more - examples. - - #include "simdjson.h" - - int main(void) { - // load from `twitter.json` file: - simdjson::dom::parser parser; - simdjson::dom::element tweets = parser.load("twitter.json"); - std::cout << tweets["search_metadata"]["count"] << " results." << - std::endl; - - // Parse and iterate through an array of objects - auto abstract_json = R"( [ - { "12345" : {"a":12.34, "b":56.78, "c": 9998877} }, - { "12545" : {"a":11.44, "b":12.78, "c": 11111111} } - ] )"_padded; - - for (simdjson::dom::object obj : parser.parse(abstract_json)) { - for(const auto key_value : obj) { - cout << "key: " << key_value.key << " : "; - simdjson::dom::object innerobj = key_value.value; - cout << "a: " << double(innerobj["a"]) << ", "; - cout << "b: " << double(innerobj["b"]) << ", "; - cout << "c: " << int64_t(innerobj["c"]) << endl; - } - } - } - */ - -/* begin file include/simdjson/simdjson_version.h */ -// /include/simdjson/simdjson_version.h automatically generated by release.py, -// do not change by hand -#ifndef SIMDJSON_SIMDJSON_VERSION_H -#define SIMDJSON_SIMDJSON_VERSION_H - -/** The version of simdjson being used (major.minor.revision) */ -#define SIMDJSON_VERSION 1.0.2 - -namespace simdjson { -enum { - /** - * The major version (MAJOR.minor.revision) of simdjson being used. - */ - SIMDJSON_VERSION_MAJOR = 1, - /** - * The minor version (major.MINOR.revision) of simdjson being used. - */ - SIMDJSON_VERSION_MINOR = 0, - /** - * The revision (major.minor.REVISION) of simdjson being used. - */ - SIMDJSON_VERSION_REVISION = 2 -}; -} // namespace simdjson - -#endif // SIMDJSON_SIMDJSON_VERSION_H -/* end file include/simdjson/simdjson_version.h */ -/* begin file include/simdjson/dom.h */ -#ifndef SIMDJSON_DOM_H -#define SIMDJSON_DOM_H - -/* begin file include/simdjson/base.h */ -#ifndef SIMDJSON_BASE_H -#define SIMDJSON_BASE_H - -/* begin file include/simdjson/compiler_check.h */ -#ifndef SIMDJSON_COMPILER_CHECK_H -#define SIMDJSON_COMPILER_CHECK_H - -#ifndef __cplusplus -#error simdjson requires a C++ compiler -#endif - -#ifndef SIMDJSON_CPLUSPLUS -#if defined(_MSVC_LANG) && !defined(__clang__) -#define SIMDJSON_CPLUSPLUS (_MSC_VER == 1900 ? 201103L : _MSVC_LANG) -#else -#define SIMDJSON_CPLUSPLUS __cplusplus -#endif -#endif - -// C++ 17 -#if !defined(SIMDJSON_CPLUSPLUS17) && (SIMDJSON_CPLUSPLUS >= 201703L) -#define SIMDJSON_CPLUSPLUS17 1 -#endif - -// C++ 14 -#if !defined(SIMDJSON_CPLUSPLUS14) && (SIMDJSON_CPLUSPLUS >= 201402L) -#define SIMDJSON_CPLUSPLUS14 1 -#endif - -// C++ 11 -#if !defined(SIMDJSON_CPLUSPLUS11) && (SIMDJSON_CPLUSPLUS >= 201103L) -#define SIMDJSON_CPLUSPLUS11 1 -#endif - -#ifndef SIMDJSON_CPLUSPLUS11 -#error simdjson requires a compiler compliant with the C++11 standard -#endif - -#endif // SIMDJSON_COMPILER_CHECK_H -/* end file include/simdjson/compiler_check.h */ -/* begin file include/simdjson/common_defs.h */ -#ifndef SIMDJSON_COMMON_DEFS_H -#define SIMDJSON_COMMON_DEFS_H - -#include -/* begin file include/simdjson/portability.h */ -#ifndef SIMDJSON_PORTABILITY_H -#define SIMDJSON_PORTABILITY_H - -#include -#include -#include -#include -#include -#ifndef _WIN32 -// strcasecmp, strncasecmp -#include -#endif - -#ifdef _MSC_VER -#define SIMDJSON_VISUAL_STUDIO 1 -/** - * We want to differentiate carefully between - * clang under visual studio and regular visual - * studio. - * - * Under clang for Windows, we enable: - * * target pragmas so that part and only part of the - * code gets compiled for advanced instructions. - * - */ -#ifdef __clang__ -// clang under visual studio -#define SIMDJSON_CLANG_VISUAL_STUDIO 1 -#else -// just regular visual studio (best guess) -#define SIMDJSON_REGULAR_VISUAL_STUDIO 1 -#endif // __clang__ -#endif // _MSC_VER - -#ifdef SIMDJSON_REGULAR_VISUAL_STUDIO -// https://en.wikipedia.org/wiki/C_alternative_tokens -// This header should have no effect, except maybe -// under Visual Studio. -#include -#endif - -#if defined(__x86_64__) || defined(_M_AMD64) -#define SIMDJSON_IS_X86_64 1 -#elif defined(__aarch64__) || defined(_M_ARM64) -#define SIMDJSON_IS_ARM64 1 -#elif defined(__PPC64__) || defined(_M_PPC64) -#define SIMDJSON_IS_PPC64 1 -#else -#define SIMDJSON_IS_32BITS 1 - -// We do not support 32-bit platforms, but it can be -// handy to identify them. -#if defined(_M_IX86) || defined(__i386__) -#define SIMDJSON_IS_X86_32BITS 1 -#elif defined(__arm__) || defined(_M_ARM) -#define SIMDJSON_IS_ARM_32BITS 1 -#elif defined(__PPC__) || defined(_M_PPC) -#define SIMDJSON_IS_PPC_32BITS 1 -#endif - -#endif // defined(__x86_64__) || defined(_M_AMD64) - -#ifdef SIMDJSON_IS_32BITS -#ifndef SIMDJSON_NO_PORTABILITY_WARNING -#pragma message( \ - "The simdjson library is designed \ -for 64-bit processors and it seems that you are not \ -compiling for a known 64-bit platform. All fast kernels \ -will be disabled and performance may be poor. Please \ -use a 64-bit target such as x64, 64-bit ARM or 64-bit PPC.") -#endif // SIMDJSON_NO_PORTABILITY_WARNING -#endif // SIMDJSON_IS_32BITS - -// this is almost standard? -#undef SIMDJSON_STRINGIFY_IMPLEMENTATION_ -#undef SIMDJSON_STRINGIFY -#define SIMDJSON_STRINGIFY_IMPLEMENTATION_(a) #a -#define SIMDJSON_STRINGIFY(a) SIMDJSON_STRINGIFY_IMPLEMENTATION_(a) - -// Our fast kernels require 64-bit systems. -// -// On 32-bit x86, we lack 64-bit popcnt, lzcnt, blsr instructions. -// Furthermore, the number of SIMD registers is reduced. -// -// On 32-bit ARM, we would have smaller registers. -// -// The simdjson users should still have the fallback kernel. It is -// slower, but it should run everywhere. - -// -// Enable valid runtime implementations, and select -// SIMDJSON_BUILTIN_IMPLEMENTATION -// - -// We are going to use runtime dispatch. -#ifdef SIMDJSON_IS_X86_64 -#ifdef __clang__ -// clang does not have GCC push pop -// warning: clang attribute push can't be used within a namespace in clang up -// til 8.0 so SIMDJSON_TARGET_REGION and SIMDJSON_UNTARGET_REGION must be -// *outside* of a -// namespace. -#define SIMDJSON_TARGET_REGION(T) \ - _Pragma(SIMDJSON_STRINGIFY(clang attribute push( \ - __attribute__((target(T))), apply_to = function))) -#define SIMDJSON_UNTARGET_REGION _Pragma("clang attribute pop") -#elif defined(__GNUC__) -// GCC is easier -#define SIMDJSON_TARGET_REGION(T) \ - _Pragma("GCC push_options") _Pragma(SIMDJSON_STRINGIFY(GCC target(T))) -#define SIMDJSON_UNTARGET_REGION _Pragma("GCC pop_options") -#endif // clang then gcc - -#endif // x86 - -// Default target region macros don't do anything. -#ifndef SIMDJSON_TARGET_REGION -#define SIMDJSON_TARGET_REGION(T) -#define SIMDJSON_UNTARGET_REGION -#endif - -// Is threading enabled? -#if defined(_REENTRANT) || defined(_MT) -#ifndef SIMDJSON_THREADS_ENABLED -#define SIMDJSON_THREADS_ENABLED -#endif -#endif - -// workaround for large stack sizes under -O0. -// https://github.com/simdjson/simdjson/issues/691 -#ifdef __APPLE__ -#ifndef __OPTIMIZE__ -// Apple systems have small stack sizes in secondary threads. -// Lack of compiler optimization may generate high stack usage. -// Users may want to disable threads for safety, but only when -// in debug mode which we detect by the fact that the __OPTIMIZE__ -// macro is not defined. -#undef SIMDJSON_THREADS_ENABLED -#endif -#endif - - -#if defined(__clang__) -#define SIMDJSON_NO_SANITIZE_UNDEFINED __attribute__((no_sanitize("undefined"))) -#elif defined(__GNUC__) -#define SIMDJSON_NO_SANITIZE_UNDEFINED __attribute__((no_sanitize_undefined)) -#else -#define SIMDJSON_NO_SANITIZE_UNDEFINED -#endif - -#ifdef SIMDJSON_VISUAL_STUDIO -// This is one case where we do not distinguish between -// regular visual studio and clang under visual studio. -// clang under Windows has _stricmp (like visual studio) but not strcasecmp (as -// clang normally has) -#define simdjson_strcasecmp _stricmp -#define simdjson_strncasecmp _strnicmp -#else -// The strcasecmp, strncasecmp, and strcasestr functions do not work with -// multibyte strings (e.g. UTF-8). -// So they are only useful for ASCII in our context. -// https://www.gnu.org/software/libunistring/manual/libunistring.html#char-_002a-strings -#define simdjson_strcasecmp strcasecmp -#define simdjson_strncasecmp strncasecmp -#endif - -#ifdef NDEBUG - -#ifdef SIMDJSON_VISUAL_STUDIO -#define SIMDJSON_UNREACHABLE() __assume(0) -#define SIMDJSON_ASSUME(COND) __assume(COND) -#else -#define SIMDJSON_UNREACHABLE() __builtin_unreachable(); -#define SIMDJSON_ASSUME(COND) \ - do { \ - if (!(COND)) __builtin_unreachable(); \ - } while (0) -#endif - -#else // NDEBUG - -#define SIMDJSON_UNREACHABLE() assert(0); -#define SIMDJSON_ASSUME(COND) assert(COND) - -#endif - -#endif // SIMDJSON_PORTABILITY_H -/* end file include/simdjson/portability.h */ - -namespace simdjson { - -namespace internal { -/** - * @private - * Our own implementation of the C++17 to_chars function. - * Defined in src/to_chars - */ -char *to_chars(char *first, const char *last, double value); -/** - * @private - * A number parsing routine. - * Defined in src/from_chars - */ -double from_chars(const char *first) noexcept; -double from_chars(const char *first, const char *end) noexcept; -} - -#ifndef SIMDJSON_EXCEPTIONS -#if __cpp_exceptions -#define SIMDJSON_EXCEPTIONS 1 -#else -#define SIMDJSON_EXCEPTIONS 0 -#endif -#endif - -/** The maximum document size supported by simdjson. */ -constexpr size_t SIMDJSON_MAXSIZE_BYTES = 0xFFFFFFFF; - -/** - * The amount of padding needed in a buffer to parse JSON. - * - * the input buf should be readable up to buf + SIMDJSON_PADDING - * this is a stopgap; there should be a better description of the - * main loop and its behavior that abstracts over this - * See https://github.com/simdjson/simdjson/issues/174 - */ -constexpr size_t SIMDJSON_PADDING = 32; - -/** - * By default, simdjson supports this many nested objects and arrays. - * - * This is the default for parser::max_depth(). - */ -constexpr size_t DEFAULT_MAX_DEPTH = 1024; - -} // namespace simdjson - -#if defined(__GNUC__) -// Marks a block with a name so that MCA analysis can see it. -#define SIMDJSON_BEGIN_DEBUG_BLOCK(name) \ - __asm volatile("# LLVM-MCA-BEGIN " #name); -#define SIMDJSON_END_DEBUG_BLOCK(name) __asm volatile("# LLVM-MCA-END " #name); -#define SIMDJSON_DEBUG_BLOCK(name, block) \ - BEGIN_DEBUG_BLOCK(name); \ - block; \ - END_DEBUG_BLOCK(name); -#else -#define SIMDJSON_BEGIN_DEBUG_BLOCK(name) -#define SIMDJSON_END_DEBUG_BLOCK(name) -#define SIMDJSON_DEBUG_BLOCK(name, block) -#endif - -// Align to N-byte boundary -#define SIMDJSON_ROUNDUP_N(a, n) (((a) + ((n)-1)) & ~((n)-1)) -#define SIMDJSON_ROUNDDOWN_N(a, n) ((a) & ~((n)-1)) - -#define SIMDJSON_ISALIGNED_N(ptr, n) (((uintptr_t)(ptr) & ((n)-1)) == 0) - -#if defined(SIMDJSON_REGULAR_VISUAL_STUDIO) - -#define simdjson_really_inline __forceinline -#define simdjson_never_inline __declspec(noinline) - -#define simdjson_unused -#define simdjson_warn_unused - -#ifndef simdjson_likely -#define simdjson_likely(x) x -#endif -#ifndef simdjson_unlikely -#define simdjson_unlikely(x) x -#endif - -#define SIMDJSON_PUSH_DISABLE_WARNINGS __pragma(warning(push)) -#define SIMDJSON_PUSH_DISABLE_ALL_WARNINGS __pragma(warning(push, 0)) -#define SIMDJSON_DISABLE_VS_WARNING(WARNING_NUMBER) \ - __pragma(warning(disable : WARNING_NUMBER)) -// Get rid of Intellisense-only warnings (Code Analysis) -// Though __has_include is C++17, it is supported in Visual Studio 2017 or -// better (_MSC_VER>=1910). -#ifdef __has_include -#if __has_include() -#include -#define SIMDJSON_DISABLE_UNDESIRED_WARNINGS \ - SIMDJSON_DISABLE_VS_WARNING(ALL_CPPCORECHECK_WARNINGS) -#endif -#endif - -#ifndef SIMDJSON_DISABLE_UNDESIRED_WARNINGS -#define SIMDJSON_DISABLE_UNDESIRED_WARNINGS -#endif - -#define SIMDJSON_DISABLE_DEPRECATED_WARNING SIMDJSON_DISABLE_VS_WARNING(4996) -#define SIMDJSON_DISABLE_STRICT_OVERFLOW_WARNING -#define SIMDJSON_POP_DISABLE_WARNINGS __pragma(warning(pop)) - -#else // SIMDJSON_REGULAR_VISUAL_STUDIO - -#define simdjson_really_inline inline __attribute__((always_inline)) -#define simdjson_never_inline inline __attribute__((noinline)) - -#define simdjson_unused __attribute__((unused)) -#define simdjson_warn_unused __attribute__((warn_unused_result)) - -#ifndef simdjson_likely -#define simdjson_likely(x) __builtin_expect(!!(x), 1) -#endif -#ifndef simdjson_unlikely -#define simdjson_unlikely(x) __builtin_expect(!!(x), 0) -#endif - -#define SIMDJSON_PUSH_DISABLE_WARNINGS _Pragma("GCC diagnostic push") -// gcc doesn't seem to disable all warnings with all and extra, add warnings -// here as necessary -#define SIMDJSON_PUSH_DISABLE_ALL_WARNINGS \ - SIMDJSON_PUSH_DISABLE_WARNINGS \ - SIMDJSON_DISABLE_GCC_WARNING(-Weffc++) \ - SIMDJSON_DISABLE_GCC_WARNING(-Wall) \ - SIMDJSON_DISABLE_GCC_WARNING(-Wconversion) \ - SIMDJSON_DISABLE_GCC_WARNING(-Wextra) \ - SIMDJSON_DISABLE_GCC_WARNING(-Wattributes) \ - SIMDJSON_DISABLE_GCC_WARNING(-Wimplicit - fallthrough) \ - SIMDJSON_DISABLE_GCC_WARNING(-Wnon - virtual - dtor) \ - SIMDJSON_DISABLE_GCC_WARNING(-Wreturn - type) \ - SIMDJSON_DISABLE_GCC_WARNING(-Wshadow) \ - SIMDJSON_DISABLE_GCC_WARNING(-Wunused - parameter) \ - SIMDJSON_DISABLE_GCC_WARNING(-Wunused - variable) -#define SIMDJSON_PRAGMA(P) _Pragma(#P) -#define SIMDJSON_DISABLE_GCC_WARNING(WARNING) \ - SIMDJSON_PRAGMA(GCC diagnostic ignored #WARNING) -#if defined(SIMDJSON_CLANG_VISUAL_STUDIO) -#define SIMDJSON_DISABLE_UNDESIRED_WARNINGS \ - SIMDJSON_DISABLE_GCC_WARNING(-Wmicrosoft - include) -#else -#define SIMDJSON_DISABLE_UNDESIRED_WARNINGS -#endif -#define SIMDJSON_DISABLE_DEPRECATED_WARNING \ - SIMDJSON_DISABLE_GCC_WARNING(-Wdeprecated - declarations) -#define SIMDJSON_DISABLE_STRICT_OVERFLOW_WARNING \ - SIMDJSON_DISABLE_GCC_WARNING(-Wstrict - overflow) -#define SIMDJSON_POP_DISABLE_WARNINGS _Pragma("GCC diagnostic pop") - - -#endif // MSC_VER - -#if defined(SIMDJSON_VISUAL_STUDIO) -/** - * Windows users need to do some extra work when building - * or using a dynamic library (DLL). When building, we need - * to set SIMDJSON_DLLIMPORTEXPORT to __declspec(dllexport). - * When *using* the DLL, the user needs to set - * SIMDJSON_DLLIMPORTEXPORT __declspec(dllimport). - * - * Static libraries not need require such work. - * - * It does not matter here whether you are using - * the regular visual studio or clang under visual - * studio, you still need to handle these issues. - * - * Non-Windows systems do not have this complexity. - */ -#if SIMDJSON_BUILDING_WINDOWS_DYNAMIC_LIBRARY -// We set SIMDJSON_BUILDING_WINDOWS_DYNAMIC_LIBRARY when we build a DLL under -// Windows. -// It should never happen that both SIMDJSON_BUILDING_WINDOWS_DYNAMIC_LIBRARY -// and -// SIMDJSON_USING_WINDOWS_DYNAMIC_LIBRARY are set. -#define SIMDJSON_DLLIMPORTEXPORT __declspec(dllexport) -#elif SIMDJSON_USING_WINDOWS_DYNAMIC_LIBRARY -// Windows user who call a dynamic library should set -// SIMDJSON_USING_WINDOWS_DYNAMIC_LIBRARY to 1. -#define SIMDJSON_DLLIMPORTEXPORT __declspec(dllimport) -#else -// We assume by default static linkage -#define SIMDJSON_DLLIMPORTEXPORT -#endif - -/** - * Workaround for the vcpkg package manager. Only vcpkg should - * ever touch the next line. The SIMDJSON_USING_LIBRARY macro is otherwise - * unused. - */ -#if SIMDJSON_USING_LIBRARY -#define SIMDJSON_DLLIMPORTEXPORT __declspec(dllimport) -#endif -/** - * End of workaround for the vcpkg package manager. - */ -#else -#define SIMDJSON_DLLIMPORTEXPORT -#endif - -// C++17 requires string_view. -#if SIMDJSON_CPLUSPLUS17 -#define SIMDJSON_HAS_STRING_VIEW -#include // by the standard, this has to be safe. -#endif - -// This macro (__cpp_lib_string_view) has to be defined -// for C++17 and better, but if it is otherwise defined, -// we are going to assume that string_view is available -// even if we do not have C++17 support. -#ifdef __cpp_lib_string_view -#define SIMDJSON_HAS_STRING_VIEW -#endif - -// Some systems have string_view even if we do not have C++17 support, -// and even if __cpp_lib_string_view is undefined, it is the case -// with Apple clang version 11. -// We must handle it. *This is important.* -#ifndef SIMDJSON_HAS_STRING_VIEW -#if defined __has_include -// do not combine the next #if with the previous one (unsafe) -#if __has_include() -// now it is safe to trigger the include -#include // though the file is there, it does not follow that we got the implementation -#if defined(_LIBCPP_STRING_VIEW) -// Ah! So we under libc++ which under its Library Fundamentals Technical -// Specification, which preceded C++17, -// included string_view. -// This means that we have string_view *even though* we may not have C++17. -#define SIMDJSON_HAS_STRING_VIEW -#endif // _LIBCPP_STRING_VIEW -#endif // __has_include () -#endif // defined __has_include -#endif // def SIMDJSON_HAS_STRING_VIEW -// end of complicated but important routine to try to detect string_view. - -// -// Backfill std::string_view using nonstd::string_view on systems where -// we expect that string_view is missing. Important: if we get this wrong, -// we will end up with two string_view definitions and potential trouble. -// That is why we work so hard above to avoid it. -// -#ifndef SIMDJSON_HAS_STRING_VIEW -SIMDJSON_PUSH_DISABLE_ALL_WARNINGS -/* begin file include/simdjson/nonstd/string_view.hpp */ -// Copyright 2017-2020 by Martin Moene -// -// string-view lite, a C++17-like string_view for C++98 and later. -// For more information see https://github.com/martinmoene/string-view-lite -// -// Distributed under the Boost Software License, Version 1.0. -// (See accompanying file LICENSE.txt or copy at -// http://www.boost.org/LICENSE_1_0.txt) - -#pragma once - -#ifndef NONSTD_SV_LITE_H_INCLUDED -#define NONSTD_SV_LITE_H_INCLUDED - -#define string_view_lite_MAJOR 1 -#define string_view_lite_MINOR 6 -#define string_view_lite_PATCH 0 - -#define string_view_lite_VERSION \ - nssv_STRINGIFY(string_view_lite_MAJOR) "." nssv_STRINGIFY( \ - string_view_lite_MINOR) "." nssv_STRINGIFY(string_view_lite_PATCH) - -#define nssv_STRINGIFY(x) nssv_STRINGIFY_(x) -#define nssv_STRINGIFY_(x) #x - -// string-view lite configuration: - -#define nssv_STRING_VIEW_DEFAULT 0 -#define nssv_STRING_VIEW_NONSTD 1 -#define nssv_STRING_VIEW_STD 2 - -// tweak header support: - -#ifdef __has_include -#if __has_include() -#include -#endif -#define nssv_HAVE_TWEAK_HEADER 1 -#else -#define nssv_HAVE_TWEAK_HEADER 0 -//# pragma message("string_view.hpp: Note: Tweak header not supported.") -#endif - -// string_view selection and configuration: - -#if !defined(nssv_CONFIG_SELECT_STRING_VIEW) -#define nssv_CONFIG_SELECT_STRING_VIEW \ - (nssv_HAVE_STD_STRING_VIEW ? nssv_STRING_VIEW_STD : nssv_STRING_VIEW_NONSTD) -#endif - -#ifndef nssv_CONFIG_STD_SV_OPERATOR -#define nssv_CONFIG_STD_SV_OPERATOR 0 -#endif - -#ifndef nssv_CONFIG_USR_SV_OPERATOR -#define nssv_CONFIG_USR_SV_OPERATOR 1 -#endif - -#ifdef nssv_CONFIG_CONVERSION_STD_STRING -#define nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS \ - nssv_CONFIG_CONVERSION_STD_STRING -#define nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS \ - nssv_CONFIG_CONVERSION_STD_STRING -#endif - -#ifndef nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS -#define nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS 1 -#endif - -#ifndef nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS -#define nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS 1 -#endif - -#ifndef nssv_CONFIG_NO_STREAM_INSERTION -#define nssv_CONFIG_NO_STREAM_INSERTION 0 -#endif - -// Control presence of exception handling (try and auto discover): - -#ifndef nssv_CONFIG_NO_EXCEPTIONS -#if _MSC_VER -#include // for _HAS_EXCEPTIONS -#endif -#if defined(__cpp_exceptions) || defined(__EXCEPTIONS) || (_HAS_EXCEPTIONS) -#define nssv_CONFIG_NO_EXCEPTIONS 0 -#else -#define nssv_CONFIG_NO_EXCEPTIONS 1 -#endif -#endif - -// C++ language version detection (C++20 is speculative): -// Note: VC14.0/1900 (VS2015) lacks too much from C++14. - -#ifndef nssv_CPLUSPLUS -#if defined(_MSVC_LANG) && !defined(__clang__) -#define nssv_CPLUSPLUS (_MSC_VER == 1900 ? 201103L : _MSVC_LANG) -#else -#define nssv_CPLUSPLUS __cplusplus -#endif -#endif - -#define nssv_CPP98_OR_GREATER (nssv_CPLUSPLUS >= 199711L) -#define nssv_CPP11_OR_GREATER (nssv_CPLUSPLUS >= 201103L) -#define nssv_CPP11_OR_GREATER_ (nssv_CPLUSPLUS >= 201103L) -#define nssv_CPP14_OR_GREATER (nssv_CPLUSPLUS >= 201402L) -#define nssv_CPP17_OR_GREATER (nssv_CPLUSPLUS >= 201703L) -#define nssv_CPP20_OR_GREATER (nssv_CPLUSPLUS >= 202000L) - -// use C++17 std::string_view if available and requested: - -#if nssv_CPP17_OR_GREATER && defined(__has_include) -#if __has_include() -#define nssv_HAVE_STD_STRING_VIEW 1 -#else -#define nssv_HAVE_STD_STRING_VIEW 0 -#endif -#else -#define nssv_HAVE_STD_STRING_VIEW 0 -#endif - -#define nssv_USES_STD_STRING_VIEW \ - ((nssv_CONFIG_SELECT_STRING_VIEW == nssv_STRING_VIEW_STD) || \ - ((nssv_CONFIG_SELECT_STRING_VIEW == nssv_STRING_VIEW_DEFAULT) && \ - nssv_HAVE_STD_STRING_VIEW)) - -#define nssv_HAVE_STARTS_WITH \ - (nssv_CPP20_OR_GREATER || !nssv_USES_STD_STRING_VIEW) -#define nssv_HAVE_ENDS_WITH nssv_HAVE_STARTS_WITH - -// -// Use C++17 std::string_view: -// - -#if nssv_USES_STD_STRING_VIEW - -#include - -// Extensions for std::string: - -#if nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS - -namespace nonstd { - -template > -std::basic_string to_string( - std::basic_string_view v, Allocator const &a = Allocator()) { - return std::basic_string(v.begin(), v.end(), a); -} - -template -std::basic_string_view to_string_view( - std::basic_string const &s) { - return std::basic_string_view(s.data(), s.size()); -} - -// Literal operators sv and _sv: - -#if nssv_CONFIG_STD_SV_OPERATOR - -using namespace std::literals::string_view_literals; - -#endif - -#if nssv_CONFIG_USR_SV_OPERATOR - -inline namespace literals { -inline namespace string_view_literals { -constexpr std::string_view operator"" _sv(const char *str, - size_t len) noexcept // (1) -{ - return std::string_view{str, len}; -} - -constexpr std::u16string_view operator"" _sv(const char16_t *str, - size_t len) noexcept // (2) -{ - return std::u16string_view{str, len}; -} - -constexpr std::u32string_view operator"" _sv(const char32_t *str, - size_t len) noexcept // (3) -{ - return std::u32string_view{str, len}; -} - -constexpr std::wstring_view operator"" _sv(const wchar_t *str, - size_t len) noexcept // (4) -{ - return std::wstring_view{str, len}; -} -} -} // namespace literals::string_view_literals - -#endif // nssv_CONFIG_USR_SV_OPERATOR - -} // namespace nonstd - -#endif // nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS - -namespace nonstd { - -using std::string_view; -using std::wstring_view; -using std::u16string_view; -using std::u32string_view; -using std::basic_string_view; - -// literal "sv" and "_sv", see above - -using std::operator==; -using std::operator!=; -using std::operator<; -using std::operator<=; -using std::operator>; -using std::operator>=; - -using std::operator<<; - -} // namespace nonstd - -#else // nssv_HAVE_STD_STRING_VIEW - -// -// Before C++17: use string_view lite: -// - -// Compiler versions: -// -// MSVC++ 6.0 _MSC_VER == 1200 nssv_COMPILER_MSVC_VERSION == 60 (Visual -// Studio 6.0) -// MSVC++ 7.0 _MSC_VER == 1300 nssv_COMPILER_MSVC_VERSION == 70 (Visual -// Studio .NET 2002) -// MSVC++ 7.1 _MSC_VER == 1310 nssv_COMPILER_MSVC_VERSION == 71 (Visual -// Studio .NET 2003) -// MSVC++ 8.0 _MSC_VER == 1400 nssv_COMPILER_MSVC_VERSION == 80 (Visual -// Studio 2005) -// MSVC++ 9.0 _MSC_VER == 1500 nssv_COMPILER_MSVC_VERSION == 90 (Visual -// Studio 2008) -// MSVC++ 10.0 _MSC_VER == 1600 nssv_COMPILER_MSVC_VERSION == 100 (Visual -// Studio 2010) -// MSVC++ 11.0 _MSC_VER == 1700 nssv_COMPILER_MSVC_VERSION == 110 (Visual -// Studio 2012) -// MSVC++ 12.0 _MSC_VER == 1800 nssv_COMPILER_MSVC_VERSION == 120 (Visual -// Studio 2013) -// MSVC++ 14.0 _MSC_VER == 1900 nssv_COMPILER_MSVC_VERSION == 140 (Visual -// Studio 2015) -// MSVC++ 14.1 _MSC_VER >= 1910 nssv_COMPILER_MSVC_VERSION == 141 (Visual -// Studio 2017) -// MSVC++ 14.2 _MSC_VER >= 1920 nssv_COMPILER_MSVC_VERSION == 142 (Visual -// Studio 2019) - -#if defined(_MSC_VER) && !defined(__clang__) -#define nssv_COMPILER_MSVC_VER (_MSC_VER) -#define nssv_COMPILER_MSVC_VERSION \ - (_MSC_VER / 10 - 10 * (5 + (_MSC_VER < 1900))) -#else -#define nssv_COMPILER_MSVC_VER 0 -#define nssv_COMPILER_MSVC_VERSION 0 -#endif - -#define nssv_COMPILER_VERSION(major, minor, patch) \ - (10 * (10 * (major) + (minor)) + (patch)) - -#if defined(__apple_build_version__) -#define nssv_COMPILER_APPLECLANG_VERSION \ - nssv_COMPILER_VERSION( \ - __clang_major__, __clang_minor__, __clang_patchlevel__) -#define nssv_COMPILER_CLANG_VERSION 0 -#elif defined(__clang__) -#define nssv_COMPILER_APPLECLANG_VERSION 0 -#define nssv_COMPILER_CLANG_VERSION \ - nssv_COMPILER_VERSION( \ - __clang_major__, __clang_minor__, __clang_patchlevel__) -#else -#define nssv_COMPILER_APPLECLANG_VERSION 0 -#define nssv_COMPILER_CLANG_VERSION 0 -#endif - -#if defined(__GNUC__) && !defined(__clang__) -#define nssv_COMPILER_GNUC_VERSION \ - nssv_COMPILER_VERSION(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__) -#else -#define nssv_COMPILER_GNUC_VERSION 0 -#endif - -// half-open range [lo..hi): -#define nssv_BETWEEN(v, lo, hi) ((lo) <= (v) && (v) < (hi)) - -// Presence of language and library features: - -#ifdef _HAS_CPP0X -#define nssv_HAS_CPP0X _HAS_CPP0X -#else -#define nssv_HAS_CPP0X 0 -#endif - -// Unless defined otherwise below, consider VC14 as C++11 for variant-lite: - -#if nssv_COMPILER_MSVC_VER >= 1900 -#undef nssv_CPP11_OR_GREATER -#define nssv_CPP11_OR_GREATER 1 -#endif - -#define nssv_CPP11_90 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1500) -#define nssv_CPP11_100 \ - (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1600) -#define nssv_CPP11_110 \ - (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1700) -#define nssv_CPP11_120 \ - (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1800) -#define nssv_CPP11_140 \ - (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1900) -#define nssv_CPP11_141 \ - (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1910) - -#define nssv_CPP14_000 (nssv_CPP14_OR_GREATER) -#define nssv_CPP17_000 (nssv_CPP17_OR_GREATER) - -// Presence of C++11 language features: - -#define nssv_HAVE_CONSTEXPR_11 nssv_CPP11_140 -#define nssv_HAVE_EXPLICIT_CONVERSION nssv_CPP11_140 -#define nssv_HAVE_INLINE_NAMESPACE nssv_CPP11_140 -#define nssv_HAVE_NOEXCEPT nssv_CPP11_140 -#define nssv_HAVE_NULLPTR nssv_CPP11_100 -#define nssv_HAVE_REF_QUALIFIER nssv_CPP11_140 -#define nssv_HAVE_UNICODE_LITERALS nssv_CPP11_140 -#define nssv_HAVE_USER_DEFINED_LITERALS nssv_CPP11_140 -#define nssv_HAVE_WCHAR16_T nssv_CPP11_100 -#define nssv_HAVE_WCHAR32_T nssv_CPP11_100 - -#if !((nssv_CPP11_OR_GREATER && nssv_COMPILER_CLANG_VERSION) || \ - nssv_BETWEEN(nssv_COMPILER_CLANG_VERSION, 300, 400)) -#define nssv_HAVE_STD_DEFINED_LITERALS nssv_CPP11_140 -#else -#define nssv_HAVE_STD_DEFINED_LITERALS 0 -#endif - -// Presence of C++14 language features: - -#define nssv_HAVE_CONSTEXPR_14 nssv_CPP14_000 - -// Presence of C++17 language features: - -#define nssv_HAVE_NODISCARD nssv_CPP17_000 - -// Presence of C++ library features: - -#define nssv_HAVE_STD_HASH nssv_CPP11_120 - -// Presence of compiler intrinsics: - -// Providing char-type specializations for compare() and length() that -// use compiler intrinsics can improve compile- and run-time performance. -// -// The challenge is in using the right combinations of builtin availability -// and its constexpr-ness. -// -// | compiler | __builtin_memcmp (constexpr) | memcmp (constexpr) | -// |----------|------------------------------|---------------------| -// | clang | 4.0 (>= 4.0 ) | any (? ) | -// | clang-a | 9.0 (>= 9.0 ) | any (? ) | -// | gcc | any (constexpr) | any (? ) | -// | msvc | >= 14.2 C++17 (>= 14.2 ) | any (? ) | - -#define nssv_HAVE_BUILTIN_VER \ - ((nssv_CPP17_000 && nssv_COMPILER_MSVC_VERSION >= 142) || \ - nssv_COMPILER_GNUC_VERSION > 0 || nssv_COMPILER_CLANG_VERSION >= 400 || \ - nssv_COMPILER_APPLECLANG_VERSION >= 900) -#define nssv_HAVE_BUILTIN_CE (nssv_HAVE_BUILTIN_VER) - -#define nssv_HAVE_BUILTIN_MEMCMP \ - ((nssv_HAVE_CONSTEXPR_14 && nssv_HAVE_BUILTIN_CE) || \ - !nssv_HAVE_CONSTEXPR_14) -#define nssv_HAVE_BUILTIN_STRLEN \ - ((nssv_HAVE_CONSTEXPR_11 && nssv_HAVE_BUILTIN_CE) || \ - !nssv_HAVE_CONSTEXPR_11) - -#ifdef __has_builtin -#define nssv_HAVE_BUILTIN(x) __has_builtin(x) -#else -#define nssv_HAVE_BUILTIN(x) 0 -#endif - -#if nssv_HAVE_BUILTIN(__builtin_memcmp) || nssv_HAVE_BUILTIN_VER -#define nssv_BUILTIN_MEMCMP __builtin_memcmp -#else -#define nssv_BUILTIN_MEMCMP memcmp -#endif - -#if nssv_HAVE_BUILTIN(__builtin_strlen) || nssv_HAVE_BUILTIN_VER -#define nssv_BUILTIN_STRLEN __builtin_strlen -#else -#define nssv_BUILTIN_STRLEN strlen -#endif - -// C++ feature usage: - -#if nssv_HAVE_CONSTEXPR_11 -#define nssv_constexpr constexpr -#else -#define nssv_constexpr /*constexpr*/ -#endif - -#if nssv_HAVE_CONSTEXPR_14 -#define nssv_constexpr14 constexpr -#else -#define nssv_constexpr14 /*constexpr*/ -#endif - -#if nssv_HAVE_EXPLICIT_CONVERSION -#define nssv_explicit explicit -#else -#define nssv_explicit /*explicit*/ -#endif - -#if nssv_HAVE_INLINE_NAMESPACE -#define nssv_inline_ns inline -#else -#define nssv_inline_ns /*inline*/ -#endif - -#if nssv_HAVE_NOEXCEPT -#define nssv_noexcept noexcept -#else -#define nssv_noexcept /*noexcept*/ -#endif - -//#if nssv_HAVE_REF_QUALIFIER -//# define nssv_ref_qual & -//# define nssv_refref_qual && -//#else -//# define nssv_ref_qual /*&*/ -//# define nssv_refref_qual /*&&*/ -//#endif - -#if nssv_HAVE_NULLPTR -#define nssv_nullptr nullptr -#else -#define nssv_nullptr NULL -#endif - -#if nssv_HAVE_NODISCARD -#define nssv_nodiscard [[nodiscard]] -#else -#define nssv_nodiscard /*[[nodiscard]]*/ -#endif - -// Additional includes: - -#include -#include -#include -#include -#include // std::char_traits<> - -#if !nssv_CONFIG_NO_STREAM_INSERTION -#include -#endif - -#if !nssv_CONFIG_NO_EXCEPTIONS -#include -#endif - -#if nssv_CPP11_OR_GREATER -#include -#endif - -// Clang, GNUC, MSVC warning suppression macros: - -#if defined(__clang__) -#pragma clang diagnostic ignored "-Wreserved-user-defined-literal" -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wuser-defined-literals" -#elif defined(__GNUC__) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wliteral-suffix" -#endif // __clang__ - -#if nssv_COMPILER_MSVC_VERSION >= 140 -#define nssv_SUPPRESS_MSGSL_WARNING(expr) [[gsl::suppress(expr)]] -#define nssv_SUPPRESS_MSVC_WARNING(code, descr) \ - __pragma(warning(suppress : code)) -#define nssv_DISABLE_MSVC_WARNINGS(codes) \ - __pragma(warning(push)) __pragma(warning(disable : codes)) -#else -#define nssv_SUPPRESS_MSGSL_WARNING(expr) -#define nssv_SUPPRESS_MSVC_WARNING(code, descr) -#define nssv_DISABLE_MSVC_WARNINGS(codes) -#endif - -#if defined(__clang__) -#define nssv_RESTORE_WARNINGS() _Pragma("clang diagnostic pop") -#elif defined(__GNUC__) -#define nssv_RESTORE_WARNINGS() _Pragma("GCC diagnostic pop") -#elif nssv_COMPILER_MSVC_VERSION >= 140 -#define nssv_RESTORE_WARNINGS() __pragma(warning(pop)) -#else -#define nssv_RESTORE_WARNINGS() -#endif - -// Suppress the following MSVC (GSL) warnings: -// - C4455, non-gsl : 'operator ""sv': literal suffix identifiers that do not -// start with an underscore are reserved -// - C26472, gsl::t.1 : don't use a static_cast for arithmetic conversions; -// use brace initialization, gsl::narrow_cast or gsl::narow -// - C26481: gsl::b.1 : don't use pointer arithmetic. Use span instead - -nssv_DISABLE_MSVC_WARNINGS(4455 26481 26472) - // nssv_DISABLE_CLANG_WARNINGS( "-Wuser-defined-literals" ) - // nssv_DISABLE_GNUC_WARNINGS( -Wliteral-suffix ) - - namespace nonstd { - namespace sv_lite { - - namespace detail { - - // support constexpr comparison in C++14; - // for C++17 and later, use provided traits: - - template - inline nssv_constexpr14 int compare(CharT const *s1, - CharT const *s2, - std::size_t count) { - while (count-- != 0) { - if (*s1 < *s2) return -1; - if (*s1 > *s2) return +1; - ++s1; - ++s2; - } - return 0; - } - -#if nssv_HAVE_BUILTIN_MEMCMP - - // specialization of compare() for char, see also generic compare() above: - - inline nssv_constexpr14 int compare(char const *s1, - char const *s2, - std::size_t count) { - return nssv_BUILTIN_MEMCMP(s1, s2, count); - } - -#endif - -#if nssv_HAVE_BUILTIN_STRLEN - - // specialization of length() for char, see also generic length() further - // below: - - inline nssv_constexpr std::size_t length(char const *s) { - return nssv_BUILTIN_STRLEN(s); - } - -#endif - -#if defined(__OPTIMIZE__) - - // gcc, clang provide __OPTIMIZE__ - // Expect tail call optimization to make length() non-recursive: - - template - inline nssv_constexpr std::size_t length(CharT *s, std::size_t result = 0) { - return *s == '\0' ? result : length(s + 1, result + 1); - } - -#else // OPTIMIZE - - // non-recursive: - - template - inline nssv_constexpr14 std::size_t length(CharT *s) { - std::size_t result = 0; - while (*s++ != '\0') { - ++result; - } - return result; - } - -#endif // OPTIMIZE - - } // namespace detail - - template > - class basic_string_view; - - // - // basic_string_view: - // - - template */ - > - class basic_string_view { - public: - // Member types: - - typedef Traits traits_type; - typedef CharT value_type; - - typedef CharT *pointer; - typedef CharT const *const_pointer; - typedef CharT &reference; - typedef CharT const &const_reference; - - typedef const_pointer iterator; - typedef const_pointer const_iterator; - typedef std::reverse_iterator reverse_iterator; - typedef std::reverse_iterator const_reverse_iterator; - - typedef std::size_t size_type; - typedef std::ptrdiff_t difference_type; - - // 24.4.2.1 Construction and assignment: - - nssv_constexpr basic_string_view() nssv_noexcept : data_(nssv_nullptr), - size_(0) {} - -#if nssv_CPP11_OR_GREATER - nssv_constexpr basic_string_view(basic_string_view const &other) - nssv_noexcept = default; -#else - nssv_constexpr basic_string_view(basic_string_view const &other) - nssv_noexcept : data_(other.data_), - size_(other.size_) {} -#endif - - nssv_constexpr basic_string_view(CharT const *s, size_type count) - nssv_noexcept // non-standard noexcept - : data_(s), - size_(count) {} - - nssv_constexpr basic_string_view(CharT const *s) - nssv_noexcept // non-standard noexcept - : data_(s) -#if nssv_CPP17_OR_GREATER - , - size_(Traits::length(s)) -#elif nssv_CPP11_OR_GREATER - , - size_(detail::length(s)) -#else - , - size_(Traits::length(s)) -#endif - { - } - -// Assignment: - -#if nssv_CPP11_OR_GREATER - nssv_constexpr14 basic_string_view &operator=( - basic_string_view const &other) nssv_noexcept = default; -#else - nssv_constexpr14 basic_string_view &operator=( - basic_string_view const &other) nssv_noexcept { - data_ = other.data_; - size_ = other.size_; - return *this; - } -#endif - - // 24.4.2.2 Iterator support: - - nssv_constexpr const_iterator begin() const nssv_noexcept { - return data_; - } - nssv_constexpr const_iterator end() const nssv_noexcept { - return data_ + size_; - } - - nssv_constexpr const_iterator cbegin() const nssv_noexcept { - return begin(); - } - nssv_constexpr const_iterator cend() const nssv_noexcept { - return end(); - } - - nssv_constexpr const_reverse_iterator rbegin() const nssv_noexcept { - return const_reverse_iterator(end()); - } - nssv_constexpr const_reverse_iterator rend() const nssv_noexcept { - return const_reverse_iterator(begin()); - } - - nssv_constexpr const_reverse_iterator crbegin() const nssv_noexcept { - return rbegin(); - } - nssv_constexpr const_reverse_iterator crend() const nssv_noexcept { - return rend(); - } - - // 24.4.2.3 Capacity: - - nssv_constexpr size_type size() const nssv_noexcept { return size_; } - nssv_constexpr size_type length() const nssv_noexcept { return size_; } - nssv_constexpr size_type max_size() const nssv_noexcept { - return (std::numeric_limits::max)(); - } - - // since C++20 - nssv_nodiscard nssv_constexpr bool empty() const nssv_noexcept { - return 0 == size_; - } - - // 24.4.2.4 Element access: - - nssv_constexpr const_reference operator[](size_type pos) const { - return data_at(pos); - } - - nssv_constexpr14 const_reference at(size_type pos) const { -#if nssv_CONFIG_NO_EXCEPTIONS - assert(pos < size()); -#else - if (pos >= size()) { - throw std::out_of_range("nonstd::string_view::at()"); - } -#endif - return data_at(pos); - } - - nssv_constexpr const_reference front() const { return data_at(0); } - nssv_constexpr const_reference back() const { - return data_at(size() - 1); - } - - nssv_constexpr const_pointer data() const nssv_noexcept { - return data_; - } - - // 24.4.2.5 Modifiers: - - nssv_constexpr14 void remove_prefix(size_type n) { - assert(n <= size()); - data_ += n; - size_ -= n; - } - - nssv_constexpr14 void remove_suffix(size_type n) { - assert(n <= size()); - size_ -= n; - } - - nssv_constexpr14 void swap(basic_string_view &other) nssv_noexcept { - const basic_string_view tmp(other); - other = *this; - *this = tmp; - } - - // 24.4.2.6 String operations: - - size_type copy(CharT *dest, size_type n, size_type pos = 0) const { -#if nssv_CONFIG_NO_EXCEPTIONS - assert(pos <= size()); -#else - if (pos > size()) { - throw std::out_of_range("nonstd::string_view::copy()"); - } -#endif - const size_type rlen = (std::min)(n, size() - pos); - - (void)Traits::copy(dest, data() + pos, rlen); - - return rlen; - } - - nssv_constexpr14 basic_string_view substr(size_type pos = 0, - size_type n = npos) const { -#if nssv_CONFIG_NO_EXCEPTIONS - assert(pos <= size()); -#else - if (pos > size()) { - throw std::out_of_range("nonstd::string_view::substr()"); - } -#endif - return basic_string_view(data() + pos, (std::min)(n, size() - pos)); - } - - // compare(), 6x: - - nssv_constexpr14 int compare(basic_string_view other) const - nssv_noexcept // (1) - { -#if nssv_CPP17_OR_GREATER - if (const int result = Traits::compare( - data(), other.data(), (std::min)(size(), other.size()))) -#else - if (const int result = detail::compare( - data(), other.data(), (std::min)(size(), other.size()))) -#endif - { - return result; - } - - return size() == other.size() ? 0 : size() < other.size() ? -1 : 1; - } - - nssv_constexpr int compare(size_type pos1, - size_type n1, - basic_string_view other) const // (2) - { - return substr(pos1, n1).compare(other); - } - - nssv_constexpr int compare(size_type pos1, - size_type n1, - basic_string_view other, - size_type pos2, - size_type n2) const // (3) - { - return substr(pos1, n1).compare(other.substr(pos2, n2)); - } - - nssv_constexpr int compare(CharT const *s) const // (4) - { - return compare(basic_string_view(s)); - } - - nssv_constexpr int compare(size_type pos1, - size_type n1, - CharT const *s) const // (5) - { - return substr(pos1, n1).compare(basic_string_view(s)); - } - - nssv_constexpr int compare(size_type pos1, - size_type n1, - CharT const *s, - size_type n2) const // (6) - { - return substr(pos1, n1).compare(basic_string_view(s, n2)); - } - - // 24.4.2.7 Searching: - - // starts_with(), 3x, since C++20: - - nssv_constexpr bool starts_with(basic_string_view v) const - nssv_noexcept // (1) - { - return size() >= v.size() && compare(0, v.size(), v) == 0; - } - - nssv_constexpr bool starts_with(CharT c) const nssv_noexcept // (2) - { - return starts_with(basic_string_view(&c, 1)); - } - - nssv_constexpr bool starts_with(CharT const *s) const // (3) - { - return starts_with(basic_string_view(s)); - } - - // ends_with(), 3x, since C++20: - - nssv_constexpr bool ends_with(basic_string_view v) const - nssv_noexcept // (1) - { - return size() >= v.size() && - compare(size() - v.size(), npos, v) == 0; - } - - nssv_constexpr bool ends_with(CharT c) const nssv_noexcept // (2) - { - return ends_with(basic_string_view(&c, 1)); - } - - nssv_constexpr bool ends_with(CharT const *s) const // (3) - { - return ends_with(basic_string_view(s)); - } - - // find(), 4x: - - nssv_constexpr14 size_type - find(basic_string_view v, size_type pos = 0) const nssv_noexcept // (1) - { - return assert(v.size() == 0 || v.data() != nssv_nullptr), - pos >= size() ? npos : to_pos(std::search(cbegin() + pos, - cend(), - v.cbegin(), - v.cend(), - Traits::eq)); - } - - nssv_constexpr14 size_type - find(CharT c, size_type pos = 0) const nssv_noexcept // (2) - { - return find(basic_string_view(&c, 1), pos); - } - - nssv_constexpr14 size_type find(CharT const *s, - size_type pos, - size_type n) const // (3) - { - return find(basic_string_view(s, n), pos); - } - - nssv_constexpr14 size_type find(CharT const *s, - size_type pos = 0) const // (4) - { - return find(basic_string_view(s), pos); - } - - // rfind(), 4x: - - nssv_constexpr14 size_type - rfind(basic_string_view v, - size_type pos = npos) const nssv_noexcept // (1) - { - if (size() < v.size()) { - return npos; - } - - if (v.empty()) { - return (std::min)(size(), pos); - } - - const_iterator last = - cbegin() + (std::min)(size() - v.size(), pos) + v.size(); - const_iterator result = - std::find_end(cbegin(), last, v.cbegin(), v.cend(), Traits::eq); - - return result != last ? size_type(result - cbegin()) : npos; - } - - nssv_constexpr14 size_type - rfind(CharT c, size_type pos = npos) const nssv_noexcept // (2) - { - return rfind(basic_string_view(&c, 1), pos); - } - - nssv_constexpr14 size_type rfind(CharT const *s, - size_type pos, - size_type n) const // (3) - { - return rfind(basic_string_view(s, n), pos); - } - - nssv_constexpr14 size_type rfind(CharT const *s, - size_type pos = npos) const // (4) - { - return rfind(basic_string_view(s), pos); - } - - // find_first_of(), 4x: - - nssv_constexpr size_type find_first_of( - basic_string_view v, size_type pos = 0) const nssv_noexcept // (1) - { - return pos >= size() ? npos - : to_pos(std::find_first_of(cbegin() + pos, - cend(), - v.cbegin(), - v.cend(), - Traits::eq)); - } - - nssv_constexpr size_type - find_first_of(CharT c, size_type pos = 0) const nssv_noexcept // (2) - { - return find_first_of(basic_string_view(&c, 1), pos); - } - - nssv_constexpr size_type find_first_of(CharT const *s, - size_type pos, - size_type n) const // (3) - { - return find_first_of(basic_string_view(s, n), pos); - } - - nssv_constexpr size_type find_first_of(CharT const *s, - size_type pos = 0) const // (4) - { - return find_first_of(basic_string_view(s), pos); - } - - // find_last_of(), 4x: - - nssv_constexpr size_type - find_last_of(basic_string_view v, - size_type pos = npos) const nssv_noexcept // (1) - { - return empty() - ? npos - : pos >= size() - ? find_last_of(v, size() - 1) - : to_pos(std::find_first_of( - const_reverse_iterator(cbegin() + pos + 1), - crend(), - v.cbegin(), - v.cend(), - Traits::eq)); - } - - nssv_constexpr size_type - find_last_of(CharT c, size_type pos = npos) const nssv_noexcept // (2) - { - return find_last_of(basic_string_view(&c, 1), pos); - } - - nssv_constexpr size_type find_last_of(CharT const *s, - size_type pos, - size_type count) const // (3) - { - return find_last_of(basic_string_view(s, count), pos); - } - - nssv_constexpr size_type - find_last_of(CharT const *s, size_type pos = npos) const // (4) - { - return find_last_of(basic_string_view(s), pos); - } - - // find_first_not_of(), 4x: - - nssv_constexpr size_type find_first_not_of( - basic_string_view v, size_type pos = 0) const nssv_noexcept // (1) - { - return pos >= size() ? npos - : to_pos(std::find_if( - cbegin() + pos, cend(), not_in_view(v))); - } - - nssv_constexpr size_type find_first_not_of( - CharT c, size_type pos = 0) const nssv_noexcept // (2) - { - return find_first_not_of(basic_string_view(&c, 1), pos); - } - - nssv_constexpr size_type find_first_not_of( - CharT const *s, size_type pos, size_type count) const // (3) - { - return find_first_not_of(basic_string_view(s, count), pos); - } - - nssv_constexpr size_type - find_first_not_of(CharT const *s, size_type pos = 0) const // (4) - { - return find_first_not_of(basic_string_view(s), pos); - } - - // find_last_not_of(), 4x: - - nssv_constexpr size_type - find_last_not_of(basic_string_view v, - size_type pos = npos) const nssv_noexcept // (1) - { - return empty() - ? npos - : pos >= size() - ? find_last_not_of(v, size() - 1) - : to_pos(std::find_if( - const_reverse_iterator(cbegin() + pos + 1), - crend(), - not_in_view(v))); - } - - nssv_constexpr size_type find_last_not_of( - CharT c, size_type pos = npos) const nssv_noexcept // (2) - { - return find_last_not_of(basic_string_view(&c, 1), pos); - } - - nssv_constexpr size_type find_last_not_of(CharT const *s, - size_type pos, - size_type count) const // (3) - { - return find_last_not_of(basic_string_view(s, count), pos); - } - - nssv_constexpr size_type - find_last_not_of(CharT const *s, size_type pos = npos) const // (4) - { - return find_last_not_of(basic_string_view(s), pos); - } - -// Constants: - -#if nssv_CPP17_OR_GREATER - static nssv_constexpr size_type npos = size_type(-1); -#elif nssv_CPP11_OR_GREATER - enum : size_type { npos = size_type(-1) }; -#else - enum { npos = size_type(-1) }; -#endif - - private: - struct not_in_view { - const basic_string_view v; - - nssv_constexpr explicit not_in_view(basic_string_view v_) : v(v_) {} - - nssv_constexpr bool operator()(CharT c) const { - return npos == v.find_first_of(c); - } - }; - - nssv_constexpr size_type to_pos(const_iterator it) const { - return it == cend() ? npos : size_type(it - cbegin()); - } - - nssv_constexpr size_type to_pos(const_reverse_iterator it) const { - return it == crend() ? npos : size_type(crend() - it - 1); - } - - nssv_constexpr const_reference data_at(size_type pos) const { -#if nssv_BETWEEN(nssv_COMPILER_GNUC_VERSION, 1, 500) - return data_[pos]; -#else - return assert(pos < size()), data_[pos]; -#endif - } - - private: - const_pointer data_; - size_type size_; - - public: -#if nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS - - template - basic_string_view(std::basic_string const &s) - nssv_noexcept : data_(s.data()), - size_(s.size()) {} - -#if nssv_HAVE_EXPLICIT_CONVERSION - - template - explicit operator std::basic_string() const { - return to_string(Allocator()); - } - -#endif // nssv_HAVE_EXPLICIT_CONVERSION - -#if nssv_CPP11_OR_GREATER - - template > - std::basic_string to_string( - Allocator const &a = Allocator()) const { - return std::basic_string( - begin(), end(), a); - } - -#else - - std::basic_string to_string() const { - return std::basic_string(begin(), end()); - } - - template - std::basic_string to_string( - Allocator const &a) const { - return std::basic_string( - begin(), end(), a); - } - -#endif // nssv_CPP11_OR_GREATER - -#endif // nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS - }; - - // - // Non-member functions: - // - - // 24.4.3 Non-member comparison functions: - // lexicographically compare two string views (function template): - - template - nssv_constexpr bool operator==(basic_string_view lhs, - basic_string_view rhs) - nssv_noexcept { - return lhs.size() == rhs.size() && lhs.compare(rhs) == 0; - } - - template - nssv_constexpr bool operator!=(basic_string_view lhs, - basic_string_view rhs) - nssv_noexcept { - return !(lhs == rhs); - } - - template - nssv_constexpr bool operator<(basic_string_view lhs, - basic_string_view rhs) - nssv_noexcept { - return lhs.compare(rhs) < 0; - } - - template - nssv_constexpr bool operator<=(basic_string_view lhs, - basic_string_view rhs) - nssv_noexcept { - return lhs.compare(rhs) <= 0; - } - - template - nssv_constexpr bool operator>(basic_string_view lhs, - basic_string_view rhs) - nssv_noexcept { - return lhs.compare(rhs) > 0; - } - - template - nssv_constexpr bool operator>=(basic_string_view lhs, - basic_string_view rhs) - nssv_noexcept { - return lhs.compare(rhs) >= 0; - } - -// Let S be basic_string_view, and sv be an instance of S. -// Implementations shall provide sufficient additional overloads marked -// constexpr and noexcept so that an object t with an implicit conversion -// to S can be compared according to Table 67. - -#if !nssv_CPP11_OR_GREATER || nssv_BETWEEN(nssv_COMPILER_MSVC_VERSION, 100, 141) - - // accommodate for older compilers: - - // == - - template - nssv_constexpr bool operator==(basic_string_view lhs, - CharT const *rhs) nssv_noexcept { - return lhs.size() == detail::length(rhs) && lhs.compare(rhs) == 0; - } - - template - nssv_constexpr bool operator==( - CharT const *lhs, basic_string_view rhs) nssv_noexcept { - return detail::length(lhs) == rhs.size() && rhs.compare(lhs) == 0; - } - - template - nssv_constexpr bool operator==(basic_string_view lhs, - std::basic_string rhs) - nssv_noexcept { - return lhs.size() == rhs.size() && lhs.compare(rhs) == 0; - } - - template - nssv_constexpr bool operator==(std::basic_string rhs, - basic_string_view lhs) - nssv_noexcept { - return lhs.size() == rhs.size() && lhs.compare(rhs) == 0; - } - - // != - - template - nssv_constexpr bool operator!=(basic_string_view lhs, - CharT const *rhs) nssv_noexcept { - return !(lhs == rhs); - } - - template - nssv_constexpr bool operator!=( - CharT const *lhs, basic_string_view rhs) nssv_noexcept { - return !(lhs == rhs); - } - - template - nssv_constexpr bool operator!=(basic_string_view lhs, - std::basic_string rhs) - nssv_noexcept { - return !(lhs == rhs); - } - - template - nssv_constexpr bool operator!=(std::basic_string rhs, - basic_string_view lhs) - nssv_noexcept { - return !(lhs == rhs); - } - - // < - - template - nssv_constexpr bool operator<(basic_string_view lhs, - CharT const *rhs) nssv_noexcept { - return lhs.compare(rhs) < 0; - } - - template - nssv_constexpr bool operator<( - CharT const *lhs, basic_string_view rhs) nssv_noexcept { - return rhs.compare(lhs) > 0; - } - - template - nssv_constexpr bool operator<(basic_string_view lhs, - std::basic_string rhs) - nssv_noexcept { - return lhs.compare(rhs) < 0; - } - - template - nssv_constexpr bool operator<(std::basic_string rhs, - basic_string_view lhs) - nssv_noexcept { - return rhs.compare(lhs) > 0; - } - - // <= - - template - nssv_constexpr bool operator<=(basic_string_view lhs, - CharT const *rhs) nssv_noexcept { - return lhs.compare(rhs) <= 0; - } - - template - nssv_constexpr bool operator<=( - CharT const *lhs, basic_string_view rhs) nssv_noexcept { - return rhs.compare(lhs) >= 0; - } - - template - nssv_constexpr bool operator<=(basic_string_view lhs, - std::basic_string rhs) - nssv_noexcept { - return lhs.compare(rhs) <= 0; - } - - template - nssv_constexpr bool operator<=(std::basic_string rhs, - basic_string_view lhs) - nssv_noexcept { - return rhs.compare(lhs) >= 0; - } - - // > - - template - nssv_constexpr bool operator>(basic_string_view lhs, - CharT const *rhs) nssv_noexcept { - return lhs.compare(rhs) > 0; - } - - template - nssv_constexpr bool operator>( - CharT const *lhs, basic_string_view rhs) nssv_noexcept { - return rhs.compare(lhs) < 0; - } - - template - nssv_constexpr bool operator>(basic_string_view lhs, - std::basic_string rhs) - nssv_noexcept { - return lhs.compare(rhs) > 0; - } - - template - nssv_constexpr bool operator>(std::basic_string rhs, - basic_string_view lhs) - nssv_noexcept { - return rhs.compare(lhs) < 0; - } - - // >= - - template - nssv_constexpr bool operator>=(basic_string_view lhs, - CharT const *rhs) nssv_noexcept { - return lhs.compare(rhs) >= 0; - } - - template - nssv_constexpr bool operator>=( - CharT const *lhs, basic_string_view rhs) nssv_noexcept { - return rhs.compare(lhs) <= 0; - } - - template - nssv_constexpr bool operator>=(basic_string_view lhs, - std::basic_string rhs) - nssv_noexcept { - return lhs.compare(rhs) >= 0; - } - - template - nssv_constexpr bool operator>=(std::basic_string rhs, - basic_string_view lhs) - nssv_noexcept { - return rhs.compare(lhs) <= 0; - } - -#else // newer compilers: - -#define nssv_BASIC_STRING_VIEW_I(T, U) \ - typename std::decay>::type - -#if defined(_MSC_VER) // issue 40 -#define nssv_MSVC_ORDER(x) , int = x -#else -#define nssv_MSVC_ORDER(x) /*, int=x*/ -#endif - - // == - - template - nssv_constexpr bool operator==(basic_string_view lhs, - nssv_BASIC_STRING_VIEW_I(CharT, Traits) - rhs) nssv_noexcept { - return lhs.size() == rhs.size() && lhs.compare(rhs) == 0; - } - - template - nssv_constexpr bool operator==(nssv_BASIC_STRING_VIEW_I(CharT, Traits) lhs, - basic_string_view rhs) - nssv_noexcept { - return lhs.size() == rhs.size() && lhs.compare(rhs) == 0; - } - - // != - - template - nssv_constexpr bool operator!=(basic_string_view lhs, - nssv_BASIC_STRING_VIEW_I(CharT, Traits) - rhs) nssv_noexcept { - return !(lhs == rhs); - } - - template - nssv_constexpr bool operator!=(nssv_BASIC_STRING_VIEW_I(CharT, Traits) lhs, - basic_string_view rhs) - nssv_noexcept { - return !(lhs == rhs); - } - - // < - - template - nssv_constexpr bool operator<(basic_string_view lhs, - nssv_BASIC_STRING_VIEW_I(CharT, Traits) - rhs) nssv_noexcept { - return lhs.compare(rhs) < 0; - } - - template - nssv_constexpr bool operator<(nssv_BASIC_STRING_VIEW_I(CharT, Traits) lhs, - basic_string_view rhs) - nssv_noexcept { - return lhs.compare(rhs) < 0; - } - - // <= - - template - nssv_constexpr bool operator<=(basic_string_view lhs, - nssv_BASIC_STRING_VIEW_I(CharT, Traits) - rhs) nssv_noexcept { - return lhs.compare(rhs) <= 0; - } - - template - nssv_constexpr bool operator<=(nssv_BASIC_STRING_VIEW_I(CharT, Traits) lhs, - basic_string_view rhs) - nssv_noexcept { - return lhs.compare(rhs) <= 0; - } - - // > - - template - nssv_constexpr bool operator>(basic_string_view lhs, - nssv_BASIC_STRING_VIEW_I(CharT, Traits) - rhs) nssv_noexcept { - return lhs.compare(rhs) > 0; - } - - template - nssv_constexpr bool operator>(nssv_BASIC_STRING_VIEW_I(CharT, Traits) lhs, - basic_string_view rhs) - nssv_noexcept { - return lhs.compare(rhs) > 0; - } - - // >= - - template - nssv_constexpr bool operator>=(basic_string_view lhs, - nssv_BASIC_STRING_VIEW_I(CharT, Traits) - rhs) nssv_noexcept { - return lhs.compare(rhs) >= 0; - } - - template - nssv_constexpr bool operator>=(nssv_BASIC_STRING_VIEW_I(CharT, Traits) lhs, - basic_string_view rhs) - nssv_noexcept { - return lhs.compare(rhs) >= 0; - } - -#undef nssv_MSVC_ORDER -#undef nssv_BASIC_STRING_VIEW_I - -#endif // compiler-dependent approach to comparisons - -// 24.4.4 Inserters and extractors: - -#if !nssv_CONFIG_NO_STREAM_INSERTION - - namespace detail { - - template - void write_padding(Stream &os, std::streamsize n) { - for (std::streamsize i = 0; i < n; ++i) os.rdbuf()->sputc(os.fill()); - } - - template - Stream &write_to_stream(Stream &os, View const &sv) { - typename Stream::sentry sentry(os); - - if (!os) return os; - - const std::streamsize length = - static_cast(sv.length()); - - // Whether, and how, to pad: - const bool pad = (length < os.width()); - const bool left_pad = - pad && - (os.flags() & std::ios_base::adjustfield) == std::ios_base::right; - - if (left_pad) write_padding(os, os.width() - length); - - // Write span characters: - os.rdbuf()->sputn(sv.begin(), length); - - if (pad && !left_pad) write_padding(os, os.width() - length); - - // Reset output stream width: - os.width(0); - - return os; - } - - } // namespace detail - - template - std::basic_ostream &operator<<( - std::basic_ostream &os, - basic_string_view sv) { - return detail::write_to_stream(os, sv); - } - -#endif // nssv_CONFIG_NO_STREAM_INSERTION - - // Several typedefs for common character types are provided: - - typedef basic_string_view string_view; - typedef basic_string_view wstring_view; -#if nssv_HAVE_WCHAR16_T - typedef basic_string_view u16string_view; - typedef basic_string_view u32string_view; -#endif - } -} // namespace nonstd::sv_lite - -// -// 24.4.6 Suffix for basic_string_view literals: -// - -#if nssv_HAVE_USER_DEFINED_LITERALS - -namespace nonstd { -nssv_inline_ns namespace literals { - nssv_inline_ns namespace string_view_literals { - -#if nssv_CONFIG_STD_SV_OPERATOR && nssv_HAVE_STD_DEFINED_LITERALS - - nssv_constexpr nonstd::sv_lite::string_view operator"" sv( - const char *str, size_t len) nssv_noexcept // (1) - { - return nonstd::sv_lite::string_view{str, len}; - } - - nssv_constexpr nonstd::sv_lite::u16string_view operator"" sv( - const char16_t *str, size_t len) nssv_noexcept // (2) - { - return nonstd::sv_lite::u16string_view{str, len}; - } - - nssv_constexpr nonstd::sv_lite::u32string_view operator"" sv( - const char32_t *str, size_t len) nssv_noexcept // (3) - { - return nonstd::sv_lite::u32string_view{str, len}; - } - - nssv_constexpr nonstd::sv_lite::wstring_view operator"" sv( - const wchar_t *str, size_t len) nssv_noexcept // (4) - { - return nonstd::sv_lite::wstring_view{str, len}; - } - -#endif // nssv_CONFIG_STD_SV_OPERATOR && nssv_HAVE_STD_DEFINED_LITERALS - -#if nssv_CONFIG_USR_SV_OPERATOR - - nssv_constexpr nonstd::sv_lite::string_view operator"" _sv( - const char *str, size_t len) nssv_noexcept // (1) - { - return nonstd::sv_lite::string_view{str, len}; - } - - nssv_constexpr nonstd::sv_lite::u16string_view operator"" _sv( - const char16_t *str, size_t len) nssv_noexcept // (2) - { - return nonstd::sv_lite::u16string_view{str, len}; - } - - nssv_constexpr nonstd::sv_lite::u32string_view operator"" _sv( - const char32_t *str, size_t len) nssv_noexcept // (3) - { - return nonstd::sv_lite::u32string_view{str, len}; - } - - nssv_constexpr nonstd::sv_lite::wstring_view operator"" _sv( - const wchar_t *str, size_t len) nssv_noexcept // (4) - { - return nonstd::sv_lite::wstring_view{str, len}; - } - -#endif // nssv_CONFIG_USR_SV_OPERATOR - } -} -} // namespace nonstd::literals::string_view_literals - -#endif - -// -// Extensions for std::string: -// - -#if nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS - -namespace nonstd { -namespace sv_lite { - -// Exclude MSVC 14 (19.00): it yields ambiguous to_string(): - -#if nssv_CPP11_OR_GREATER && nssv_COMPILER_MSVC_VERSION != 140 - -template > -std::basic_string to_string( - basic_string_view v, Allocator const &a = Allocator()) { - return std::basic_string(v.begin(), v.end(), a); -} - -#else - -template -std::basic_string to_string(basic_string_view v) { - return std::basic_string(v.begin(), v.end()); -} - -template -std::basic_string to_string( - basic_string_view v, Allocator const &a) { - return std::basic_string(v.begin(), v.end(), a); -} - -#endif // nssv_CPP11_OR_GREATER - -template -basic_string_view to_string_view( - std::basic_string const &s) { - return basic_string_view(s.data(), s.size()); -} -} -} // namespace nonstd::sv_lite - -#endif // nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS - -// -// make types and algorithms available in namespace nonstd: -// - -namespace nonstd { - -using sv_lite::basic_string_view; -using sv_lite::string_view; -using sv_lite::wstring_view; - -#if nssv_HAVE_WCHAR16_T -using sv_lite::u16string_view; -#endif -#if nssv_HAVE_WCHAR32_T -using sv_lite::u32string_view; -#endif - -// literal "sv" - -using sv_lite::operator==; -using sv_lite::operator!=; -using sv_lite::operator<; -using sv_lite::operator<=; -using sv_lite::operator>; -using sv_lite::operator>=; - -#if !nssv_CONFIG_NO_STREAM_INSERTION -using sv_lite::operator<<; -#endif - -#if nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS -using sv_lite::to_string; -using sv_lite::to_string_view; -#endif - -} // namespace nonstd - -// 24.4.5 Hash support (C++11): - -// Note: The hash value of a string view object is equal to the hash value of -// the corresponding string object. - -#if nssv_HAVE_STD_HASH - -#include - -namespace std { - -template <> -struct hash { - public: - std::size_t operator()(nonstd::string_view v) const nssv_noexcept { - return std::hash()(std::string(v.data(), v.size())); - } -}; - -template <> -struct hash { - public: - std::size_t operator()(nonstd::wstring_view v) const nssv_noexcept { - return std::hash()(std::wstring(v.data(), v.size())); - } -}; - -template <> -struct hash { - public: - std::size_t operator()(nonstd::u16string_view v) const nssv_noexcept { - return std::hash()(std::u16string(v.data(), v.size())); - } -}; - -template <> -struct hash { - public: - std::size_t operator()(nonstd::u32string_view v) const nssv_noexcept { - return std::hash()(std::u32string(v.data(), v.size())); - } -}; - -} // namespace std - -#endif // nssv_HAVE_STD_HASH - -nssv_RESTORE_WARNINGS() - -#endif // nssv_HAVE_STD_STRING_VIEW -#endif // NONSTD_SV_LITE_H_INCLUDED -/* end file include/simdjson/nonstd/string_view.hpp */ -SIMDJSON_POP_DISABLE_WARNINGS - -namespace std { -using string_view = nonstd::string_view; -} -#endif // SIMDJSON_HAS_STRING_VIEW -#undef SIMDJSON_HAS_STRING_VIEW // We are not going to need this macro anymore. - -/// If EXPR is an error, returns it. -#define SIMDJSON_TRY(EXPR) \ - { \ - auto _err = (EXPR); \ - if (_err) { \ - return _err; \ - } \ - } - -#ifndef SIMDJSON_DEVELOPMENT_CHECKS -#ifndef NDEBUG -#define SIMDJSON_DEVELOPMENT_CHECKS -#endif -#endif - -// The SIMDJSON_CHECK_EOF macro is a feature flag for the "don't require -// padding" -// feature. - -#if SIMDJSON_CPLUSPLUS17 -// if we have C++, then fallthrough is a default attribute -#define simdjson_fallthrough [[fallthrough]] -// check if we have __attribute__ support -#elif defined(__has_attribute) -// check if we have the __fallthrough__ attribute -#if __has_attribute(__fallthrough__) -// we are good to go: -#define simdjson_fallthrough __attribute__((__fallthrough__)) -#endif // __has_attribute(__fallthrough__) -#endif // SIMDJSON_CPLUSPLUS17 -// on some systems, we simply do not have support for fallthrough, so use a -// default: -#ifndef simdjson_fallthrough -#define simdjson_fallthrough \ - do { \ - } while (0) /* fallthrough */ -#endif // simdjson_fallthrough - -#endif // SIMDJSON_COMMON_DEFS_H -/* end file include/simdjson/common_defs.h */ - -SIMDJSON_PUSH_DISABLE_WARNINGS -SIMDJSON_DISABLE_UNDESIRED_WARNINGS - -// Public API -/* begin file include/simdjson/error.h */ -#ifndef SIMDJSON_ERROR_H -#define SIMDJSON_ERROR_H - -#include - -namespace simdjson { - -/** - * All possible errors returned by simdjson. These error codes are subject to - * change - * and not all simdjson kernel returns the same error code given the same input: - * it is not - * well defined which error a given input should produce. - * - * Only SUCCESS evaluates to false as a Boolean. All other error codes will - * evaluate - * to true as a Boolean. - */ -enum error_code { - SUCCESS = 0, ///< No error - CAPACITY, ///< This parser can't support a document that big - MEMALLOC, ///< Error allocating memory, most likely out of memory - TAPE_ERROR, ///< Something went wrong while writing to the tape (stage 2), - /// this is a generic error - DEPTH_ERROR, ///< Your document exceeds the user-specified depth limitation - STRING_ERROR, ///< Problem while parsing a string - T_ATOM_ERROR, ///< Problem while parsing an atom starting with the letter - ///'t' - F_ATOM_ERROR, ///< Problem while parsing an atom starting with the letter - ///'f' - N_ATOM_ERROR, ///< Problem while parsing an atom starting with the letter - ///'n' - NUMBER_ERROR, ///< Problem while parsing a number - UTF8_ERROR, ///< the input is not valid UTF-8 - UNINITIALIZED, ///< unknown error, or uninitialized document - EMPTY, ///< no structural element found - UNESCAPED_CHARS, ///< found unescaped characters in a string. - UNCLOSED_STRING, ///< missing quote at the end - UNSUPPORTED_ARCHITECTURE, ///< unsupported architecture - INCORRECT_TYPE, ///< JSON element has a different type than user expected - NUMBER_OUT_OF_RANGE, ///< JSON number does not fit in 64 bits - INDEX_OUT_OF_BOUNDS, ///< JSON array index too large - NO_SUCH_FIELD, ///< JSON field not found in object - IO_ERROR, ///< Error reading a file - INVALID_JSON_POINTER, ///< Invalid JSON pointer reference - INVALID_URI_FRAGMENT, ///< Invalid URI fragment - UNEXPECTED_ERROR, ///< indicative of a bug in simdjson - PARSER_IN_USE, ///< parser is already in use. - OUT_OF_ORDER_ITERATION, ///< tried to iterate an array or object out of - /// order - INSUFFICIENT_PADDING, ///< The JSON doesn't have enough padding for - /// simdjson to safely parse it. - INCOMPLETE_ARRAY_OR_OBJECT, ///< The document ends early. - SCALAR_DOCUMENT_AS_VALUE, ///< A scalar document is treated as a value. - OUT_OF_BOUNDS, ///< Attempted to access location outside of document. - NUM_ERROR_CODES -}; - -/** - * Get the error message for the given error code. - * - * dom::parser parser; - * dom::element doc; - * auto error = parser.parse("foo",3).get(doc); - * if (error) { printf("Error: %s\n", error_message(error)); } - * - * @return The error message. - */ -inline const char *error_message(error_code error) noexcept; - -/** - * Write the error message to the output stream - */ -inline std::ostream &operator<<(std::ostream &out, error_code error) noexcept; - -/** - * Exception thrown when an exception-supporting simdjson method is called - */ -struct simdjson_error : public std::exception { - /** - * Create an exception from a simdjson error code. - * @param error The error code - */ - simdjson_error(error_code error) noexcept : _error{error} {} - /** The error message */ - const char *what() const noexcept { return error_message(error()); } - /** The error code */ - error_code error() const noexcept { return _error; } - - private: - /** The error code that was used */ - error_code _error; -}; - -namespace internal { - -/** - * The result of a simdjson operation that could fail. - * - * Gives the option of reading error codes, or throwing an exception by casting - * to the desired result. - * - * This is a base class for implementations that want to add functions to the - * result type for - * chaining. - * - * Override like: - * - * struct simdjson_result : public internal::simdjson_result_base { - * simdjson_result() noexcept : internal::simdjson_result_base() {} - * simdjson_result(error_code error) noexcept : - * internal::simdjson_result_base(error) {} - * simdjson_result(T &&value) noexcept : - * internal::simdjson_result_base(std::forward(value)) {} - * simdjson_result(T &&value, error_code error) noexcept : - * internal::simdjson_result_base(value, error) {} - * // Your extra methods here - * } - * - * Then any method returning simdjson_result will be chainable with your - * methods. - */ -template -struct simdjson_result_base : protected std::pair { - /** - * Create a new empty result with error = UNINITIALIZED. - */ - simdjson_really_inline simdjson_result_base() noexcept; - - /** - * Create a new error result. - */ - simdjson_really_inline simdjson_result_base(error_code error) noexcept; - - /** - * Create a new successful result. - */ - simdjson_really_inline simdjson_result_base(T &&value) noexcept; - - /** - * Create a new result with both things (use if you don't want to branch - * when creating the result). - */ - simdjson_really_inline simdjson_result_base(T &&value, - error_code error) noexcept; - - /** - * Move the value and the error to the provided variables. - * - * @param value The variable to assign the value to. May not be set if there - * is an error. - * @param error The variable to assign the error to. Set to SUCCESS if there - * is no error. - */ - simdjson_really_inline void tie(T &value, error_code &error) && noexcept; - - /** - * Move the value to the provided variable. - * - * @param value The variable to assign the value to. May not be set if there - * is an error. - */ - simdjson_really_inline error_code get(T &value) && noexcept; - - /** - * Move the value to the provided variable. - * - * @param value The variable to assign the value to. May not be set if there - * is an error. - */ - simdjson_really_inline const T &value(error_code &error) const &noexcept; - - /** - * The error. - */ - simdjson_really_inline error_code error() const noexcept; - -#if SIMDJSON_EXCEPTIONS - - /** - * Get the result value. - * - * @throw simdjson_error if there was an error. - */ - simdjson_really_inline T &value() & noexcept(false); - - /** - * Take the result value (move it). - * - * @throw simdjson_error if there was an error. - */ - simdjson_really_inline T &&value() && noexcept(false); - - /** - * Take the result value (move it). - * - * @throw simdjson_error if there was an error. - */ - simdjson_really_inline T &&take_value() && noexcept(false); - - /** - * Cast to the value (will throw on error). - * - * @throw simdjson_error if there was an error. - */ - simdjson_really_inline operator T &&() && noexcept(false); -#endif // SIMDJSON_EXCEPTIONS - - /** - * Get the result value. This function is safe if and only - * the error() method returns a value that evaluates to false. - */ - simdjson_really_inline const T &value_unsafe() const &noexcept; - - /** - * Take the result value (move it). This function is safe if and only - * the error() method returns a value that evaluates to false. - */ - simdjson_really_inline T &&value_unsafe() && noexcept; - -}; // struct simdjson_result_base - -} // namespace internal - -/** - * The result of a simdjson operation that could fail. - * - * Gives the option of reading error codes, or throwing an exception by casting - * to the desired result. - */ -template -struct simdjson_result : public internal::simdjson_result_base { - /** - * @private Create a new empty result with error = UNINITIALIZED. - */ - simdjson_really_inline simdjson_result() noexcept; - /** - * @private Create a new error result. - */ - simdjson_really_inline simdjson_result(T &&value) noexcept; - /** - * @private Create a new successful result. - */ - simdjson_really_inline simdjson_result(error_code error_code) noexcept; - /** - * @private Create a new result with both things (use if you don't want to - * branch when creating the result). - */ - simdjson_really_inline simdjson_result(T &&value, - error_code error) noexcept; - - /** - * Move the value and the error to the provided variables. - * - * @param value The variable to assign the value to. May not be set if there - * is an error. - * @param error The variable to assign the error to. Set to SUCCESS if there - * is no error. - */ - simdjson_really_inline void tie(T &value, error_code &error) && noexcept; - - /** - * Move the value to the provided variable. - * - * @param value The variable to assign the value to. May not be set if there - * is an error. - */ - simdjson_warn_unused simdjson_really_inline error_code get(T &value) && - noexcept; - - /** - * The error. - */ - simdjson_really_inline error_code error() const noexcept; - -#if SIMDJSON_EXCEPTIONS - - /** - * Get the result value. - * - * @throw simdjson_error if there was an error. - */ - simdjson_really_inline T &value() & noexcept(false); - - /** - * Take the result value (move it). - * - * @throw simdjson_error if there was an error. - */ - simdjson_really_inline T &&value() && noexcept(false); - - /** - * Take the result value (move it). - * - * @throw simdjson_error if there was an error. - */ - simdjson_really_inline T &&take_value() && noexcept(false); - - /** - * Cast to the value (will throw on error). - * - * @throw simdjson_error if there was an error. - */ - simdjson_really_inline operator T &&() && noexcept(false); -#endif // SIMDJSON_EXCEPTIONS - - /** - * Get the result value. This function is safe if and only - * the error() method returns a value that evaluates to false. - */ - simdjson_really_inline const T &value_unsafe() const &noexcept; - - /** - * Take the result value (move it). This function is safe if and only - * the error() method returns a value that evaluates to false. - */ - simdjson_really_inline T &&value_unsafe() && noexcept; - -}; // struct simdjson_result - -#if SIMDJSON_EXCEPTIONS - -template -inline std::ostream &operator<<(std::ostream &out, simdjson_result value) { - return out << value.value(); -} -#endif // SIMDJSON_EXCEPTIONS - -#ifndef SIMDJSON_DISABLE_DEPRECATED_API -/** - * @deprecated This is an alias and will be removed, use error_code instead - */ -using ErrorValues[[deprecated( - "This is an alias and will be removed, use error_code instead")]] = - error_code; - -/** - * @deprecated Error codes should be stored and returned as `error_code`, use - * `error_message()` instead. - */ -[ - [deprecated("Error codes should be stored and returned as `error_code`, " - "use `error_message()` instead.")]] inline const std::string -error_message(int error) noexcept; -#endif // SIMDJSON_DISABLE_DEPRECATED_API -} // namespace simdjson - -#endif // SIMDJSON_ERROR_H -/* end file include/simdjson/error.h */ -/* begin file include/simdjson/minify.h */ -#ifndef SIMDJSON_MINIFY_H -#define SIMDJSON_MINIFY_H - -/* begin file include/simdjson/padded_string.h */ -#ifndef SIMDJSON_PADDED_STRING_H -#define SIMDJSON_PADDED_STRING_H - -#include -#include -#include -#include - -namespace simdjson { - -class padded_string_view; - -/** - * String with extra allocation for ease of use with parser::parse() - * - * This is a move-only class, it cannot be copied. - */ -struct padded_string final { - /** - * Create a new, empty padded string. - */ - explicit inline padded_string() noexcept; - /** - * Create a new padded string buffer. - * - * @param length the size of the string. - */ - explicit inline padded_string(size_t length) noexcept; - /** - * Create a new padded string by copying the given input. - * - * @param data the buffer to copy - * @param length the number of bytes to copy - */ - explicit inline padded_string(const char *data, size_t length) noexcept; - /** - * Create a new padded string by copying the given input. - * - * @param str_ the string to copy - */ - inline padded_string(const std::string &str_) noexcept; - /** - * Create a new padded string by copying the given input. - * - * @param sv_ the string to copy - */ - inline padded_string(std::string_view sv_) noexcept; - /** - * Move one padded string into another. - * - * The original padded string will be reduced to zero capacity. - * - * @param o the string to move. - */ - inline padded_string(padded_string &&o) noexcept; - /** - * Move one padded string into another. - * - * The original padded string will be reduced to zero capacity. - * - * @param o the string to move. - */ - inline padded_string &operator=(padded_string &&o) noexcept; - inline void swap(padded_string &o) noexcept; - ~padded_string() noexcept; - - /** - * The length of the string. - * - * Does not include padding. - */ - size_t size() const noexcept; - - /** - * The length of the string. - * - * Does not include padding. - */ - size_t length() const noexcept; - - /** - * The string data. - **/ - const char *data() const noexcept; - const uint8_t *u8data() const noexcept { - return static_cast( - static_cast(data_ptr)); - } - - /** - * The string data. - **/ - char *data() noexcept; - - /** - * Create a std::string_view with the same content. - */ - operator std::string_view() const; - - /** - * Create a padded_string_view with the same content. - */ - operator padded_string_view() const noexcept; - - /** - * Load this padded string from a file. - * - * @return IO_ERROR on error. Be mindful that on some 32-bit systems, - * the file size might be limited to 2 GB. - * - * @param path the path to the file. - **/ - inline static simdjson_result load( - std::string_view path) noexcept; - - private: - padded_string &operator=(const padded_string &o) = delete; - padded_string(const padded_string &o) = delete; - - size_t viable_size{0}; - char *data_ptr{nullptr}; - -}; // padded_string - -/** - * Send padded_string instance to an output stream. - * - * @param out The output stream. - * @param s The padded_string instance. - * @throw if there is an error with the underlying output stream. simdjson - * itself will not throw. - */ -inline std::ostream &operator<<(std::ostream &out, const padded_string &s) { - return out << s.data(); -} - -#if SIMDJSON_EXCEPTIONS -/** - * Send padded_string instance to an output stream. - * - * @param out The output stream. - * @param s The padded_string instance. - * @throw simdjson_error if the result being printed has an error. If there is - * an error with the - * underlying output stream, that error will be propagated - * (simdjson_error will not be - * thrown). - */ -inline std::ostream &operator<<( - std::ostream &out, simdjson_result &s) noexcept(false) { - return out << s.value(); -} -#endif - -} // namespace simdjson - -// This is deliberately outside of simdjson so that people get it without having -// to use the namespace -inline simdjson::padded_string operator"" _padded(const char *str, size_t len) { - return simdjson::padded_string(str, len); -} - -namespace simdjson { -namespace internal { - -// The allocate_padded_buffer function is a low-level function to allocate -// memory -// with padding so we can read past the "length" bytes safely. It is used by -// the padded_string class automatically. It returns nullptr in case -// of error: the caller should check for a null pointer. -// The length parameter is the maximum size in bytes of the string. -// The caller is responsible to free the memory (e.g., delete[] (...)). -inline char *allocate_padded_buffer(size_t length) noexcept; - -} // namespace internal -} // namespace simdjson - -#endif // SIMDJSON_PADDED_STRING_H -/* end file include/simdjson/padded_string.h */ -#include -#include -#include - -namespace simdjson { - - -/** - * - * Minify the input string assuming that it represents a JSON string, does not - * parse or validate. - * This function is much faster than parsing a JSON string and then writing a - * minified version of it. - * However, it does not validate the input. It will merely return an error in - * simple cases (e.g., if - * there is a string that was never terminated). - * - * - * @param buf the json document to minify. - * @param len the length of the json document. - * @param dst the buffer to write the minified document to. *MUST* be allocated - * up to len bytes. - * @param dst_len the number of bytes written. Output only. - * @return the error code, or SUCCESS if there was no error. - */ -simdjson_warn_unused error_code minify(const char *buf, - size_t len, - char *dst, - size_t &dst_len) noexcept; - -} // namespace simdjson - -#endif // SIMDJSON_MINIFY_H -/* end file include/simdjson/minify.h */ -/* begin file include/simdjson/padded_string_view.h */ -#ifndef SIMDJSON_PADDED_STRING_VIEW_H -#define SIMDJSON_PADDED_STRING_VIEW_H - - -#include -#include -#include -#include - -namespace simdjson { - -/** - * User-provided string that promises it has extra padded bytes at the end for - * use with parser::parse(). - */ -class padded_string_view : public std::string_view { - private: - size_t _capacity; - - public: - /** Create an empty padded_string_view. */ - inline padded_string_view() noexcept = default; - - /** - * Promise the given buffer has at least SIMDJSON_PADDING extra bytes - * allocated to it. - * - * @param s The string. - * @param len The length of the string (not including padding). - * @param capacity The allocated length of the string, including padding. - */ - explicit inline padded_string_view(const char *s, - size_t len, - size_t capacity) noexcept; - /** overload explicit inline padded_string_view(const char* s, size_t len) - * noexcept */ - explicit inline padded_string_view(const uint8_t *s, - size_t len, - size_t capacity) noexcept; - - /** - * Promise the given string has at least SIMDJSON_PADDING extra bytes - * allocated to it. - * - * The capacity of the string will be used to determine its padding. - * - * @param s The string. - */ - explicit inline padded_string_view(const std::string &s) noexcept; - - /** - * Promise the given string_view has at least SIMDJSON_PADDING extra bytes - * allocated to it. - * - * @param s The string. - * @param capacity The allocated length of the string, including padding. - */ - explicit inline padded_string_view(std::string_view s, - size_t capacity) noexcept; - - /** The number of allocated bytes. */ - inline size_t capacity() const noexcept; - - /** The amount of padding on the string (capacity() - length()) */ - inline size_t padding() const noexcept; - -}; // padded_string_view - -#if SIMDJSON_EXCEPTIONS -/** - * Send padded_string instance to an output stream. - * - * @param out The output stream. - * @param s The padded_string_view. - * @throw simdjson_error if the result being printed has an error. If there is - * an error with the - * underlying output stream, that error will be propagated - * (simdjson_error will not be - * thrown). - */ -inline std::ostream &operator<<( - std::ostream &out, simdjson_result &s) noexcept(false) { - return out << s.value(); -} -#endif - -} // namespace simdjson - -#endif // SIMDJSON_PADDED_STRING_VIEW_H -/* end file include/simdjson/padded_string_view.h */ -/* begin file include/simdjson/implementation.h */ -#ifndef SIMDJSON_IMPLEMENTATION_H -#define SIMDJSON_IMPLEMENTATION_H - -/* begin file include/simdjson/internal/dom_parser_implementation.h */ -#ifndef SIMDJSON_INTERNAL_DOM_PARSER_IMPLEMENTATION_H -#define SIMDJSON_INTERNAL_DOM_PARSER_IMPLEMENTATION_H - -#include - -namespace simdjson { - -namespace dom { -class document; -} // namespace dom - -/** -* This enum is used with the dom_parser_implementation::stage1 function. -* 1) The regular mode expects a fully formed JSON document. -* 2) The streaming_partial mode expects a possibly truncated -* input within a stream on JSON documents. -* 3) The stream_final mode allows us to truncate final -* unterminated strings. It is useful in conjunction with streaming_partial. -*/ -enum class stage1_mode { regular, streaming_partial, streaming_final }; - -/** - * Returns true if mode == streaming_partial or mode == streaming_final - */ -inline bool is_streaming(stage1_mode mode) { - // performance note: it is probably faster to check that mode is different - // from regular than checking that it is either streaming_partial or - // streaming_final. - return (mode != stage1_mode::regular); - // return (mode == stage1_mode::streaming_partial || mode == - // stage1_mode::streaming_final); -} - - -namespace internal { - - -/** - * An implementation of simdjson's DOM parser for a particular CPU architecture. - * - * This class is expected to be accessed only by pointer, and never move in - * memory (though the - * pointer can move). - */ -class dom_parser_implementation { - public: - /** - * @private For internal implementation use - * - * Run a full JSON parse on a single document (stage1 + stage2). - * - * Guaranteed only to be called when capacity > document length. - * - * Overridden by each implementation. - * - * @param buf The json document to parse. *MUST* be allocated up to len + - * SIMDJSON_PADDING bytes. - * @param len The length of the json document. - * @return The error code, or SUCCESS if there was no error. - */ - simdjson_warn_unused virtual error_code parse( - const uint8_t *buf, size_t len, dom::document &doc) noexcept = 0; - - /** - * @private For internal implementation use - * - * Stage 1 of the document parser. - * - * Guaranteed only to be called when capacity > document length. - * - * Overridden by each implementation. - * - * @param buf The json document to parse. - * @param len The length of the json document. - * @param streaming Whether this is being called by parser::parse_many. - * @return The error code, or SUCCESS if there was no error. - */ - simdjson_warn_unused virtual error_code stage1( - const uint8_t *buf, size_t len, stage1_mode streaming) noexcept = 0; - - /** - * @private For internal implementation use - * - * Stage 2 of the document parser. - * - * Called after stage1(). - * - * Overridden by each implementation. - * - * @param doc The document to output to. - * @return The error code, or SUCCESS if there was no error. - */ - simdjson_warn_unused virtual error_code stage2( - dom::document &doc) noexcept = 0; - - /** - * @private For internal implementation use - * - * Stage 2 of the document parser for parser::parse_many. - * - * Guaranteed only to be called after stage1(). - * Overridden by each implementation. - * - * @param doc The document to output to. - * @return The error code, SUCCESS if there was no error, or EMPTY if all - * documents have been parsed. - */ - simdjson_warn_unused virtual error_code stage2_next( - dom::document &doc) noexcept = 0; - - /** - * Change the capacity of this parser. - * - * The capacity can never exceed SIMDJSON_MAXSIZE_BYTES (e.g., 4 GB) - * and an CAPACITY error is returned if it is attempted. - * - * Generally used for reallocation. - * - * @param capacity The new capacity. - * @param max_depth The new max_depth. - * @return The error code, or SUCCESS if there was no error. - */ - virtual error_code set_capacity(size_t capacity) noexcept = 0; - - /** - * Change the max depth of this parser. - * - * Generally used for reallocation. - * - * @param capacity The new capacity. - * @param max_depth The new max_depth. - * @return The error code, or SUCCESS if there was no error. - */ - virtual error_code set_max_depth(size_t max_depth) noexcept = 0; - - /** - * Deallocate this parser. - */ - virtual ~dom_parser_implementation() = default; - - /** Number of structural indices passed from stage 1 to stage 2 */ - uint32_t n_structural_indexes{0}; - /** Structural indices passed from stage 1 to stage 2 */ - std::unique_ptr structural_indexes{}; - /** Next structural index to parse */ - uint32_t next_structural_index{0}; - - /** - * The largest document this parser can support without reallocating. - * - * @return Current capacity, in bytes. - */ - simdjson_really_inline size_t capacity() const noexcept; - - /** - * The maximum level of nested object and arrays supported by this parser. - * - * @return Maximum depth, in bytes. - */ - simdjson_really_inline size_t max_depth() const noexcept; - - /** - * Ensure this parser has enough memory to process JSON documents up to - * `capacity` bytes in length - * and `max_depth` depth. - * - * @param capacity The new capacity. - * @param max_depth The new max_depth. Defaults to DEFAULT_MAX_DEPTH. - * @return The error, if there is one. - */ - simdjson_warn_unused inline error_code allocate(size_t capacity, - size_t max_depth) noexcept; - - protected: - /** - * The maximum document length this parser supports. - * - * Buffers are large enough to handle any document up to this length. - */ - size_t _capacity{0}; - - /** - * The maximum depth (number of nested objects and arrays) supported by this - * parser. - * - * Defaults to DEFAULT_MAX_DEPTH. - */ - size_t _max_depth{0}; - - // Declaring these so that subclasses can use them to implement their - // constructors. - simdjson_really_inline dom_parser_implementation() noexcept; - simdjson_really_inline dom_parser_implementation( - dom_parser_implementation &&other) noexcept; - simdjson_really_inline dom_parser_implementation &operator=( - dom_parser_implementation &&other) noexcept; - - simdjson_really_inline dom_parser_implementation( - const dom_parser_implementation &) noexcept = delete; - simdjson_really_inline dom_parser_implementation &operator=( - const dom_parser_implementation &other) noexcept = delete; -}; // class dom_parser_implementation - -simdjson_really_inline -dom_parser_implementation::dom_parser_implementation() noexcept = default; -simdjson_really_inline dom_parser_implementation::dom_parser_implementation( - dom_parser_implementation &&other) noexcept = default; -simdjson_really_inline dom_parser_implementation &dom_parser_implementation:: -operator=(dom_parser_implementation &&other) noexcept = default; - -simdjson_really_inline size_t dom_parser_implementation::capacity() const - noexcept { - return _capacity; -} - -simdjson_really_inline size_t dom_parser_implementation::max_depth() const - noexcept { - return _max_depth; -} - -simdjson_warn_unused inline error_code dom_parser_implementation::allocate( - size_t capacity, size_t max_depth) noexcept { - if (this->max_depth() != max_depth) { - error_code err = set_max_depth(max_depth); - if (err) { - return err; - } - } - if (_capacity != capacity) { - error_code err = set_capacity(capacity); - if (err) { - return err; - } - } - return SUCCESS; -} - -} // namespace internal -} // namespace simdjson - -#endif // SIMDJSON_INTERNAL_DOM_PARSER_IMPLEMENTATION_H -/* end file include/simdjson/internal/dom_parser_implementation.h */ -/* begin file include/simdjson/internal/isadetection.h */ -/* From -https://github.com/endorno/pytorch/blob/master/torch/lib/TH/generic/simd/simd.h -Highly modified. - -Copyright (c) 2016- Facebook, Inc (Adam Paszke) -Copyright (c) 2014- Facebook, Inc (Soumith Chintala) -Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) -Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) -Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) -Copyright (c) 2011-2013 NYU (Clement Farabet) -Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, -Iain Melvin, Jason Weston) Copyright (c) 2006 Idiap Research Institute -(Samy Bengio) Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, -Samy Bengio, Johnny Mariethoz) - -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - -3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories -America and IDIAP Research Institute nor the names of its contributors may be - used to endorse or promote products derived from this software without - specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. -*/ - -#ifndef SIMDJSON_INTERNAL_ISADETECTION_H -#define SIMDJSON_INTERNAL_ISADETECTION_H - -#include -#include -#if defined(_MSC_VER) -#include -#elif defined(HAVE_GCC_GET_CPUID) && defined(USE_GCC_GET_CPUID) -#include -#endif - -namespace simdjson { -namespace internal { - - -enum instruction_set { - DEFAULT = 0x0, - NEON = 0x1, - AVX2 = 0x4, - SSE42 = 0x8, - PCLMULQDQ = 0x10, - BMI1 = 0x20, - BMI2 = 0x40, - ALTIVEC = 0x80 -}; - -#if defined(__PPC64__) - -static inline uint32_t detect_supported_architectures() { - return instruction_set::ALTIVEC; -} - -#elif defined(__arm__) || defined(__aarch64__) // incl. armel, armhf, arm64 - -#if defined(__ARM_NEON) - -static inline uint32_t detect_supported_architectures() { - return instruction_set::NEON; -} - -#else // ARM without NEON - -static inline uint32_t detect_supported_architectures() { - return instruction_set::DEFAULT; -} - -#endif - -#elif defined(__x86_64__) || defined(_M_AMD64) // x64 - - -namespace { -// Can be found on Intel ISA Reference for CPUID -constexpr uint32_t cpuid_avx2_bit = - 1 << 5; ///< @private Bit 5 of EBX for EAX=0x7 -constexpr uint32_t cpuid_bmi1_bit = - 1 << 3; ///< @private bit 3 of EBX for EAX=0x7 -constexpr uint32_t cpuid_bmi2_bit = - 1 << 8; ///< @private bit 8 of EBX for EAX=0x7 -constexpr uint32_t cpuid_sse42_bit = - 1 << 20; ///< @private bit 20 of ECX for EAX=0x1 -constexpr uint32_t cpuid_pclmulqdq_bit = - 1 << 1; ///< @private bit 1 of ECX for EAX=0x1 -} - - -static inline void cpuid(uint32_t *eax, - uint32_t *ebx, - uint32_t *ecx, - uint32_t *edx) { -#if defined(_MSC_VER) - int cpu_info[4]; - __cpuid(cpu_info, *eax); - *eax = cpu_info[0]; - *ebx = cpu_info[1]; - *ecx = cpu_info[2]; - *edx = cpu_info[3]; -#elif defined(HAVE_GCC_GET_CPUID) && defined(USE_GCC_GET_CPUID) - uint32_t level = *eax; - __get_cpuid(level, eax, ebx, ecx, edx); -#else - uint32_t a = *eax, b, c = *ecx, d; - asm volatile("cpuid\n\t" : "+a"(a), "=b"(b), "+c"(c), "=d"(d)); - *eax = a; - *ebx = b; - *ecx = c; - *edx = d; -#endif -} - -static inline uint32_t detect_supported_architectures() { - uint32_t eax, ebx, ecx, edx; - uint32_t host_isa = 0x0; - - // ECX for EAX=0x7 - eax = 0x7; - ecx = 0x0; - cpuid(&eax, &ebx, &ecx, &edx); - if (ebx & cpuid_avx2_bit) { - host_isa |= instruction_set::AVX2; - } - if (ebx & cpuid_bmi1_bit) { - host_isa |= instruction_set::BMI1; - } - - if (ebx & cpuid_bmi2_bit) { - host_isa |= instruction_set::BMI2; - } - - // EBX for EAX=0x1 - eax = 0x1; - cpuid(&eax, &ebx, &ecx, &edx); - - if (ecx & cpuid_sse42_bit) { - host_isa |= instruction_set::SSE42; - } - - if (ecx & cpuid_pclmulqdq_bit) { - host_isa |= instruction_set::PCLMULQDQ; - } - - return host_isa; -} -#else // fallback - - -static inline uint32_t detect_supported_architectures() { - return instruction_set::DEFAULT; -} - - -#endif // end SIMD extension detection code - -} // namespace internal -} // namespace simdjson - -#endif // SIMDJSON_INTERNAL_ISADETECTION_H -/* end file include/simdjson/internal/isadetection.h */ -#include -#include -#include - -namespace simdjson { - -/** - * Validate the UTF-8 string. - * - * @param buf the string to validate. - * @param len the length of the string in bytes. - * @return true if the string is valid UTF-8. - */ -simdjson_warn_unused bool validate_utf8(const char *buf, size_t len) noexcept; - - -/** - * Validate the UTF-8 string. - * - * @param sv the string_view to validate. - * @return true if the string is valid UTF-8. - */ -simdjson_really_inline simdjson_warn_unused bool validate_utf8( - const std::string_view sv) noexcept { - return validate_utf8(sv.data(), sv.size()); -} - -/** - * Validate the UTF-8 string. - * - * @param p the string to validate. - * @return true if the string is valid UTF-8. - */ -simdjson_really_inline simdjson_warn_unused bool validate_utf8( - const std::string &s) noexcept { - return validate_utf8(s.data(), s.size()); -} - -namespace dom { -class document; -} // namespace dom - -/** - * An implementation of simdjson for a particular CPU architecture. - * - * Also used to maintain the currently active implementation. The active - * implementation is - * automatically initialized on first use to the most advanced implementation - * supported by the host. - */ -class implementation { - public: - /** - * The name of this implementation. - * - * const implementation *impl = simdjson::get_active_implementation(); - * cout << "simdjson is optimized for " << impl->name() << "(" << - * impl->description() << ")" << endl; - * - * @return the name of the implementation, e.g. "haswell", "westmere", - * "arm64" - */ - virtual const std::string &name() const { return _name; } - - /** - * The description of this implementation. - * - * const implementation *impl = simdjson::get_active_implementation(); - * cout << "simdjson is optimized for " << impl->name() << "(" << - * impl->description() << ")" << endl; - * - * @return the name of the implementation, e.g. "haswell", "westmere", - * "arm64" - */ - virtual const std::string &description() const { return _description; } - - /** - * The instruction sets this implementation is compiled against - * and the current CPU match. This function may poll the current CPU/system - * and should therefore not be called too often if performance is a concern. - * - * - * @return true if the implementation can be safely used on the current - * system (determined at runtime) - */ - bool supported_by_runtime_system() const; - - /** - * @private For internal implementation use - * - * The instruction sets this implementation is compiled against. - * - * @return a mask of all required `internal::instruction_set::` values - */ - virtual uint32_t required_instruction_sets() const { - return _required_instruction_sets; - }; - - /** - * @private For internal implementation use - * - * const implementation *impl = simdjson::get_active_implementation(); - * cout << "simdjson is optimized for " << impl->name() << "(" << - * impl->description() << ")" << endl; - * - * @param capacity The largest document that will be passed to the parser. - * @param max_depth The maximum JSON object/array nesting this parser is - * expected to handle. - * @param dst The place to put the resulting parser implementation. - * @return the name of the implementation, e.g. "haswell", "westmere", - * "arm64" - */ - virtual error_code create_dom_parser_implementation( - size_t capacity, - size_t max_depth, - std::unique_ptr &dst) const - noexcept = 0; - - /** - * @private For internal implementation use - * - * Minify the input string assuming that it represents a JSON string, does - * not parse or validate. - * - * Overridden by each implementation. - * - * @param buf the json document to minify. - * @param len the length of the json document. - * @param dst the buffer to write the minified document to. *MUST* be - * allocated up to len + SIMDJSON_PADDING bytes. - * @param dst_len the number of bytes written. Output only. - * @return the error code, or SUCCESS if there was no error. - */ - simdjson_warn_unused virtual error_code minify(const uint8_t *buf, - size_t len, - uint8_t *dst, - size_t &dst_len) const - noexcept = 0; - - - /** - * Validate the UTF-8 string. - * - * Overridden by each implementation. - * - * @param buf the string to validate. - * @param len the length of the string in bytes. - * @return true if and only if the string is valid UTF-8. - */ - simdjson_warn_unused virtual bool validate_utf8(const char *buf, - size_t len) const - noexcept = 0; - - protected: - /** @private Construct an implementation with the given name and - * description. For subclasses. */ - simdjson_really_inline implementation(std::string_view name, - std::string_view description, - uint32_t required_instruction_sets) - : _name(name), - _description(description), - _required_instruction_sets(required_instruction_sets) {} - virtual ~implementation() = default; - - private: - /** - * The name of this implementation. - */ - const std::string _name; - - /** - * The description of this implementation. - */ - const std::string _description; - - /** - * Instruction sets required for this implementation. - */ - const uint32_t _required_instruction_sets; -}; - -/** @private */ -namespace internal { - -/** - * The list of available implementations compiled into simdjson. - */ -class available_implementation_list { - public: - /** Get the list of available implementations compiled into simdjson */ - simdjson_really_inline available_implementation_list() {} - /** Number of implementations */ - size_t size() const noexcept; - /** STL const begin() iterator */ - const implementation *const *begin() const noexcept; - /** STL const end() iterator */ - const implementation *const *end() const noexcept; - - /** - * Get the implementation with the given name. - * - * Case sensitive. - * - * const implementation *impl = - * simdjson::get_available_implementations()["westmere"]; - * if (!impl) { exit(1); } - * if (!imp->supported_by_runtime_system()) { exit(1); } - * simdjson::get_active_implementation() = impl; - * - * @param name the implementation to find, e.g. "westmere", "haswell", - * "arm64" - * @return the implementation, or nullptr if the parse failed. - */ - const implementation *operator[](const std::string_view &name) const - noexcept { - for (const implementation *impl : *this) { - if (impl->name() == name) { - return impl; - } - } - return nullptr; - } - - /** - * Detect the most advanced implementation supported by the current host. - * - * This is used to initialize the implementation on startup. - * - * const implementation *impl = - * simdjson::available_implementation::detect_best_supported(); - * simdjson::get_active_implementation() = impl; - * - * @return the most advanced supported implementation for the current host, - * or an - * implementation that returns UNSUPPORTED_ARCHITECTURE if there is - * no supported - * implementation. Will never return nullptr. - */ - const implementation *detect_best_supported() const noexcept; -}; - -template -class atomic_ptr { - public: - atomic_ptr(T *_ptr) : ptr{_ptr} {} - - operator const T *() const { return ptr.load(); } - const T &operator*() const { return *ptr; } - const T *operator->() const { return ptr.load(); } - - operator T *() { return ptr.load(); } - T &operator*() { return *ptr; } - T *operator->() { return ptr.load(); } - atomic_ptr &operator=(T *_ptr) { - ptr = _ptr; - return *this; - } - - private: - std::atomic ptr; -}; - -} // namespace internal - -/** - * The list of available implementations compiled into simdjson. - */ -extern SIMDJSON_DLLIMPORTEXPORT const internal::available_implementation_list & -get_available_implementations(); - -/** - * The active implementation. - * - * Automatically initialized on first use to the most advanced implementation - * supported by this hardware. - */ -extern SIMDJSON_DLLIMPORTEXPORT internal::atomic_ptr - &get_active_implementation(); - -} // namespace simdjson - -#endif // SIMDJSON_IMPLEMENTATION_H -/* end file include/simdjson/implementation.h */ - -// Inline functions -/* begin file include/simdjson/error-inl.h */ -#ifndef SIMDJSON_INLINE_ERROR_H -#define SIMDJSON_INLINE_ERROR_H - -#include -#include -#include - -namespace simdjson { -namespace internal { -// We store the error code so we can validate the error message is associated -// with the right code -struct error_code_info { - error_code code; - const char *message; // do not use a fancy std::string where a simple C - // string will do (no alloc, no destructor) -}; -// These MUST match the codes in error_code. We check this constraint in -// basictests. -extern SIMDJSON_DLLIMPORTEXPORT const error_code_info error_codes[]; -} // namespace internal - - -inline const char *error_message(error_code error) noexcept { - // If you're using error_code, we're trusting you got it from the enum. - return internal::error_codes[int(error)].message; -} - -// deprecated function -#ifndef SIMDJSON_DISABLE_DEPRECATED_API -inline const std::string error_message(int error) noexcept { - if (error < 0 || error >= error_code::NUM_ERROR_CODES) { - return internal::error_codes[UNEXPECTED_ERROR].message; - } - return internal::error_codes[error].message; -} -#endif // SIMDJSON_DISABLE_DEPRECATED_API - -inline std::ostream &operator<<(std::ostream &out, error_code error) noexcept { - return out << error_message(error); -} - -namespace internal { - -// -// internal::simdjson_result_base inline implementation -// - -template - simdjson_really_inline void simdjson_result_base::tie( - T &value, error_code &error) && - noexcept { - error = this->second; - if (!error) { - value = std::forward>(*this).first; - } -} - -template - simdjson_warn_unused simdjson_really_inline error_code - simdjson_result_base::get(T &value) && - noexcept { - error_code error; - std::forward>(*this).tie(value, error); - return error; -} - -template -simdjson_really_inline error_code simdjson_result_base::error() const - noexcept { - return this->second; -} - -#if SIMDJSON_EXCEPTIONS - -template - simdjson_really_inline T &simdjson_result_base::value() & - noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return this->first; -} - -template - simdjson_really_inline T &&simdjson_result_base::value() && - noexcept(false) { - return std::forward>(*this).take_value(); -} - -template - simdjson_really_inline T &&simdjson_result_base::take_value() && - noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return std::forward(this->first); -} - -template - simdjson_really_inline simdjson_result_base::operator T &&() && - noexcept(false) { - return std::forward>(*this).take_value(); -} - -#endif // SIMDJSON_EXCEPTIONS - -template -simdjson_really_inline const T &simdjson_result_base::value_unsafe() - const &noexcept { - return this->first; -} - -template - simdjson_really_inline T &&simdjson_result_base::value_unsafe() && - noexcept { - return std::forward(this->first); -} - -template -simdjson_really_inline simdjson_result_base::simdjson_result_base( - T &&value, error_code error) noexcept - : std::pair(std::forward(value), error) {} -template -simdjson_really_inline simdjson_result_base::simdjson_result_base( - error_code error) noexcept : simdjson_result_base(T{}, error) {} -template -simdjson_really_inline simdjson_result_base::simdjson_result_base( - T &&value) noexcept - : simdjson_result_base(std::forward(value), SUCCESS) {} -template -simdjson_really_inline simdjson_result_base::simdjson_result_base() noexcept - : simdjson_result_base(T{}, UNINITIALIZED) {} - -} // namespace internal - -/// -/// simdjson_result inline implementation -/// - -template - simdjson_really_inline void simdjson_result::tie(T &value, - error_code &error) && - noexcept { - std::forward>(*this).tie(value, error); -} - -template - simdjson_warn_unused simdjson_really_inline error_code - simdjson_result::get(T &value) && - noexcept { - return std::forward>(*this).get(value); -} - -template -simdjson_really_inline error_code simdjson_result::error() const noexcept { - return internal::simdjson_result_base::error(); -} - -#if SIMDJSON_EXCEPTIONS - -template - simdjson_really_inline T &simdjson_result::value() & noexcept(false) { - return internal::simdjson_result_base::value(); -} - -template - simdjson_really_inline T &&simdjson_result::value() && noexcept(false) { - return std::forward>(*this).value(); -} - -template - simdjson_really_inline T &&simdjson_result::take_value() && - noexcept(false) { - return std::forward>(*this).take_value(); -} - -template - simdjson_really_inline simdjson_result::operator T &&() && - noexcept(false) { - return std::forward>(*this).take_value(); -} - -#endif // SIMDJSON_EXCEPTIONS - -template -simdjson_really_inline const T &simdjson_result::value_unsafe() - const &noexcept { - return internal::simdjson_result_base::value_unsafe(); -} - -template - simdjson_really_inline T &&simdjson_result::value_unsafe() && noexcept { - return std::forward>(*this) - .value_unsafe(); -} - -template -simdjson_really_inline simdjson_result::simdjson_result( - T &&value, error_code error) noexcept - : internal::simdjson_result_base(std::forward(value), error) {} -template -simdjson_really_inline simdjson_result::simdjson_result( - error_code error) noexcept : internal::simdjson_result_base(error) {} -template -simdjson_really_inline simdjson_result::simdjson_result(T &&value) noexcept - : internal::simdjson_result_base(std::forward(value)) {} -template -simdjson_really_inline simdjson_result::simdjson_result() noexcept - : internal::simdjson_result_base() {} - -} // namespace simdjson - -#endif // SIMDJSON_INLINE_ERROR_H -/* end file include/simdjson/error-inl.h */ -/* begin file include/simdjson/padded_string-inl.h */ -#ifndef SIMDJSON_INLINE_PADDED_STRING_H -#define SIMDJSON_INLINE_PADDED_STRING_H - - -#include -#include -#include -#include - -namespace simdjson { -namespace internal { - -// The allocate_padded_buffer function is a low-level function to allocate -// memory -// with padding so we can read past the "length" bytes safely. It is used by -// the padded_string class automatically. It returns nullptr in case -// of error: the caller should check for a null pointer. -// The length parameter is the maximum size in bytes of the string. -// The caller is responsible to free the memory (e.g., delete[] (...)). -inline char *allocate_padded_buffer(size_t length) noexcept { - const size_t totalpaddedlength = length + SIMDJSON_PADDING; - if (totalpaddedlength < length) { - // overflow - return nullptr; - } -#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION - // avoid getting out of memory - if (totalpaddedlength > (1UL << 20)) { - return nullptr; - } -#endif - - char *padded_buffer = new (std::nothrow) char[totalpaddedlength]; - if (padded_buffer == nullptr) { - return nullptr; - } - // We write zeroes in the padded region to avoid having uninitized - // garbage. If nothing else, garbage getting read might trigger a - // warning in a memory checking. - std::memset(padded_buffer + length, 0, totalpaddedlength - length); - return padded_buffer; -} // allocate_padded_buffer() - -} // namespace internal - - -inline padded_string::padded_string() noexcept {} -inline padded_string::padded_string(size_t length) noexcept - : viable_size(length), - data_ptr(internal::allocate_padded_buffer(length)) {} -inline padded_string::padded_string(const char *data, size_t length) noexcept - : viable_size(length), - data_ptr(internal::allocate_padded_buffer(length)) { - if ((data != nullptr) && (data_ptr != nullptr)) { - std::memcpy(data_ptr, data, length); - } -} -// note: do not pass std::string arguments by value -inline padded_string::padded_string(const std::string &str_) noexcept - : viable_size(str_.size()), - data_ptr(internal::allocate_padded_buffer(str_.size())) { - if (data_ptr != nullptr) { - std::memcpy(data_ptr, str_.data(), str_.size()); - } -} -// note: do pass std::string_view arguments by value -inline padded_string::padded_string(std::string_view sv_) noexcept - : viable_size(sv_.size()), - data_ptr(internal::allocate_padded_buffer(sv_.size())) { - if (simdjson_unlikely(!data_ptr)) { - // allocation failed or zero size - viable_size = 0; - return; - } - if (sv_.size()) { - std::memcpy(data_ptr, sv_.data(), sv_.size()); - } -} -inline padded_string::padded_string(padded_string &&o) noexcept - : viable_size(o.viable_size), - data_ptr(o.data_ptr) { - o.data_ptr = nullptr; // we take ownership -} - -inline padded_string &padded_string::operator=(padded_string &&o) noexcept { - delete[] data_ptr; - data_ptr = o.data_ptr; - viable_size = o.viable_size; - o.data_ptr = nullptr; // we take ownership - o.viable_size = 0; - return *this; -} - -inline void padded_string::swap(padded_string &o) noexcept { - size_t tmp_viable_size = viable_size; - char *tmp_data_ptr = data_ptr; - viable_size = o.viable_size; - data_ptr = o.data_ptr; - o.data_ptr = tmp_data_ptr; - o.viable_size = tmp_viable_size; -} - -inline padded_string::~padded_string() noexcept { delete[] data_ptr; } - -inline size_t padded_string::size() const noexcept { return viable_size; } - -inline size_t padded_string::length() const noexcept { return viable_size; } - -inline const char *padded_string::data() const noexcept { return data_ptr; } - -inline char *padded_string::data() noexcept { return data_ptr; } - -inline padded_string::operator std::string_view() const { - return std::string_view(data(), length()); -} - -inline padded_string::operator padded_string_view() const noexcept { - return padded_string_view(data(), length(), length() + SIMDJSON_PADDING); -} - -inline simdjson_result padded_string::load( - std::string_view filename) noexcept { - // Open the file - SIMDJSON_PUSH_DISABLE_WARNINGS - SIMDJSON_DISABLE_DEPRECATED_WARNING // Disable CRT_SECURE warning on MSVC: - // manually verified this is safe - std::FILE *fp = std::fopen(filename.data(), "rb"); - SIMDJSON_POP_DISABLE_WARNINGS - - if (fp == nullptr) { - return IO_ERROR; - } - - // Get the file size - if (std::fseek(fp, 0, SEEK_END) < 0) { - std::fclose(fp); - return IO_ERROR; - } -#if defined(SIMDJSON_VISUAL_STUDIO) && !SIMDJSON_IS_32BITS - __int64 llen = _ftelli64(fp); - if (llen == -1L) { - std::fclose(fp); - return IO_ERROR; - } -#else - long llen = std::ftell(fp); - if ((llen < 0) || (llen == LONG_MAX)) { - std::fclose(fp); - return IO_ERROR; - } -#endif - - // Allocate the padded_string - size_t len = static_cast(llen); - padded_string s(len); - if (s.data() == nullptr) { - std::fclose(fp); - return MEMALLOC; - } - - // Read the padded_string - std::rewind(fp); - size_t bytes_read = std::fread(s.data(), 1, len, fp); - if (std::fclose(fp) != 0 || bytes_read != len) { - return IO_ERROR; - } - - return s; -} - -} // namespace simdjson - -#endif // SIMDJSON_INLINE_PADDED_STRING_H -/* end file include/simdjson/padded_string-inl.h */ -/* begin file include/simdjson/padded_string_view-inl.h */ -#ifndef SIMDJSON_PADDED_STRING_VIEW_INL_H -#define SIMDJSON_PADDED_STRING_VIEW_INL_H - - -#include -#include -#include -#include - -namespace simdjson { - -inline padded_string_view::padded_string_view(const char *s, - size_t len, - size_t capacity) noexcept - : std::string_view(s, len), - _capacity(capacity) {} - -inline padded_string_view::padded_string_view(const uint8_t *s, - size_t len, - size_t capacity) noexcept - : padded_string_view(reinterpret_cast(s), len, capacity) {} - -inline padded_string_view::padded_string_view(const std::string &s) noexcept - : std::string_view(s), - _capacity(s.capacity()) {} - -inline padded_string_view::padded_string_view(std::string_view s, - size_t capacity) noexcept - : std::string_view(s), - _capacity(capacity) {} - -inline size_t padded_string_view::capacity() const noexcept { - return _capacity; -} - -inline size_t padded_string_view::padding() const noexcept { - return capacity() - length(); -} - -} // namespace simdjson - -#endif // SIMDJSON_PADDED_STRING_VIEW_INL_H -/* end file include/simdjson/padded_string_view-inl.h */ - -SIMDJSON_POP_DISABLE_WARNINGS - -#endif // SIMDJSON_BASE_H -/* end file include/simdjson/base.h */ - -SIMDJSON_PUSH_DISABLE_WARNINGS -SIMDJSON_DISABLE_UNDESIRED_WARNINGS - -/* begin file include/simdjson/dom/array.h */ -#ifndef SIMDJSON_DOM_ARRAY_H -#define SIMDJSON_DOM_ARRAY_H - -/* begin file include/simdjson/internal/tape_ref.h */ -#ifndef SIMDJSON_INTERNAL_TAPE_REF_H -#define SIMDJSON_INTERNAL_TAPE_REF_H - -/* begin file include/simdjson/internal/tape_type.h */ -#ifndef SIMDJSON_INTERNAL_TAPE_TYPE_H -#define SIMDJSON_INTERNAL_TAPE_TYPE_H - -namespace simdjson { -namespace internal { - -/** - * The possible types in the tape. - */ -enum class tape_type { - ROOT = 'r', - START_ARRAY = '[', - START_OBJECT = '{', - END_ARRAY = ']', - END_OBJECT = '}', - STRING = '"', - INT64 = 'l', - UINT64 = 'u', - DOUBLE = 'd', - TRUE_VALUE = 't', - FALSE_VALUE = 'f', - NULL_VALUE = 'n' -}; // enum class tape_type - -} // namespace internal -} // namespace simdjson - -#endif // SIMDJSON_INTERNAL_TAPE_TYPE_H -/* end file include/simdjson/internal/tape_type.h */ - -namespace simdjson { - -namespace dom { -class document; -} - -namespace internal { - -constexpr const uint64_t JSON_VALUE_MASK = 0x00FFFFFFFFFFFFFF; -constexpr const uint32_t JSON_COUNT_MASK = 0xFFFFFF; - -/** - * A reference to an element on the tape. Internal only. - */ -class tape_ref { - public: - simdjson_really_inline tape_ref() noexcept; - simdjson_really_inline tape_ref(const dom::document *doc, - size_t json_index) noexcept; - inline size_t after_element() const noexcept; - simdjson_really_inline tape_type tape_ref_type() const noexcept; - simdjson_really_inline uint64_t tape_value() const noexcept; - simdjson_really_inline bool is_double() const noexcept; - simdjson_really_inline bool is_int64() const noexcept; - simdjson_really_inline bool is_uint64() const noexcept; - simdjson_really_inline bool is_false() const noexcept; - simdjson_really_inline bool is_true() const noexcept; - simdjson_really_inline bool is_null_on_tape() const - noexcept; // different name to avoid clash with is_null. - simdjson_really_inline uint32_t matching_brace_index() const noexcept; - simdjson_really_inline uint32_t scope_count() const noexcept; - template - simdjson_really_inline T next_tape_value() const noexcept; - simdjson_really_inline uint32_t get_string_length() const noexcept; - simdjson_really_inline const char *get_c_str() const noexcept; - inline std::string_view get_string_view() const noexcept; - simdjson_really_inline bool is_document_root() const noexcept; - - /** The document this element references. */ - const dom::document *doc; - - /** The index of this element on `doc.tape[]` */ - size_t json_index; -}; - -} // namespace internal -} // namespace simdjson - -#endif // SIMDJSON_INTERNAL_TAPE_REF_H -/* end file include/simdjson/internal/tape_ref.h */ - -namespace simdjson { - -namespace internal { -template -class string_builder; -} -namespace dom { - -class document; -class element; - -/** - * JSON array. - */ -class array { - public: - /** Create a new, invalid array */ - simdjson_really_inline array() noexcept; - - class iterator { - public: - using value_type = element; - using difference_type = std::ptrdiff_t; - - /** - * Get the actual value - */ - inline value_type operator*() const noexcept; - /** - * Get the next value. - * - * Part of the std::iterator interface. - */ - inline iterator &operator++() noexcept; - /** - * Get the next value. - * - * Part of the std::iterator interface. - */ - inline iterator operator++(int)noexcept; - /** - * Check if these values come from the same place in the JSON. - * - * Part of the std::iterator interface. - */ - inline bool operator!=(const iterator &other) const noexcept; - inline bool operator==(const iterator &other) const noexcept; - - inline bool operator<(const iterator &other) const noexcept; - inline bool operator<=(const iterator &other) const noexcept; - inline bool operator>=(const iterator &other) const noexcept; - inline bool operator>(const iterator &other) const noexcept; - - iterator() noexcept = default; - iterator(const iterator &) noexcept = default; - iterator &operator=(const iterator &) noexcept = default; - - private: - simdjson_really_inline iterator( - const internal::tape_ref &tape) noexcept; - internal::tape_ref tape; - friend class array; - }; - - /** - * Return the first array element. - * - * Part of the std::iterable interface. - */ - inline iterator begin() const noexcept; - /** - * One past the last array element. - * - * Part of the std::iterable interface. - */ - inline iterator end() const noexcept; - /** - * Get the size of the array (number of immediate children). - * It is a saturated value with a maximum of 0xFFFFFF: if the value - * is 0xFFFFFF then the size is 0xFFFFFF or greater. - */ - inline size_t size() const noexcept; - /** - * Get the total number of slots used by this array on the tape. - * - * Note that this is not the same thing as `size()`, which reports the - * number of actual elements within an array (not counting its children). - * - * Since an element can use 1 or 2 slots on the tape, you can only use this - * to figure out the total size of an array (including its children, - * recursively) if you know its structure ahead of time. - **/ - inline size_t number_of_slots() const noexcept; - /** - * Get the value associated with the given JSON pointer. We use the RFC - * 6901 - * https://tools.ietf.org/html/rfc6901 standard, interpreting the current - * node - * as the root of its own JSON document. - * - * dom::parser parser; - * array a = parser.parse(R"([ { "foo": { "a": [ 10, 20, 30 ] }} - * ])"_padded); - * a.at_pointer("/0/foo/a/1") == 20 - * a.at_pointer("0")["foo"]["a"].at(1) == 20 - * - * @return The value associated with the given JSON pointer, or: - * - NO_SUCH_FIELD if a field does not exist in an object - * - INDEX_OUT_OF_BOUNDS if an array index is larger than an array - * length - * - INCORRECT_TYPE if a non-integer is used to access an array - * - INVALID_JSON_POINTER if the JSON pointer is invalid and cannot - * be parsed - */ - inline simdjson_result at_pointer( - std::string_view json_pointer) const noexcept; - - /** - * Get the value at the given index. This function has linear-time - * complexity and - * is equivalent to the following: - * - * size_t i=0; - * for (auto element : *this) { - * if (i == index) { return element; } - * i++; - * } - * return INDEX_OUT_OF_BOUNDS; - * - * Avoid calling the at() function repeatedly. - * - * @return The value at the given index, or: - * - INDEX_OUT_OF_BOUNDS if the array index is larger than an array - * length - */ - inline simdjson_result at(size_t index) const noexcept; - - private: - simdjson_really_inline array(const internal::tape_ref &tape) noexcept; - internal::tape_ref tape; - friend class element; - friend struct simdjson_result; - template - friend class simdjson::internal::string_builder; -}; - - -} // namespace dom - -/** The result of a JSON conversion that may fail. */ -template <> -struct simdjson_result - : public internal::simdjson_result_base { - public: - simdjson_really_inline simdjson_result() noexcept; ///< @private - simdjson_really_inline simdjson_result( - dom::array value) noexcept; ///< @private - simdjson_really_inline simdjson_result( - error_code error) noexcept; ///< @private - - inline simdjson_result at_pointer( - std::string_view json_pointer) const noexcept; - inline simdjson_result at(size_t index) const noexcept; - -#if SIMDJSON_EXCEPTIONS - inline dom::array::iterator begin() const noexcept(false); - inline dom::array::iterator end() const noexcept(false); - inline size_t size() const noexcept(false); -#endif // SIMDJSON_EXCEPTIONS -}; - - -} // namespace simdjson - -#if defined(__cpp_lib_ranges) -#include - -namespace std { -namespace ranges { -template <> -inline constexpr bool enable_view = true; -#if SIMDJSON_EXCEPTIONS -template <> -inline constexpr bool - enable_view> = true; -#endif // SIMDJSON_EXCEPTIONS -} // namespace ranges -} // namespace std -#endif // defined(__cpp_lib_ranges) - -#endif // SIMDJSON_DOM_ARRAY_H -/* end file include/simdjson/dom/array.h */ -/* begin file include/simdjson/dom/document_stream.h */ -#ifndef SIMDJSON_DOCUMENT_STREAM_H -#define SIMDJSON_DOCUMENT_STREAM_H - -/* begin file include/simdjson/dom/parser.h */ -#ifndef SIMDJSON_DOM_PARSER_H -#define SIMDJSON_DOM_PARSER_H - -/* begin file include/simdjson/dom/document.h */ -#ifndef SIMDJSON_DOM_DOCUMENT_H -#define SIMDJSON_DOM_DOCUMENT_H - -#include -#include - -namespace simdjson { -namespace dom { - -class element; - -/** - * A parsed JSON document. - * - * This class cannot be copied, only moved, to avoid unintended allocations. - */ -class document { - public: - /** - * Create a document container with zero capacity. - * - * The parser will allocate capacity as needed. - */ - document() noexcept = default; - ~document() noexcept = default; - - /** - * Take another document's buffers. - * - * @param other The document to take. Its capacity is zeroed and it is - * invalidated. - */ - document(document &&other) noexcept = default; - /** @private */ - document(const document &) = delete; // Disallow copying - /** - * Take another document's buffers. - * - * @param other The document to take. Its capacity is zeroed. - */ - document &operator=(document &&other) noexcept = default; - /** @private */ - document &operator=(const document &) = delete; // Disallow copying - - /** - * Get the root element of this document as a JSON array. - */ - element root() const noexcept; - - /** - * @private Dump the raw tape for debugging. - * - * @param os the stream to output to. - * @return false if the tape is likely wrong (e.g., you did not parse a - * valid JSON). - */ - bool dump_raw_tape(std::ostream &os) const noexcept; - - /** @private Structural values. */ - std::unique_ptr tape{}; - - /** @private String values. - * - * Should be at least byte_capacity. - */ - std::unique_ptr string_buf{}; - /** @private Allocate memory to support - * input JSON documents of up to len bytes. - * - * When calling this function, you lose - * all the data. - * - * The memory allocation is strict: you - * can you use this function to increase - * or lower the amount of allocated memory. - * Passsing zero clears the memory. - */ - error_code allocate(size_t len) noexcept; - /** @private Capacity in bytes, in terms - * of how many bytes of input JSON we can - * support. - */ - size_t capacity() const noexcept; - - - private: - size_t allocated_capacity{0}; - friend class parser; -}; // class document - -} // namespace dom -} // namespace simdjson - -#endif // SIMDJSON_DOM_DOCUMENT_H -/* end file include/simdjson/dom/document.h */ -#include -#include -#include - -namespace simdjson { - -namespace dom { - -class document_stream; -class element; - -/** The default batch size for parser.parse_many() and parser.load_many() */ -static constexpr size_t DEFAULT_BATCH_SIZE = 1000000; -/** - * Some adversary might try to set the batch size to 0 or 1, which might cause - * problems. - * We set a minimum of 32B since anything else is highly likely to be an error. - * In practice, - * most users will want a much larger batch size. - * - * All non-negative MINIMAL_BATCH_SIZE values should be 'safe' except that, - * obviously, no JSON - * document can ever span 0 or 1 byte and that very large values would create - * memory allocation issues. - */ -static constexpr size_t MINIMAL_BATCH_SIZE = 32; - -/** - * It is wasteful to allocate memory for tiny documents (e.g., 4 bytes). - */ -static constexpr size_t MINIMAL_DOCUMENT_CAPACITY = 32; - -/** - * A persistent document parser. - * - * The parser is designed to be reused, holding the internal buffers necessary - * to do parsing, - * as well as memory for a single document. The parsed document is overwritten - * on each parse. - * - * This class cannot be copied, only moved, to avoid unintended allocations. - * - * @note Moving a parser instance may invalidate "dom::element" instances. If - * you need to - * preserve both the "dom::element" instances and the parser, consider wrapping - * the parser - * instance in a std::unique_ptr instance: - * - * std::unique_ptr parser(new dom::parser{}); - * auto error = parser->load(f).get(root); - * - * You can then move std::unique_ptr safely. - * - * @note This is not thread safe: one parser cannot produce two documents at the - * same time! - */ -class parser { - public: - /** - * Create a JSON parser. - * - * The new parser will have zero capacity. - * - * @param max_capacity The maximum document length the parser can - * automatically handle. The parser - * will allocate more capacity on an as needed basis (when it sees - * documents too big to handle) - * up to this amount. The parser still starts with zero capacity no - * matter what this number is: - * to allocate an initial capacity, call allocate() after constructing - * the parser. - * Defaults to SIMDJSON_MAXSIZE_BYTES (the largest single document - * simdjson can process). - */ - simdjson_really_inline explicit parser( - size_t max_capacity = SIMDJSON_MAXSIZE_BYTES) noexcept; - /** - * Take another parser's buffers and state. - * - * @param other The parser to take. Its capacity is zeroed. - */ - simdjson_really_inline parser(parser &&other) noexcept; - parser(const parser &) = delete; ///< @private Disallow copying - /** - * Take another parser's buffers and state. - * - * @param other The parser to take. Its capacity is zeroed. - */ - simdjson_really_inline parser &operator=(parser &&other) noexcept; - parser &operator=(const parser &) = delete; ///< @private Disallow copying - - /** Deallocate the JSON parser. */ - ~parser() = default; - - /** - * Load a JSON document from a file and return a reference to it. - * - * dom::parser parser; - * const element doc = parser.load("jsonexamples/twitter.json"); - * - * The function is eager: the file's content is loaded in memory inside the - * parser instance - * and immediately parsed. The file can be deleted after the `parser.load` - * call. - * - * ### IMPORTANT: Document Lifetime - * - * The JSON document still lives in the parser: this is the most efficient - * way to parse JSON - * documents because it reuses the same buffers, but you *must* use the - * document before you - * destroy the parser or call parse() again. - * - * Moving the parser instance is safe, but it invalidates the element - * instances. You may store - * the parser instance without moving it by wrapping it inside an - * `unique_ptr` instance like - * so: `std::unique_ptr parser(new dom::parser{});`. - * - * ### Parser Capacity - * - * If the parser's current capacity is less than the file length, it will - * allocate enough capacity - * to handle it (up to max_capacity). - * - * @param path The path to load. - * @return The document, or an error: - * - IO_ERROR if there was an error opening or reading the file. - * Be mindful that on some 32-bit systems, - * the file size might be limited to 2 GB. - * - MEMALLOC if the parser does not have enough capacity and memory - * allocation fails. - * - CAPACITY if the parser does not have enough capacity and len > - * max_capacity. - * - other json errors if parsing fails. You should not rely on - * these errors to always the same for the - * same document: they may vary under runtime dispatch (so they - * may vary depending on your system and hardware). - */ - inline simdjson_result load(const std::string &path) & noexcept; - inline simdjson_result load(const std::string &path) && = delete; - /** - * Parse a JSON document and return a temporary reference to it. - * - * dom::parser parser; - * element doc_root = parser.parse(buf, len); - * - * The function eagerly parses the input: the input can be modified and - * discarded after - * the `parser.parse(buf, len)` call has completed. - * - * ### IMPORTANT: Document Lifetime - * - * The JSON document still lives in the parser: this is the most efficient - * way to parse JSON - * documents because it reuses the same buffers, but you *must* use the - * document before you - * destroy the parser or call parse() again. - * - * Moving the parser instance is safe, but it invalidates the element - * instances. You may store - * the parser instance without moving it by wrapping it inside an - * `unique_ptr` instance like - * so: `std::unique_ptr parser(new dom::parser{});`. - * - * ### REQUIRED: Buffer Padding - * - * The buffer must have at least SIMDJSON_PADDING extra allocated bytes. It - * does not matter what - * those bytes are initialized to, as long as they are allocated. - * - * If realloc_if_needed is true (the default), it is assumed that the buffer - * does *not* have enough padding, - * and it is copied into an enlarged temporary buffer before parsing. Thus - * the following is safe: - * - * const char *json = R"({"key":"value"})"; - * const size_t json_len = std::strlen(json); - * simdjson::dom::parser parser; - * simdjson::dom::element element = parser.parse(json, json_len); - * - * If you set realloc_if_needed to false (e.g., parser.parse(json, json_len, - * false)), - * you must provide a buffer with at least SIMDJSON_PADDING extra bytes at - * the end. - * The benefit of setting realloc_if_needed to false is that you avoid a - * temporary - * memory allocation and a copy. - * - * The padded bytes may be read. It is not important how you initialize - * these bytes though we recommend a sensible default like null character - * values or spaces. - * For example, the following low-level code is safe: - * - * const char *json = R"({"key":"value"})"; - * const size_t json_len = std::strlen(json); - * std::unique_ptr padded_json_copy{new char[json_len + - * SIMDJSON_PADDING]}; - * std::memcpy(padded_json_copy.get(), json, json_len); - * std::memset(padded_json_copy.get() + json_len, '\0', SIMDJSON_PADDING); - * simdjson::dom::parser parser; - * simdjson::dom::element element = parser.parse(padded_json_copy.get(), - * json_len, false); - * - * ### Parser Capacity - * - * If the parser's current capacity is less than len, it will allocate - * enough capacity - * to handle it (up to max_capacity). - * - * @param buf The JSON to parse. Must have at least len + SIMDJSON_PADDING - * allocated bytes, unless - * realloc_if_needed is true. - * @param len The length of the JSON. - * @param realloc_if_needed Whether to reallocate and enlarge the JSON - * buffer to add padding. - * @return An element pointing at the root of the document, or an error: - * - MEMALLOC if realloc_if_needed is true or the parser does not - * have enough capacity, - * and memory allocation fails. - * - CAPACITY if the parser does not have enough capacity and len > - * max_capacity. - * - other json errors if parsing fails. You should not rely on - * these errors to always the same for the - * same document: they may vary under runtime dispatch (so they - * may vary depending on your system and hardware). - */ - inline simdjson_result parse(const uint8_t *buf, - size_t len, - bool realloc_if_needed = true) & - noexcept; - inline simdjson_result parse(const uint8_t *buf, - size_t len, - bool realloc_if_needed = true) && = - delete; - /** @overload parse(const uint8_t *buf, size_t len, bool realloc_if_needed) - */ - simdjson_really_inline simdjson_result parse( - const char *buf, size_t len, bool realloc_if_needed = true) & - noexcept; - simdjson_really_inline simdjson_result parse( - const char *buf, size_t len, bool realloc_if_needed = true) && = delete; - /** @overload parse(const uint8_t *buf, size_t len, bool realloc_if_needed) - */ - simdjson_really_inline simdjson_result parse( - const std::string &s) & - noexcept; - simdjson_really_inline simdjson_result parse( - const std::string &s) && = delete; - /** @overload parse(const uint8_t *buf, size_t len, bool realloc_if_needed) - */ - simdjson_really_inline simdjson_result parse( - const padded_string &s) & - noexcept; - simdjson_really_inline simdjson_result parse( - const padded_string &s) && = delete; - - /** @private We do not want to allow implicit conversion from C string to - * std::string. */ - simdjson_really_inline simdjson_result parse( - const char *buf) noexcept = delete; - - /** - * Parse a JSON document into a provide document instance and return a - * temporary reference to it. - * It is similar to the function `parse` except that instead of parsing into - * the internal - * `document` instance associated with the parser, it allows the user to - * provide a document - * instance. - * - * dom::parser parser; - * dom::document doc; - * element doc_root = parser.parse_into_document(doc, buf, len); - * - * The function eagerly parses the input: the input can be modified and - * discarded after - * the `parser.parse(buf, len)` call has completed. - * - * ### IMPORTANT: Document Lifetime - * - * After the call to parse_into_document, the parser is no longer needed. - * - * The JSON document lives in the document instance: you must keep the - * document - * instance alive while you navigate through it (i.e., used the returned - * value from - * parse_into_document). You are encourage to reuse the document instance - * many times with new data to avoid reallocations: - * - * dom::document doc; - * element doc_root1 = parser.parse_into_document(doc, buf1, len); - * //... doc_root1 is a pointer inside doc - * element doc_root2 = parser.parse_into_document(doc, buf1, len); - * //... doc_root2 is a pointer inside doc - * // at this point doc_root1 is no longer safe - * - * Moving the document instance is safe, but it invalidates the element - * instances. After - * moving a document, you can recover safe access to the document root with - * its `root()` method. - * - * @param doc The document instance where the parsed data will be stored (on - * success). - * @param buf The JSON to parse. Must have at least len + SIMDJSON_PADDING - * allocated bytes, unless - * realloc_if_needed is true. - * @param len The length of the JSON. - * @param realloc_if_needed Whether to reallocate and enlarge the JSON - * buffer to add padding. - * @return An element pointing at the root of document, or an error: - * - MEMALLOC if realloc_if_needed is true or the parser does not - * have enough capacity, - * and memory allocation fails. - * - CAPACITY if the parser does not have enough capacity and len > - * max_capacity. - * - other json errors if parsing fails. You should not rely on - * these errors to always the same for the - * same document: they may vary under runtime dispatch (so they - * may vary depending on your system and hardware). - */ - inline simdjson_result parse_into_document( - document &doc, - const uint8_t *buf, - size_t len, - bool realloc_if_needed = true) & - noexcept; - inline simdjson_result parse_into_document( - document &doc, - const uint8_t *buf, - size_t len, - bool realloc_if_needed = true) && = delete; - /** @overload parse_into_document(const uint8_t *buf, size_t len, bool - * realloc_if_needed) */ - simdjson_really_inline simdjson_result parse_into_document( - document &doc, - const char *buf, - size_t len, - bool realloc_if_needed = true) & - noexcept; - simdjson_really_inline simdjson_result parse_into_document( - document &doc, - const char *buf, - size_t len, - bool realloc_if_needed = true) && = delete; - /** @overload parse_into_document(const uint8_t *buf, size_t len, bool - * realloc_if_needed) */ - simdjson_really_inline simdjson_result parse_into_document( - document &doc, const std::string &s) & - noexcept; - simdjson_really_inline simdjson_result parse_into_document( - document &doc, const std::string &s) && = delete; - /** @overload parse_into_document(const uint8_t *buf, size_t len, bool - * realloc_if_needed) */ - simdjson_really_inline simdjson_result parse_into_document( - document &doc, const padded_string &s) & - noexcept; - simdjson_really_inline simdjson_result parse_into_document( - document &doc, const padded_string &s) && = delete; - - /** @private We do not want to allow implicit conversion from C string to - * std::string. */ - simdjson_really_inline simdjson_result parse_into_document( - document &doc, const char *buf) noexcept = delete; - - /** - * Load a file containing many JSON documents. - * - * dom::parser parser; - * for (const element doc : parser.load_many(path)) { - * cout << std::string(doc["title"]) << endl; - * } - * - * The file is loaded in memory and can be safely deleted after the - * `parser.load_many(path)` - * function has returned. The memory is held by the `parser` instance. - * - * The function is lazy: it may be that no more than one JSON document at a - * time is parsed. - * And, possibly, no document many have been parsed when the - * `parser.load_many(path)` function - * returned. - * - * ### Format - * - * The file must contain a series of one or more JSON documents, - * concatenated into a single - * buffer, separated by whitespace. It effectively parses until it has a - * fully valid document, - * then starts parsing the next document at that point. (It does this with - * more parallelism and - * lookahead than you might think, though.) - * - * Documents that consist of an object or array may omit the whitespace - * between them, concatenating - * with no separator. documents that consist of a single primitive (i.e. - * documents that are not - * arrays or objects) MUST be separated with whitespace. - * - * The documents must not exceed batch_size bytes (by default 1MB) or they - * will fail to parse. - * Setting batch_size to excessively large or excesively small values may - * impact negatively the - * performance. - * - * ### Error Handling - * - * All errors are returned during iteration: if there is a global error such - * as memory allocation, - * it will be yielded as the first result. Iteration always stops after the - * first error. - * - * As with all other simdjson methods, non-exception error handling is - * readily available through - * the same interface, requiring you to check the error before using the - * document: - * - * dom::parser parser; - * dom::document_stream docs; - * auto error = parser.load_many(path).get(docs); - * if (error) { cerr << error << endl; exit(1); } - * for (auto doc : docs) { - * std::string_view title; - * if ((error = doc["title"].get(title)) { cerr << error << endl; - * exit(1); } - * cout << title << endl; - * } - * - * ### Threads - * - * When compiled with SIMDJSON_THREADS_ENABLED, this method will use a - * single thread under the - * hood to do some lookahead. - * - * ### Parser Capacity - * - * If the parser's current capacity is less than batch_size, it will - * allocate enough capacity - * to handle it (up to max_capacity). - * - * @param path File name pointing at the concatenated JSON to parse. - * @param batch_size The batch size to use. MUST be larger than the largest - * document. The sweet - * spot is cache-related: small enough to fit in cache, - * yet big enough to - * parse as many documents as possible in one tight loop. - * Defaults to 1MB (as simdjson::dom::DEFAULT_BATCH_SIZE), - * which has been a reasonable sweet - * spot in our tests. - * If you set the batch_size to a value smaller than - * simdjson::dom::MINIMAL_BATCH_SIZE - * (currently 32B), it will be replaced by - * simdjson::dom::MINIMAL_BATCH_SIZE. - * @return The stream, or an error. An empty input will yield 0 documents - * rather than an EMPTY error. Errors: - * - IO_ERROR if there was an error opening or reading the file. - * - MEMALLOC if the parser does not have enough capacity and memory - * allocation fails. - * - CAPACITY if the parser does not have enough capacity and - * batch_size > max_capacity. - * - other json errors if parsing fails. You should not rely on - * these errors to always the same for the - * same document: they may vary under runtime dispatch (so they - * may vary depending on your system and hardware). - */ - inline simdjson_result load_many( - const std::string &path, - size_t batch_size = dom::DEFAULT_BATCH_SIZE) noexcept; - - /** - * Parse a buffer containing many JSON documents. - * - * dom::parser parser; - * for (element doc : parser.parse_many(buf, len)) { - * cout << std::string(doc["title"]) << endl; - * } - * - * No copy of the input buffer is made. - * - * The function is lazy: it may be that no more than one JSON document at a - * time is parsed. - * And, possibly, no document many have been parsed when the - * `parser.load_many(path)` function - * returned. - * - * The caller is responsabile to ensure that the input string data remains - * unchanged and is - * not deleted during the loop. In particular, the following is unsafe and - * will not compile: - * - * auto docs = parser.parse_many("[\"temporary data\"]"_padded); - * // here the string "[\"temporary data\"]" may no longer exist in memory - * // the parser instance may not have even accessed the input yet - * for (element doc : docs) { - * cout << std::string(doc["title"]) << endl; - * } - * - * The following is safe: - * - * auto json = "[\"temporary data\"]"_padded; - * auto docs = parser.parse_many(json); - * for (element doc : docs) { - * cout << std::string(doc["title"]) << endl; - * } - * - * ### Format - * - * The buffer must contain a series of one or more JSON documents, - * concatenated into a single - * buffer, separated by whitespace. It effectively parses until it has a - * fully valid document, - * then starts parsing the next document at that point. (It does this with - * more parallelism and - * lookahead than you might think, though.) - * - * documents that consist of an object or array may omit the whitespace - * between them, concatenating - * with no separator. documents that consist of a single primitive (i.e. - * documents that are not - * arrays or objects) MUST be separated with whitespace. - * - * The documents must not exceed batch_size bytes (by default 1MB) or they - * will fail to parse. - * Setting batch_size to excessively large or excesively small values may - * impact negatively the - * performance. - * - * ### Error Handling - * - * All errors are returned during iteration: if there is a global error such - * as memory allocation, - * it will be yielded as the first result. Iteration always stops after the - * first error. - * - * As with all other simdjson methods, non-exception error handling is - * readily available through - * the same interface, requiring you to check the error before using the - * document: - * - * dom::parser parser; - * dom::document_stream docs; - * auto error = parser.load_many(path).get(docs); - * if (error) { cerr << error << endl; exit(1); } - * for (auto doc : docs) { - * std::string_view title; - * if ((error = doc["title"].get(title)) { cerr << error << endl; - * exit(1); } - * cout << title << endl; - * } - * - * ### REQUIRED: Buffer Padding - * - * The buffer must have at least SIMDJSON_PADDING extra allocated bytes. It - * does not matter what - * those bytes are initialized to, as long as they are allocated. - * - * ### Threads - * - * When compiled with SIMDJSON_THREADS_ENABLED, this method will use a - * single thread under the - * hood to do some lookahead. - * - * ### Parser Capacity - * - * If the parser's current capacity is less than batch_size, it will - * allocate enough capacity - * to handle it (up to max_capacity). - * - * @param buf The concatenated JSON to parse. Must have at least len + - * SIMDJSON_PADDING allocated bytes. - * @param len The length of the concatenated JSON. - * @param batch_size The batch size to use. MUST be larger than the largest - * document. The sweet - * spot is cache-related: small enough to fit in cache, - * yet big enough to - * parse as many documents as possible in one tight loop. - * Defaults to 10MB, which has been a reasonable sweet - * spot in our tests. - * @return The stream, or an error. An empty input will yield 0 documents - * rather than an EMPTY error. Errors: - * - MEMALLOC if the parser does not have enough capacity and memory - * allocation fails - * - CAPACITY if the parser does not have enough capacity and - * batch_size > max_capacity. - * - other json errors if parsing fails. You should not rely on - * these errors to always the same for the - * same document: they may vary under runtime dispatch (so they - * may vary depending on your system and hardware). - */ - inline simdjson_result parse_many( - const uint8_t *buf, - size_t len, - size_t batch_size = dom::DEFAULT_BATCH_SIZE) noexcept; - /** @overload parse_many(const uint8_t *buf, size_t len, size_t batch_size) - */ - inline simdjson_result parse_many( - const char *buf, - size_t len, - size_t batch_size = dom::DEFAULT_BATCH_SIZE) noexcept; - /** @overload parse_many(const uint8_t *buf, size_t len, size_t batch_size) - */ - inline simdjson_result parse_many( - const std::string &s, - size_t batch_size = dom::DEFAULT_BATCH_SIZE) noexcept; - inline simdjson_result parse_many( - const std::string &&s, size_t batch_size) = delete; // unsafe - /** @overload parse_many(const uint8_t *buf, size_t len, size_t batch_size) - */ - inline simdjson_result parse_many( - const padded_string &s, - size_t batch_size = dom::DEFAULT_BATCH_SIZE) noexcept; - inline simdjson_result parse_many( - const padded_string &&s, size_t batch_size) = delete; // unsafe - - /** @private We do not want to allow implicit conversion from C string to - * std::string. */ - simdjson_result parse_many( - const char *buf, - size_t batch_size = dom::DEFAULT_BATCH_SIZE) noexcept = delete; - - /** - * Ensure this parser has enough memory to process JSON documents up to - * `capacity` bytes in length - * and `max_depth` depth. - * - * @param capacity The new capacity. - * @param max_depth The new max_depth. Defaults to DEFAULT_MAX_DEPTH. - * @return The error, if there is one. - */ - simdjson_warn_unused inline error_code allocate( - size_t capacity, size_t max_depth = DEFAULT_MAX_DEPTH) noexcept; - -#ifndef SIMDJSON_DISABLE_DEPRECATED_API - /** - * @private deprecated because it returns bool instead of error_code, which - * is our standard for - * failures. Use allocate() instead. - * - * Ensure this parser has enough memory to process JSON documents up to - * `capacity` bytes in length - * and `max_depth` depth. - * - * @param capacity The new capacity. - * @param max_depth The new max_depth. Defaults to DEFAULT_MAX_DEPTH. - * @return true if successful, false if allocation failed. - */ - [[deprecated("Use allocate() instead.")]] simdjson_warn_unused inline bool - allocate_capacity(size_t capacity, - size_t max_depth = DEFAULT_MAX_DEPTH) noexcept; -#endif // SIMDJSON_DISABLE_DEPRECATED_API - /** - * The largest document this parser can support without reallocating. - * - * @return Current capacity, in bytes. - */ - simdjson_really_inline size_t capacity() const noexcept; - - /** - * The largest document this parser can automatically support. - * - * The parser may reallocate internal buffers as needed up to this amount. - * - * @return Maximum capacity, in bytes. - */ - simdjson_really_inline size_t max_capacity() const noexcept; - - /** - * The maximum level of nested object and arrays supported by this parser. - * - * @return Maximum depth, in bytes. - */ - simdjson_really_inline size_t max_depth() const noexcept; - - /** - * Set max_capacity. This is the largest document this parser can - * automatically support. - * - * The parser may reallocate internal buffers as needed up to this amount as - * documents are passed - * to it. - * - * Note: To avoid limiting the memory to an absurd value, such as zero or - * two bytes, - * iff you try to set max_capacity to a value lower than - * MINIMAL_DOCUMENT_CAPACITY, - * then the maximal capacity is set to MINIMAL_DOCUMENT_CAPACITY. - * - * This call will not allocate or deallocate, even if capacity is currently - * above max_capacity. - * - * @param max_capacity The new maximum capacity, in bytes. - */ - simdjson_really_inline void set_max_capacity(size_t max_capacity) noexcept; - -#ifdef SIMDJSON_THREADS_ENABLED - /** - * The parser instance can use threads when they are available to speed up - * some - * operations. It is enabled by default. Changing this attribute will change - * the - * behavior of the parser for future operations. - */ - bool threaded{true}; -#endif - /** @private Use the new DOM API instead */ - class Iterator; - /** @private Use simdjson_error instead */ - using InvalidJSON[[deprecated("Use simdjson_error instead")]] = - simdjson_error; - - /** @private [for benchmarking access] The implementation to use */ - std::unique_ptr implementation{}; - - /** @private Use `if (parser.parse(...).error())` instead */ - bool valid{false}; - /** @private Use `parser.parse(...).error()` instead */ - error_code error{UNINITIALIZED}; - - /** @private Use `parser.parse(...).value()` instead */ - document doc{}; - - /** @private returns true if the document parsed was valid */ - [[deprecated("Use the result of parser.parse() instead")]] inline bool - is_valid() const noexcept; - - /** - * @private return an error code corresponding to the last parsing attempt, - * see - * simdjson.h will return UNINITIALIZED if no parsing was attempted - */ - [[deprecated("Use the result of parser.parse() instead")]] inline int - get_error_code() const noexcept; - - /** @private return the string equivalent of "get_error_code" */ - [ - [deprecated("Use error_message() on the result of parser.parse() " - "instead, or cout << error")]] inline std::string - get_error_message() const noexcept; - - /** @private */ - [[deprecated( - "Use cout << on the result of parser.parse() instead")]] inline bool - print_json(std::ostream &os) const noexcept; - - /** @private Private and deprecated: use - * `parser.parse(...).doc.dump_raw_tape()` instead */ - inline bool dump_raw_tape(std::ostream &os) const noexcept; - - - private: - /** - * The maximum document length this parser will automatically support. - * - * The parser will not be automatically allocated above this amount. - */ - size_t _max_capacity; - - /** - * The loaded buffer (reused each time load() is called) - */ - std::unique_ptr loaded_bytes; - - /** Capacity of loaded_bytes buffer. */ - size_t _loaded_bytes_capacity{0}; - - // all nodes are stored on the doc.tape using a 64-bit word. - // - // strings, double and ints are stored as - // a 64-bit word with a pointer to the actual value - // - // - // - // for objects or arrays, store [ or { at the beginning and } and ] at the - // end. For the openings ([ or {), we annotate them with a reference to the - // location on the doc.tape of the end, and for then closings (} and ]), we - // annotate them with a reference to the location of the opening - // - // - - /** - * Ensure we have enough capacity to handle at least desired_capacity bytes, - * and auto-allocate if not. This also allocates memory if needed in the - * internal document. - */ - inline error_code ensure_capacity(size_t desired_capacity) noexcept; - /** - * Ensure we have enough capacity to handle at least desired_capacity bytes, - * and auto-allocate if not. This also allocates memory if needed in the - * provided document. - */ - inline error_code ensure_capacity(document &doc, - size_t desired_capacity) noexcept; - - /** Read the file into loaded_bytes */ - inline simdjson_result read_file(const std::string &path) noexcept; - - friend class parser::Iterator; - friend class document_stream; - - -}; // class parser - -} // namespace dom -} // namespace simdjson - -#endif // SIMDJSON_DOM_PARSER_H -/* end file include/simdjson/dom/parser.h */ -#ifdef SIMDJSON_THREADS_ENABLED -#include -#include -#include -#endif - -namespace simdjson { -namespace dom { - - -#ifdef SIMDJSON_THREADS_ENABLED -/** @private Custom worker class **/ -struct stage1_worker { - stage1_worker() noexcept = default; - stage1_worker(const stage1_worker &) = delete; - stage1_worker(stage1_worker &&) = delete; - stage1_worker operator=(const stage1_worker &) = delete; - ~stage1_worker(); - /** - * We only start the thread when it is needed, not at object construction, - *this may throw. - * You should only call this once. - **/ - void start_thread(); - /** - * Start a stage 1 job. You should first call 'run', then 'finish'. - * You must call start_thread once before. - */ - void run(document_stream *ds, dom::parser *stage1, size_t next_batch_start); - /** Wait for the run to finish (blocking). You should first call 'run', then - * 'finish'. **/ - void finish(); - - private: - /** - * Normally, we would never stop the thread. But we do in the destructor. - * This function is only safe assuming that you are not waiting for results. - *You - * should have called run, then finish, and be done. - **/ - void stop_thread(); - - std::thread thread{}; - /** These three variables define the work done by the thread. **/ - dom::parser *stage1_thread_parser{}; - size_t _next_batch_start{}; - document_stream *owner{}; - /** - * We have two state variables. This could be streamlined to one variable in - * the future but - * we use two for clarity. - */ - bool has_work{false}; - bool can_work{true}; - - /** - * We lock using a mutex. - */ - std::mutex locking_mutex{}; - std::condition_variable cond_var{}; -}; -#endif - -/** - * A forward-only stream of documents. - * - * Produced by parser::parse_many. - * - */ -class document_stream { - public: - /** - * Construct an uninitialized document_stream. - * - * ```c++ - * document_stream docs; - * error = parser.parse_many(json).get(docs); - * ``` - */ - simdjson_really_inline document_stream() noexcept; - /** Move one document_stream to another. */ - simdjson_really_inline document_stream(document_stream &&other) noexcept = - default; - /** Move one document_stream to another. */ - simdjson_really_inline document_stream &operator=( - document_stream &&other) noexcept = default; - - simdjson_really_inline ~document_stream() noexcept; - /** - * Returns the input size in bytes. - */ - inline size_t size_in_bytes() const noexcept; - /** - * After iterating through the stream, this method - * returns the number of bytes that were not parsed at the end - * of the stream. If truncated_bytes() differs from zero, - * then the input was truncated maybe because incomplete JSON - * documents were found at the end of the stream. You - * may need to process the bytes in the interval - * [size_in_bytes()-truncated_bytes(), size_in_bytes()). - * - * You should only call truncated_bytes() after streaming through all - * documents, like so: - * - * document_stream stream = parser.parse_many(json,window); - * for(auto doc : stream) { - * // do something with doc - * } - * size_t truncated = stream.truncated_bytes(); - * - */ - inline size_t truncated_bytes() const noexcept; - /** - * An iterator through a forward-only stream of documents. - */ - class iterator { - public: - using value_type = simdjson_result; - using reference = value_type; - - using difference_type = std::ptrdiff_t; - - using iterator_category = std::input_iterator_tag; - - /** - * Default constructor. - */ - simdjson_really_inline iterator() noexcept; - /** - * Get the current document (or error). - */ - simdjson_really_inline reference operator*() noexcept; - /** - * Advance to the next document (prefix). - */ - inline iterator &operator++() noexcept; - /** - * Check if we're at the end yet. - * @param other the end iterator to compare to. - */ - simdjson_really_inline bool operator!=(const iterator &other) const - noexcept; - /** - * @private - * - * Gives the current index in the input document in bytes. - * - * document_stream stream = parser.parse_many(json,window); - * for(auto i = stream.begin(); i != stream.end(); ++i) { - * auto doc = *i; - * size_t index = i.current_index(); - * } - * - * This function (current_index()) is experimental and the usage - * may change in future versions of simdjson: we find the API somewhat - * awkward and we would like to offer something friendlier. - */ - simdjson_really_inline size_t current_index() const noexcept; - /** - * @private - * - * Gives a view of the current document. - * - * document_stream stream = parser.parse_many(json,window); - * for(auto i = stream.begin(); i != stream.end(); ++i) { - * auto doc = *i; - * std::string_view v = i->source(); - * } - * - * The returned string_view instance is simply a map to the (unparsed) - * source string: it may thus include white-space characters and all - * manner - * of padding. - * - * This function (source()) is experimental and the usage - * may change in future versions of simdjson: we find the API somewhat - * awkward and we would like to offer something friendlier. - */ - simdjson_really_inline std::string_view source() const noexcept; - - private: - simdjson_really_inline iterator(document_stream *s, - bool finished) noexcept; - /** The document_stream we're iterating through. */ - document_stream *stream; - /** Whether we're finished or not. */ - bool finished; - friend class document_stream; - }; - - /** - * Start iterating the documents in the stream. - */ - simdjson_really_inline iterator begin() noexcept; - /** - * The end of the stream, for iterator comparison purposes. - */ - simdjson_really_inline iterator end() noexcept; - - private: - document_stream &operator=(const document_stream &) = - delete; // Disallow copying - document_stream(const document_stream &other) = delete; // Disallow copying - - /** - * Construct a document_stream. Does not allocate or parse anything until - * the iterator is - * used. - * - * @param parser is a reference to the parser instance used to generate this - * document_stream - * @param buf is the raw byte buffer we need to process - * @param len is the length of the raw byte buffer in bytes - * @param batch_size is the size of the windows (must be strictly greater or - * equal to the largest JSON document) - */ - simdjson_really_inline document_stream(dom::parser &parser, - const uint8_t *buf, - size_t len, - size_t batch_size) noexcept; - - /** - * Parse the first document in the buffer. Used by begin(), to handle - * allocation and - * initialization. - */ - inline void start() noexcept; - - /** - * Parse the next document found in the buffer previously given to - * document_stream. - * - * The content should be a valid JSON document encoded as UTF-8. If there is - * a - * UTF-8 BOM, the caller is responsible for omitting it, UTF-8 BOM are - * discouraged. - * - * You do NOT need to pre-allocate a parser. This function takes care of - * pre-allocating a capacity defined by the batch_size defined when creating - * the - * document_stream object. - * - * The function returns simdjson::EMPTY if there is no more data to be - * parsed. - * - * The function returns simdjson::SUCCESS (as integer = 0) in case of - * success - * and indicates that the buffer has successfully been parsed to the end. - * Every document it contained has been parsed without error. - * - * The function returns an error code from simdjson/simdjson.h in case of - * failure - * such as simdjson::CAPACITY, simdjson::MEMALLOC, simdjson::DEPTH_ERROR and - * so forth; - * the simdjson::error_message function converts these error codes into a - * string). - * - * You can also check validity by calling parser.is_valid(). The same parser - * can - * and should be reused for the other documents in the buffer. - */ - inline void next() noexcept; - - /** - * Pass the next batch through stage 1 and return when finished. - * When threads are enabled, this may wait for the stage 1 thread to finish. - */ - inline void load_batch() noexcept; - - /** Get the next document index. */ - inline size_t next_batch_start() const noexcept; - - /** Pass the next batch through stage 1 with the given parser. */ - inline error_code run_stage1(dom::parser &p, size_t batch_start) noexcept; - - dom::parser *parser; - const uint8_t *buf; - size_t len; - size_t batch_size; - /** The error (or lack thereof) from the current document. */ - error_code error; - size_t batch_start{0}; - size_t doc_index{}; -#ifdef SIMDJSON_THREADS_ENABLED - /** Indicates whether we use threads. Note that this needs to be a constant - * during the execution of the parsing. */ - bool use_thread; - - inline void load_from_stage1_thread() noexcept; - - /** Start a thread to run stage 1 on the next batch. */ - inline void start_stage1_thread() noexcept; - - /** Wait for the stage 1 thread to finish and capture the results. */ - inline void finish_stage1_thread() noexcept; - - /** The error returned from the stage 1 thread. */ - error_code stage1_thread_error{UNINITIALIZED}; - /** The thread used to run stage 1 against the next batch in the background. - */ - friend struct stage1_worker; - std::unique_ptr worker{new (std::nothrow) stage1_worker()}; - /** - * The parser used to run stage 1 in the background. Will be swapped - * with the regular parser when finished. - */ - dom::parser stage1_thread_parser{}; -#endif // SIMDJSON_THREADS_ENABLED - - friend class dom::parser; - friend struct simdjson_result; - friend struct internal::simdjson_result_base; - -}; // class document_stream - -} // namespace dom - -template <> -struct simdjson_result - : public internal::simdjson_result_base { - public: - simdjson_really_inline simdjson_result() noexcept; ///< @private - simdjson_really_inline simdjson_result( - error_code error) noexcept; ///< @private - simdjson_really_inline simdjson_result( - dom::document_stream &&value) noexcept; ///< @private - -#if SIMDJSON_EXCEPTIONS - simdjson_really_inline dom::document_stream::iterator begin() noexcept( - false); - simdjson_really_inline dom::document_stream::iterator end() noexcept(false); -#else // SIMDJSON_EXCEPTIONS -#ifndef SIMDJSON_DISABLE_DEPRECATED_API - [[deprecated( - "parse_many() and load_many() may return errors. Use document_stream " - "stream; error = parser.parse_many().get(doc); " - "instead.")]] simdjson_really_inline dom::document_stream::iterator - begin() noexcept; - [[deprecated( - "parse_many() and load_many() may return errors. Use document_stream " - "stream; error = parser.parse_many().get(doc); " - "instead.")]] simdjson_really_inline dom::document_stream::iterator - end() noexcept; -#endif // SIMDJSON_DISABLE_DEPRECATED_API -#endif // SIMDJSON_EXCEPTIONS -}; // struct simdjson_result - -} // namespace simdjson - -#endif // SIMDJSON_DOCUMENT_STREAM_H -/* end file include/simdjson/dom/document_stream.h */ -/* begin file include/simdjson/dom/element.h */ -#ifndef SIMDJSON_DOM_ELEMENT_H -#define SIMDJSON_DOM_ELEMENT_H - -#include - -namespace simdjson { -namespace internal { -template -class string_builder; -} -namespace dom { -class array; -class document; -class object; - -/** - * The actual concrete type of a JSON element - * This is the type it is most easily cast to with get<>. - */ -enum class element_type { - ARRAY = '[', ///< dom::array - OBJECT = '{', ///< dom::object - INT64 = 'l', ///< int64_t - UINT64 = - 'u', ///< uint64_t: any integer that fits in uint64_t but *not* int64_t - DOUBLE = - 'd', ///< double: Any number with a "." or "e" that fits in double. - STRING = '"', ///< std::string_view - BOOL = 't', ///< bool - NULL_VALUE = 'n' ///< null -}; - -/** - * A JSON element. - * - * References an element in a JSON document, representing a JSON null, boolean, - * string, number, - * array or object. - */ -class element { - public: - /** Create a new, invalid element. */ - simdjson_really_inline element() noexcept; - - /** The type of this element. */ - simdjson_really_inline element_type type() const noexcept; - - /** - * Cast this element to an array. - * - * @returns An object that can be used to iterate the array, or: - * INCORRECT_TYPE if the JSON element is not an array. - */ - inline simdjson_result get_array() const noexcept; - /** - * Cast this element to an object. - * - * @returns An object that can be used to look up or iterate the object's - * fields, or: - * INCORRECT_TYPE if the JSON element is not an object. - */ - inline simdjson_result get_object() const noexcept; - /** - * Cast this element to a null-terminated C string. - * - * The string is guaranteed to be valid UTF-8. - * - * The length of the string is given by get_string_length(). Because JSON - * strings - * may contain null characters, it may be incorrect to use strlen to - * determine the - * string length. - * - * It is possible to get a single string_view instance which represents both - * the string - * content and its length: see get_string(). - * - * @returns A pointer to a null-terminated UTF-8 string. This string is - * stored in the parser and will - * be invalidated the next time it parses a document or when it is - * destroyed. - * Returns INCORRECT_TYPE if the JSON element is not a string. - */ - inline simdjson_result get_c_str() const noexcept; - /** - * Gives the length in bytes of the string. - * - * It is possible to get a single string_view instance which represents both - * the string - * content and its length: see get_string(). - * - * @returns A string length in bytes. - * Returns INCORRECT_TYPE if the JSON element is not a string. - */ - inline simdjson_result get_string_length() const noexcept; - /** - * Cast this element to a string. - * - * The string is guaranteed to be valid UTF-8. - * - * @returns An UTF-8 string. The string is stored in the parser and will be - * invalidated the next time it - * parses a document or when it is destroyed. - * Returns INCORRECT_TYPE if the JSON element is not a string. - */ - inline simdjson_result get_string() const noexcept; - /** - * Cast this element to a signed integer. - * - * @returns A signed 64-bit integer. - * Returns INCORRECT_TYPE if the JSON element is not an integer, or - * NUMBER_OUT_OF_RANGE - * if it is negative. - */ - inline simdjson_result get_int64() const noexcept; - /** - * Cast this element to an unsigned integer. - * - * @returns An unsigned 64-bit integer. - * Returns INCORRECT_TYPE if the JSON element is not an integer, or - * NUMBER_OUT_OF_RANGE - * if it is too large. - */ - inline simdjson_result get_uint64() const noexcept; - /** - * Cast this element to a double floating-point. - * - * @returns A double value. - * Returns INCORRECT_TYPE if the JSON element is not a number. - */ - inline simdjson_result get_double() const noexcept; - /** - * Cast this element to a bool. - * - * @returns A bool value. - * Returns INCORRECT_TYPE if the JSON element is not a boolean. - */ - inline simdjson_result get_bool() const noexcept; - - /** - * Whether this element is a json array. - * - * Equivalent to is(). - */ - inline bool is_array() const noexcept; - /** - * Whether this element is a json object. - * - * Equivalent to is(). - */ - inline bool is_object() const noexcept; - /** - * Whether this element is a json string. - * - * Equivalent to is() or is(). - */ - inline bool is_string() const noexcept; - /** - * Whether this element is a json number that fits in a signed 64-bit - * integer. - * - * Equivalent to is(). - */ - inline bool is_int64() const noexcept; - /** - * Whether this element is a json number that fits in an unsigned 64-bit - * integer. - * - * Equivalent to is(). - */ - inline bool is_uint64() const noexcept; - /** - * Whether this element is a json number that fits in a double. - * - * Equivalent to is(). - */ - inline bool is_double() const noexcept; - - /** - * Whether this element is a json number. - * - * Both integers and floating points will return true. - */ - inline bool is_number() const noexcept; - - /** - * Whether this element is a json `true` or `false`. - * - * Equivalent to is(). - */ - inline bool is_bool() const noexcept; - /** - * Whether this element is a json `null`. - */ - inline bool is_null() const noexcept; - - /** - * Tell whether the value can be cast to provided type (T). - * - * Supported types: - * - Boolean: bool - * - Number: double, uint64_t, int64_t - * - String: std::string_view, const char * - * - Array: dom::array - * - Object: dom::object - * - * @tparam T bool, double, uint64_t, int64_t, std::string_view, const char - * *, dom::array, dom::object - */ - template - simdjson_really_inline bool is() const noexcept; - - /** - * Get the value as the provided type (T). - * - * Supported types: - * - Boolean: bool - * - Number: double, uint64_t, int64_t - * - String: std::string_view, const char * - * - Array: dom::array - * - Object: dom::object - * - * You may use get_double(), get_bool(), get_uint64(), get_int64(), - * get_object(), get_array() or get_string() instead. - * - * @tparam T bool, double, uint64_t, int64_t, std::string_view, const char - * *, dom::array, dom::object - * - * @returns The value cast to the given type, or: - * INCORRECT_TYPE if the value cannot be cast to the given type. - */ - - template - inline simdjson_result get() const noexcept { - // Unless the simdjson library provides an inline implementation, - // calling this method should - // immediately fail. - static_assert(!sizeof(T), - "The get method with given type is not implemented by " - "the simdjson library."); - } - - /** - * Get the value as the provided type (T). - * - * Supported types: - * - Boolean: bool - * - Number: double, uint64_t, int64_t - * - String: std::string_view, const char * - * - Array: dom::array - * - Object: dom::object - * - * @tparam T bool, double, uint64_t, int64_t, std::string_view, const char - * *, dom::array, dom::object - * - * @param value The variable to set to the value. May not be set if there is - * an error. - * - * @returns The error that occurred, or SUCCESS if there was no error. - */ - template - simdjson_warn_unused simdjson_really_inline error_code get(T &value) const - noexcept; - - /** - * Get the value as the provided type (T), setting error if it's not the - * given type. - * - * Supported types: - * - Boolean: bool - * - Number: double, uint64_t, int64_t - * - String: std::string_view, const char * - * - Array: dom::array - * - Object: dom::object - * - * @tparam T bool, double, uint64_t, int64_t, std::string_view, const char - * *, dom::array, dom::object - * - * @param value The variable to set to the given type. value is undefined if - * there is an error. - * @param error The variable to store the error. error is set to - * error_code::SUCCEED if there is an error. - */ - template - inline void tie(T &value, error_code &error) && noexcept; - -#if SIMDJSON_EXCEPTIONS - /** - * Read this element as a boolean. - * - * @return The boolean value - * @exception simdjson_error(INCORRECT_TYPE) if the JSON element is not a - * boolean. - */ - inline operator bool() const noexcept(false); - - /** - * Read this element as a null-terminated UTF-8 string. - * - * Be mindful that JSON allows strings to contain null characters. - * - * Does *not* convert other types to a string; requires that the JSON type - * of the element was - * an actual string. - * - * @return The string value. - * @exception simdjson_error(INCORRECT_TYPE) if the JSON element is not a - * string. - */ - inline explicit operator const char *() const noexcept(false); - - /** - * Read this element as a null-terminated UTF-8 string. - * - * Does *not* convert other types to a string; requires that the JSON type - * of the element was - * an actual string. - * - * @return The string value. - * @exception simdjson_error(INCORRECT_TYPE) if the JSON element is not a - * string. - */ - inline operator std::string_view() const noexcept(false); - - /** - * Read this element as an unsigned integer. - * - * @return The integer value. - * @exception simdjson_error(INCORRECT_TYPE) if the JSON element is not an - * integer - * @exception simdjson_error(NUMBER_OUT_OF_RANGE) if the integer doesn't fit - * in 64 bits or is negative - */ - inline operator uint64_t() const noexcept(false); - /** - * Read this element as an signed integer. - * - * @return The integer value. - * @exception simdjson_error(INCORRECT_TYPE) if the JSON element is not an - * integer - * @exception simdjson_error(NUMBER_OUT_OF_RANGE) if the integer doesn't fit - * in 64 bits - */ - inline operator int64_t() const noexcept(false); - /** - * Read this element as an double. - * - * @return The double value. - * @exception simdjson_error(INCORRECT_TYPE) if the JSON element is not a - * number - * @exception simdjson_error(NUMBER_OUT_OF_RANGE) if the integer doesn't fit - * in 64 bits or is negative - */ - inline operator double() const noexcept(false); - /** - * Read this element as a JSON array. - * - * @return The JSON array. - * @exception simdjson_error(INCORRECT_TYPE) if the JSON element is not an - * array - */ - inline operator array() const noexcept(false); - /** - * Read this element as a JSON object (key/value pairs). - * - * @return The JSON object. - * @exception simdjson_error(INCORRECT_TYPE) if the JSON element is not an - * object - */ - inline operator object() const noexcept(false); - - /** - * Iterate over each element in this array. - * - * @return The beginning of the iteration. - * @exception simdjson_error(INCORRECT_TYPE) if the JSON element is not an - * array - */ - inline dom::array::iterator begin() const noexcept(false); - - /** - * Iterate over each element in this array. - * - * @return The end of the iteration. - * @exception simdjson_error(INCORRECT_TYPE) if the JSON element is not an - * array - */ - inline dom::array::iterator end() const noexcept(false); -#endif // SIMDJSON_EXCEPTIONS - - /** - * Get the value associated with the given key. - * - * The key will be matched against **unescaped** JSON: - * - * dom::parser parser; - * int64_t(parser.parse(R"({ "a\n": 1 })"_padded)["a\n"]) == 1 - * parser.parse(R"({ "a\n": 1 })"_padded)["a\\n"].get_uint64().error() == - * NO_SUCH_FIELD - * - * @return The value associated with this field, or: - * - NO_SUCH_FIELD if the field does not exist in the object - * - INCORRECT_TYPE if this is not an object - */ - inline simdjson_result operator[](std::string_view key) const - noexcept; - - /** - * Get the value associated with the given key. - * - * The key will be matched against **unescaped** JSON: - * - * dom::parser parser; - * int64_t(parser.parse(R"({ "a\n": 1 })"_padded)["a\n"]) == 1 - * parser.parse(R"({ "a\n": 1 })"_padded)["a\\n"].get_uint64().error() == - * NO_SUCH_FIELD - * - * @return The value associated with this field, or: - * - NO_SUCH_FIELD if the field does not exist in the object - * - INCORRECT_TYPE if this is not an object - */ - inline simdjson_result operator[](const char *key) const noexcept; - - /** - * Get the value associated with the given JSON pointer. We use the RFC - * 6901 - * https://tools.ietf.org/html/rfc6901 standard. - * - * dom::parser parser; - * element doc = parser.parse(R"({ "foo": { "a": [ 10, 20, 30 ] - * }})"_padded); - * doc.at_pointer("/foo/a/1") == 20 - * doc.at_pointer("/foo")["a"].at(1) == 20 - * doc.at_pointer("")["foo"]["a"].at(1) == 20 - * - * It is allowed for a key to be the empty string: - * - * dom::parser parser; - * object obj = parser.parse(R"({ "": { "a": [ 10, 20, 30 ] }})"_padded); - * obj.at_pointer("//a/1") == 20 - * - * @return The value associated with the given JSON pointer, or: - * - NO_SUCH_FIELD if a field does not exist in an object - * - INDEX_OUT_OF_BOUNDS if an array index is larger than an array - * length - * - INCORRECT_TYPE if a non-integer is used to access an array - * - INVALID_JSON_POINTER if the JSON pointer is invalid and cannot - * be parsed - */ - inline simdjson_result at_pointer( - const std::string_view json_pointer) const noexcept; - -#ifndef SIMDJSON_DISABLE_DEPRECATED_API - /** - * - * Version 0.4 of simdjson used an incorrect interpretation of the JSON - * Pointer standard - * and allowed the following : - * - * dom::parser parser; - * element doc = parser.parse(R"({ "foo": { "a": [ 10, 20, 30 ] - * }})"_padded); - * doc.at("foo/a/1") == 20 - * - * Though it is intuitive, it is not compliant with RFC 6901 - * https://tools.ietf.org/html/rfc6901 - * - * For standard compliance, use the at_pointer function instead. - * - * @return The value associated with the given JSON pointer, or: - * - NO_SUCH_FIELD if a field does not exist in an object - * - INDEX_OUT_OF_BOUNDS if an array index is larger than an array - * length - * - INCORRECT_TYPE if a non-integer is used to access an array - * - INVALID_JSON_POINTER if the JSON pointer is invalid and cannot - * be parsed - */ - [ - [deprecated("For standard compliance, use at_pointer instead, and " - "prefix your pointers with a slash '/', see " - "RFC6901 ")]] inline simdjson_result - at(const std::string_view json_pointer) const noexcept; -#endif // SIMDJSON_DISABLE_DEPRECATED_API - - /** - * Get the value at the given index. - * - * @return The value at the given index, or: - * - INDEX_OUT_OF_BOUNDS if the array index is larger than an array - * length - */ - inline simdjson_result at(size_t index) const noexcept; - - /** - * Get the value associated with the given key. - * - * The key will be matched against **unescaped** JSON: - * - * dom::parser parser; - * int64_t(parser.parse(R"({ "a\n": 1 })"_padded)["a\n"]) == 1 - * parser.parse(R"({ "a\n": 1 })"_padded)["a\\n"].get_uint64().error() == - * NO_SUCH_FIELD - * - * @return The value associated with this field, or: - * - NO_SUCH_FIELD if the field does not exist in the object - */ - inline simdjson_result at_key(std::string_view key) const noexcept; - - /** - * Get the value associated with the given key in a case-insensitive manner. - * - * Note: The key will be matched against **unescaped** JSON. - * - * @return The value associated with this field, or: - * - NO_SUCH_FIELD if the field does not exist in the object - */ - inline simdjson_result at_key_case_insensitive( - std::string_view key) const noexcept; - - /** @private for debugging. Prints out the root element. */ - inline bool dump_raw_tape(std::ostream &out) const noexcept; - - private: - simdjson_really_inline element(const internal::tape_ref &tape) noexcept; - internal::tape_ref tape; - friend class document; - friend class object; - friend class array; - friend struct simdjson_result; - template - friend class simdjson::internal::string_builder; -}; - -} // namespace dom - -/** The result of a JSON navigation that may fail. */ -template <> -struct simdjson_result - : public internal::simdjson_result_base { - public: - simdjson_really_inline simdjson_result() noexcept; ///< @private - simdjson_really_inline simdjson_result( - dom::element &&value) noexcept; ///< @private - simdjson_really_inline simdjson_result( - error_code error) noexcept; ///< @private - - simdjson_really_inline simdjson_result type() const - noexcept; - template - simdjson_really_inline bool is() const noexcept; - template - simdjson_really_inline simdjson_result get() const noexcept; - template - simdjson_warn_unused simdjson_really_inline error_code get(T &value) const - noexcept; - - simdjson_really_inline simdjson_result get_array() const - noexcept; - simdjson_really_inline simdjson_result get_object() const - noexcept; - simdjson_really_inline simdjson_result get_c_str() const - noexcept; - simdjson_really_inline simdjson_result get_string_length() const - noexcept; - simdjson_really_inline simdjson_result get_string() const - noexcept; - simdjson_really_inline simdjson_result get_int64() const noexcept; - simdjson_really_inline simdjson_result get_uint64() const - noexcept; - simdjson_really_inline simdjson_result get_double() const noexcept; - simdjson_really_inline simdjson_result get_bool() const noexcept; - - simdjson_really_inline bool is_array() const noexcept; - simdjson_really_inline bool is_object() const noexcept; - simdjson_really_inline bool is_string() const noexcept; - simdjson_really_inline bool is_int64() const noexcept; - simdjson_really_inline bool is_uint64() const noexcept; - simdjson_really_inline bool is_double() const noexcept; - simdjson_really_inline bool is_number() const noexcept; - simdjson_really_inline bool is_bool() const noexcept; - simdjson_really_inline bool is_null() const noexcept; - - simdjson_really_inline simdjson_result operator[]( - std::string_view key) const noexcept; - simdjson_really_inline simdjson_result operator[]( - const char *key) const noexcept; - simdjson_really_inline simdjson_result at_pointer( - const std::string_view json_pointer) const noexcept; - [[deprecated( - "For standard compliance, use at_pointer instead, and prefix your " - "pointers with a slash '/', see RFC6901 ")]] simdjson_really_inline - simdjson_result - at(const std::string_view json_pointer) const noexcept; - simdjson_really_inline simdjson_result at(size_t index) const - noexcept; - simdjson_really_inline simdjson_result at_key( - std::string_view key) const noexcept; - simdjson_really_inline simdjson_result - at_key_case_insensitive(std::string_view key) const noexcept; - -#if SIMDJSON_EXCEPTIONS - simdjson_really_inline operator bool() const noexcept(false); - simdjson_really_inline explicit operator const char *() const - noexcept(false); - simdjson_really_inline operator std::string_view() const noexcept(false); - simdjson_really_inline operator uint64_t() const noexcept(false); - simdjson_really_inline operator int64_t() const noexcept(false); - simdjson_really_inline operator double() const noexcept(false); - simdjson_really_inline operator dom::array() const noexcept(false); - simdjson_really_inline operator dom::object() const noexcept(false); - - simdjson_really_inline dom::array::iterator begin() const noexcept(false); - simdjson_really_inline dom::array::iterator end() const noexcept(false); -#endif // SIMDJSON_EXCEPTIONS -}; - - -} // namespace simdjson - -#endif // SIMDJSON_DOM_DOCUMENT_H -/* end file include/simdjson/dom/element.h */ -/* begin file include/simdjson/dom/object.h */ -#ifndef SIMDJSON_DOM_OBJECT_H -#define SIMDJSON_DOM_OBJECT_H - - -namespace simdjson { -namespace internal { -template -class string_builder; -} -namespace dom { - -class document; -class element; -class key_value_pair; - -/** - * JSON object. - */ -class object { - public: - /** Create a new, invalid object */ - simdjson_really_inline object() noexcept; - - class iterator { - public: - using value_type = key_value_pair; - using difference_type = std::ptrdiff_t; - - /** - * Get the actual key/value pair - */ - inline const value_type operator*() const noexcept; - /** - * Get the next key/value pair. - * - * Part of the std::iterator interface. - * - */ - inline iterator &operator++() noexcept; - /** - * Get the next key/value pair. - * - * Part of the std::iterator interface. - * - */ - inline iterator operator++(int)noexcept; - /** - * Check if these values come from the same place in the JSON. - * - * Part of the std::iterator interface. - */ - inline bool operator!=(const iterator &other) const noexcept; - inline bool operator==(const iterator &other) const noexcept; - - inline bool operator<(const iterator &other) const noexcept; - inline bool operator<=(const iterator &other) const noexcept; - inline bool operator>=(const iterator &other) const noexcept; - inline bool operator>(const iterator &other) const noexcept; - /** - * Get the key of this key/value pair. - */ - inline std::string_view key() const noexcept; - /** - * Get the length (in bytes) of the key in this key/value pair. - * You should expect this function to be faster than key().size(). - */ - inline uint32_t key_length() const noexcept; - /** - * Returns true if the key in this key/value pair is equal - * to the provided string_view. - */ - inline bool key_equals(std::string_view o) const noexcept; - /** - * Returns true if the key in this key/value pair is equal - * to the provided string_view in a case-insensitive manner. - * Case comparisons may only be handled correctly for ASCII strings. - */ - inline bool key_equals_case_insensitive(std::string_view o) const - noexcept; - /** - * Get the key of this key/value pair. - */ - inline const char *key_c_str() const noexcept; - /** - * Get the value of this key/value pair. - */ - inline element value() const noexcept; - - iterator() noexcept = default; - iterator(const iterator &) noexcept = default; - iterator &operator=(const iterator &) noexcept = default; - - private: - simdjson_really_inline iterator( - const internal::tape_ref &tape) noexcept; - - internal::tape_ref tape; - - friend class object; - }; - - /** - * Return the first key/value pair. - * - * Part of the std::iterable interface. - */ - inline iterator begin() const noexcept; - /** - * One past the last key/value pair. - * - * Part of the std::iterable interface. - */ - inline iterator end() const noexcept; - /** - * Get the size of the object (number of keys). - * It is a saturated value with a maximum of 0xFFFFFF: if the value - * is 0xFFFFFF then the size is 0xFFFFFF or greater. - */ - inline size_t size() const noexcept; - /** - * Get the value associated with the given key. - * - * The key will be matched against **unescaped** JSON: - * - * dom::parser parser; - * int64_t(parser.parse(R"({ "a\n": 1 })"_padded)["a\n"]) == 1 - * parser.parse(R"({ "a\n": 1 })"_padded)["a\\n"].get_uint64().error() == - * NO_SUCH_FIELD - * - * This function has linear-time complexity: the keys are checked one by - * one. - * - * @return The value associated with this field, or: - * - NO_SUCH_FIELD if the field does not exist in the object - * - INCORRECT_TYPE if this is not an object - */ - inline simdjson_result operator[](std::string_view key) const - noexcept; - - /** - * Get the value associated with the given key. - * - * The key will be matched against **unescaped** JSON: - * - * dom::parser parser; - * int64_t(parser.parse(R"({ "a\n": 1 })"_padded)["a\n"]) == 1 - * parser.parse(R"({ "a\n": 1 })"_padded)["a\\n"].get_uint64().error() == - * NO_SUCH_FIELD - * - * This function has linear-time complexity: the keys are checked one by - * one. - * - * @return The value associated with this field, or: - * - NO_SUCH_FIELD if the field does not exist in the object - * - INCORRECT_TYPE if this is not an object - */ - inline simdjson_result operator[](const char *key) const noexcept; - - /** - * Get the value associated with the given JSON pointer. We use the RFC 6901 - * https://tools.ietf.org/html/rfc6901 standard, interpreting the current - * node - * as the root of its own JSON document. - * - * dom::parser parser; - * object obj = parser.parse(R"({ "foo": { "a": [ 10, 20, 30 ] - * }})"_padded); - * obj.at_pointer("/foo/a/1") == 20 - * obj.at_pointer("/foo")["a"].at(1) == 20 - * - * It is allowed for a key to be the empty string: - * - * dom::parser parser; - * object obj = parser.parse(R"({ "": { "a": [ 10, 20, 30 ] }})"_padded); - * obj.at_pointer("//a/1") == 20 - * obj.at_pointer("/")["a"].at(1) == 20 - * - * @return The value associated with the given JSON pointer, or: - * - NO_SUCH_FIELD if a field does not exist in an object - * - INDEX_OUT_OF_BOUNDS if an array index is larger than an array - * length - * - INCORRECT_TYPE if a non-integer is used to access an array - * - INVALID_JSON_POINTER if the JSON pointer is invalid and cannot - * be parsed - */ - inline simdjson_result at_pointer( - std::string_view json_pointer) const noexcept; - - /** - * Get the value associated with the given key. - * - * The key will be matched against **unescaped** JSON: - * - * dom::parser parser; - * int64_t(parser.parse(R"({ "a\n": 1 })"_padded)["a\n"]) == 1 - * parser.parse(R"({ "a\n": 1 })"_padded)["a\\n"].get_uint64().error() == - * NO_SUCH_FIELD - * - * This function has linear-time complexity: the keys are checked one by - * one. - * - * @return The value associated with this field, or: - * - NO_SUCH_FIELD if the field does not exist in the object - */ - inline simdjson_result at_key(std::string_view key) const noexcept; - - /** - * Get the value associated with the given key in a case-insensitive manner. - * It is only guaranteed to work over ASCII inputs. - * - * Note: The key will be matched against **unescaped** JSON. - * - * This function has linear-time complexity: the keys are checked one by - * one. - * - * @return The value associated with this field, or: - * - NO_SUCH_FIELD if the field does not exist in the object - */ - inline simdjson_result at_key_case_insensitive( - std::string_view key) const noexcept; - - private: - simdjson_really_inline object(const internal::tape_ref &tape) noexcept; - - internal::tape_ref tape; - - friend class element; - friend struct simdjson_result; - template - friend class simdjson::internal::string_builder; -}; - -/** - * Key/value pair in an object. - */ -class key_value_pair { - public: - /** key in the key-value pair **/ - std::string_view key; - /** value in the key-value pair **/ - element value; - - private: - simdjson_really_inline key_value_pair(std::string_view _key, - element _value) noexcept; - friend class object; -}; - -} // namespace dom - -/** The result of a JSON conversion that may fail. */ -template <> -struct simdjson_result - : public internal::simdjson_result_base { - public: - simdjson_really_inline simdjson_result() noexcept; ///< @private - simdjson_really_inline simdjson_result( - dom::object value) noexcept; ///< @private - simdjson_really_inline simdjson_result( - error_code error) noexcept; ///< @private - - inline simdjson_result operator[](std::string_view key) const - noexcept; - inline simdjson_result operator[](const char *key) const - noexcept; - inline simdjson_result at_pointer( - std::string_view json_pointer) const noexcept; - inline simdjson_result at_key(std::string_view key) const - noexcept; - inline simdjson_result at_key_case_insensitive( - std::string_view key) const noexcept; - -#if SIMDJSON_EXCEPTIONS - inline dom::object::iterator begin() const noexcept(false); - inline dom::object::iterator end() const noexcept(false); - inline size_t size() const noexcept(false); -#endif // SIMDJSON_EXCEPTIONS -}; - -} // namespace simdjson - -#if defined(__cpp_lib_ranges) -#include - -namespace std { -namespace ranges { -template <> -inline constexpr bool enable_view = true; -#if SIMDJSON_EXCEPTIONS -template <> -inline constexpr bool - enable_view> = true; -#endif // SIMDJSON_EXCEPTIONS -} // namespace ranges -} // namespace std -#endif // defined(__cpp_lib_ranges) - -#endif // SIMDJSON_DOM_OBJECT_H -/* end file include/simdjson/dom/object.h */ -/* begin file include/simdjson/dom/serialization.h */ -#ifndef SIMDJSON_SERIALIZATION_H -#define SIMDJSON_SERIALIZATION_H - -#include - -namespace simdjson { - -/** - * The string_builder template and mini_formatter class - * are not part of our public API and are subject to change - * at any time! - */ -namespace internal { - -class mini_formatter; - -/** - * @private The string_builder template allows us to construct - * a string from a document element. It is parametrized - * by a "formatter" which handles the details. Thus - * the string_builder template could support both minification - * and prettification, and various other tradeoffs. - */ -template -class string_builder { - public: - /** Construct an initially empty builder, would print the empty string **/ - string_builder() = default; - /** Append an element to the builder (to be printed) **/ - inline void append(simdjson::dom::element value); - /** Append an array to the builder (to be printed) **/ - inline void append(simdjson::dom::array value); - /** Append an object to the builder (to be printed) **/ - inline void append(simdjson::dom::object value); - /** Reset the builder (so that it would print the empty string) **/ - simdjson_really_inline void clear(); - /** - * Get access to the string. The string_view is owned by the builder - * and it is invalid to use it after the string_builder has been - * destroyed. - * However you can make a copy of the string_view on memory that you - * own. - */ - simdjson_really_inline std::string_view str() const; - /** Append a key_value_pair to the builder (to be printed) **/ - simdjson_really_inline void append(simdjson::dom::key_value_pair value); - - private: - formatter format{}; -}; - -/** - * @private This is the class that we expect to use with the string_builder - * template. It tries to produce a compact version of the JSON element - * as quickly as possible. - */ -class mini_formatter { - public: - mini_formatter() = default; - /** Add a comma **/ - simdjson_really_inline void comma(); - /** Start an array, prints [ **/ - simdjson_really_inline void start_array(); - /** End an array, prints ] **/ - simdjson_really_inline void end_array(); - /** Start an array, prints { **/ - simdjson_really_inline void start_object(); - /** Start an array, prints } **/ - simdjson_really_inline void end_object(); - /** Prints a true **/ - simdjson_really_inline void true_atom(); - /** Prints a false **/ - simdjson_really_inline void false_atom(); - /** Prints a null **/ - simdjson_really_inline void null_atom(); - /** Prints a number **/ - simdjson_really_inline void number(int64_t x); - /** Prints a number **/ - simdjson_really_inline void number(uint64_t x); - /** Prints a number **/ - simdjson_really_inline void number(double x); - /** Prints a key (string + colon) **/ - simdjson_really_inline void key(std::string_view unescaped); - /** Prints a string. The string is escaped as needed. **/ - simdjson_really_inline void string(std::string_view unescaped); - /** Clears out the content. **/ - simdjson_really_inline void clear(); - /** - * Get access to the buffer, it is owned by the instance, but - * the user can make a copy. - **/ - simdjson_really_inline std::string_view str() const; - - private: - // implementation details (subject to change) - /** Prints one character **/ - simdjson_really_inline void one_char(char c); - /** Backing buffer **/ - std::vector buffer{}; // not ideal! -}; - -} // internal - -namespace dom { - -/** - * Print JSON to an output stream. - * - * @param out The output stream. - * @param value The element. - * @throw if there is an error with the underlying output stream. simdjson - * itself will not throw. - */ -inline std::ostream &operator<<(std::ostream &out, - simdjson::dom::element value) { - simdjson::internal::string_builder<> sb; - sb.append(value); - return (out << sb.str()); -} -#if SIMDJSON_EXCEPTIONS -inline std::ostream &operator<<( - std::ostream &out, simdjson::simdjson_result x) { - if (x.error()) { - throw simdjson::simdjson_error(x.error()); - } - return (out << x.value()); -} -#endif -/** - * Print JSON to an output stream. - * - * @param out The output stream. - * @param value The array. - * @throw if there is an error with the underlying output stream. simdjson - * itself will not throw. - */ -inline std::ostream &operator<<(std::ostream &out, simdjson::dom::array value) { - simdjson::internal::string_builder<> sb; - sb.append(value); - return (out << sb.str()); -} -#if SIMDJSON_EXCEPTIONS -inline std::ostream &operator<<( - std::ostream &out, simdjson::simdjson_result x) { - if (x.error()) { - throw simdjson::simdjson_error(x.error()); - } - return (out << x.value()); -} -#endif -/** - * Print JSON to an output stream. - * - * @param out The output stream. - * @param value The object. - * @throw if there is an error with the underlying output stream. simdjson - * itself will not throw. - */ -inline std::ostream &operator<<(std::ostream &out, - simdjson::dom::object value) { - simdjson::internal::string_builder<> sb; - sb.append(value); - return (out << sb.str()); -} -#if SIMDJSON_EXCEPTIONS -inline std::ostream &operator<<( - std::ostream &out, simdjson::simdjson_result x) { - if (x.error()) { - throw simdjson::simdjson_error(x.error()); - } - return (out << x.value()); -} -#endif -} // namespace dom - -/** - * Converts JSON to a string. - * - * dom::parser parser; - * element doc = parser.parse(" [ 1 , 2 , 3 ] "_padded); - * cout << to_string(doc) << endl; // prints [1,2,3] - * - */ -template -std::string to_string(T x) { - // in C++, to_string is standard: - // http://www.cplusplus.com/reference/string/to_string/ - // Currently minify and to_string are identical but in the future, they may - // differ. - simdjson::internal::string_builder<> sb; - sb.append(x); - std::string_view answer = sb.str(); - return std::string(answer.data(), answer.size()); -} -#if SIMDJSON_EXCEPTIONS -template -std::string to_string(simdjson_result x) { - if (x.error()) { - throw simdjson_error(x.error()); - } - return to_string(x.value()); -} -#endif - -/** - * Minifies a JSON element or document, printing the smallest possible valid - * JSON. - * - * dom::parser parser; - * element doc = parser.parse(" [ 1 , 2 , 3 ] "_padded); - * cout << minify(doc) << endl; // prints [1,2,3] - * - */ -template -std::string minify(T x) { - return to_string(x); -} - -#if SIMDJSON_EXCEPTIONS -template -std::string minify(simdjson_result x) { - if (x.error()) { - throw simdjson_error(x.error()); - } - return to_string(x.value()); -} -#endif - - -} // namespace simdjson - - -#endif -/* end file include/simdjson/dom/serialization.h */ - -// Deprecated API -/* begin file include/simdjson/dom/jsonparser.h */ -// TODO Remove this -- deprecated API and files - -#ifndef SIMDJSON_DOM_JSONPARSER_H -#define SIMDJSON_DOM_JSONPARSER_H - -/* begin file include/simdjson/dom/parsedjson.h */ -// TODO Remove this -- deprecated API and files - -#ifndef SIMDJSON_DOM_PARSEDJSON_H -#define SIMDJSON_DOM_PARSEDJSON_H - - -namespace simdjson { - -/** - * @deprecated Use `dom::parser` instead. - */ -using ParsedJson[[deprecated("Use dom::parser instead")]] = dom::parser; - -} // namespace simdjson - -#endif // SIMDJSON_DOM_PARSEDJSON_H -/* end file include/simdjson/dom/parsedjson.h */ -/* begin file include/simdjson/jsonioutil.h */ -#ifndef SIMDJSON_JSONIOUTIL_H -#define SIMDJSON_JSONIOUTIL_H - - -namespace simdjson { - -#if SIMDJSON_EXCEPTIONS -#ifndef SIMDJSON_DISABLE_DEPRECATED_API -[[deprecated("Use padded_string::load() instead")]] inline padded_string -get_corpus(const char *path) { - return padded_string::load(path); -} -#endif // SIMDJSON_DISABLE_DEPRECATED_API -#endif // SIMDJSON_EXCEPTIONS - -} // namespace simdjson - -#endif // SIMDJSON_JSONIOUTIL_H -/* end file include/simdjson/jsonioutil.h */ - -namespace simdjson { - -// -// C API (json_parse and build_parsed_json) declarations -// - -#ifndef SIMDJSON_DISABLE_DEPRECATED_API -[[deprecated("Use parser.parse() instead")]] inline int json_parse( - const uint8_t *buf, - size_t len, - dom::parser &parser, - bool realloc_if_needed = true) noexcept { - error_code code = parser.parse(buf, len, realloc_if_needed).error(); - // The deprecated json_parse API is a signal that the user plans to *use* - // the error code / valid - // bits in the parser instead of heeding the result code. The normal parser - // unsets those in - // anticipation of making the error code ephemeral. - // Here we put the code back into the parser, until we've removed this - // method. - parser.valid = code == SUCCESS; - parser.error = code; - return code; -} -[[deprecated("Use parser.parse() instead")]] inline int json_parse( - const char *buf, - size_t len, - dom::parser &parser, - bool realloc_if_needed = true) noexcept { - error_code code = parser.parse(buf, len, realloc_if_needed).error(); - // The deprecated json_parse API is a signal that the user plans to *use* - // the error code / valid - // bits in the parser instead of heeding the result code. The normal parser - // unsets those in - // anticipation of making the error code ephemeral. - // Here we put the code back into the parser, until we've removed this - // method. - parser.valid = code == SUCCESS; - parser.error = code; - return code; -} -[[deprecated("Use parser.parse() instead")]] inline int json_parse( - const std::string &s, - dom::parser &parser, - bool realloc_if_needed = true) noexcept { - error_code code = - parser.parse(s.data(), s.length(), realloc_if_needed).error(); - // The deprecated json_parse API is a signal that the user plans to *use* - // the error code / valid - // bits in the parser instead of heeding the result code. The normal parser - // unsets those in - // anticipation of making the error code ephemeral. - // Here we put the code back into the parser, until we've removed this - // method. - parser.valid = code == SUCCESS; - parser.error = code; - return code; -} -[[deprecated("Use parser.parse() instead")]] inline int json_parse( - const padded_string &s, dom::parser &parser) noexcept { - error_code code = parser.parse(s).error(); - // The deprecated json_parse API is a signal that the user plans to *use* - // the error code / valid - // bits in the parser instead of heeding the result code. The normal parser - // unsets those in - // anticipation of making the error code ephemeral. - // Here we put the code back into the parser, until we've removed this - // method. - parser.valid = code == SUCCESS; - parser.error = code; - return code; -} - -[[deprecated( - "Use parser.parse() instead")]] simdjson_warn_unused inline dom::parser -build_parsed_json(const uint8_t *buf, - size_t len, - bool realloc_if_needed = true) noexcept { - dom::parser parser; - error_code code = parser.parse(buf, len, realloc_if_needed).error(); - // The deprecated json_parse API is a signal that the user plans to *use* - // the error code / valid - // bits in the parser instead of heeding the result code. The normal parser - // unsets those in - // anticipation of making the error code ephemeral. - // Here we put the code back into the parser, until we've removed this - // method. - parser.valid = code == SUCCESS; - parser.error = code; - return parser; -} -[[deprecated( - "Use parser.parse() instead")]] simdjson_warn_unused inline dom::parser -build_parsed_json(const char *buf, - size_t len, - bool realloc_if_needed = true) noexcept { - dom::parser parser; - error_code code = parser.parse(buf, len, realloc_if_needed).error(); - // The deprecated json_parse API is a signal that the user plans to *use* - // the error code / valid - // bits in the parser instead of heeding the result code. The normal parser - // unsets those in - // anticipation of making the error code ephemeral. - // Here we put the code back into the parser, until we've removed this - // method. - parser.valid = code == SUCCESS; - parser.error = code; - return parser; -} -[[deprecated( - "Use parser.parse() instead")]] simdjson_warn_unused inline dom::parser -build_parsed_json(const std::string &s, - bool realloc_if_needed = true) noexcept { - dom::parser parser; - error_code code = - parser.parse(s.data(), s.length(), realloc_if_needed).error(); - // The deprecated json_parse API is a signal that the user plans to *use* - // the error code / valid - // bits in the parser instead of heeding the result code. The normal parser - // unsets those in - // anticipation of making the error code ephemeral. - // Here we put the code back into the parser, until we've removed this - // method. - parser.valid = code == SUCCESS; - parser.error = code; - return parser; -} -[[deprecated( - "Use parser.parse() instead")]] simdjson_warn_unused inline dom::parser -build_parsed_json(const padded_string &s) noexcept { - dom::parser parser; - error_code code = parser.parse(s).error(); - // The deprecated json_parse API is a signal that the user plans to *use* - // the error code / valid - // bits in the parser instead of heeding the result code. The normal parser - // unsets those in - // anticipation of making the error code ephemeral. - // Here we put the code back into the parser, until we've removed this - // method. - parser.valid = code == SUCCESS; - parser.error = code; - return parser; -} -#endif // SIMDJSON_DISABLE_DEPRECATED_API - -/** @private We do not want to allow implicit conversion from C string to - * std::string. */ -int json_parse(const char *buf, dom::parser &parser) noexcept = delete; -/** @private We do not want to allow implicit conversion from C string to - * std::string. */ -dom::parser build_parsed_json(const char *buf) noexcept = delete; - -} // namespace simdjson - -#endif // SIMDJSON_DOM_JSONPARSER_H -/* end file include/simdjson/dom/jsonparser.h */ -/* begin file include/simdjson/dom/parsedjson_iterator.h */ -// TODO Remove this -- deprecated API and files - -#ifndef SIMDJSON_DOM_PARSEDJSON_ITERATOR_H -#define SIMDJSON_DOM_PARSEDJSON_ITERATOR_H - -#include -#include -#include -#include -#include -#include - -/* begin file include/simdjson/internal/jsonformatutils.h */ -#ifndef SIMDJSON_INTERNAL_JSONFORMATUTILS_H -#define SIMDJSON_INTERNAL_JSONFORMATUTILS_H - -#include -#include -#include - -namespace simdjson { -namespace internal { - -class escape_json_string; - -inline std::ostream &operator<<(std::ostream &out, - const escape_json_string &str); - -class escape_json_string { - public: - escape_json_string(std::string_view _str) noexcept : str{_str} {} - operator std::string() const noexcept { - std::stringstream s; - s << *this; - return s.str(); - } - - private: - std::string_view str; - friend std::ostream &operator<<(std::ostream &out, - const escape_json_string &unescaped); -}; - -inline std::ostream &operator<<(std::ostream &out, - const escape_json_string &unescaped) { - for (size_t i = 0; i < unescaped.str.length(); i++) { - switch (unescaped.str[i]) { - case '\b': - out << "\\b"; - break; - case '\f': - out << "\\f"; - break; - case '\n': - out << "\\n"; - break; - case '\r': - out << "\\r"; - break; - case '\"': - out << "\\\""; - break; - case '\t': - out << "\\t"; - break; - case '\\': - out << "\\\\"; - break; - default: - if (static_cast(unescaped.str[i]) <= 0x1F) { - // TODO can this be done once at the beginning, or will it - // mess up << char? - std::ios::fmtflags f(out.flags()); - out << "\\u" << std::hex << std::setw(4) - << std::setfill('0') << int(unescaped.str[i]); - out.flags(f); - } else { - out << unescaped.str[i]; - } - } - } - return out; -} - -} // namespace internal -} // namespace simdjson - -#endif // SIMDJSON_INTERNAL_JSONFORMATUTILS_H -/* end file include/simdjson/internal/jsonformatutils.h */ - -#ifndef SIMDJSON_DISABLE_DEPRECATED_API - -namespace simdjson { -/** @private **/ -class[[deprecated( - "Use the new DOM navigation API instead (see doc/basics.md)")]] dom:: - parser::Iterator { - public: - inline Iterator(const dom::parser &parser) noexcept(false); - inline Iterator(const Iterator &o) noexcept; - inline ~Iterator() noexcept; - - inline Iterator &operator=(const Iterator &) = delete; - - inline bool is_ok() const; - - // useful for debugging purposes - inline size_t get_tape_location() const; - - // useful for debugging purposes - inline size_t get_tape_length() const; - - // returns the current depth (start at 1 with 0 reserved for the fictitious - // root node) - inline size_t get_depth() const; - - // A scope is a series of nodes at the same depth, typically it is either an - // object ({) or an array ([). The root node has type 'r'. - inline uint8_t get_scope_type() const; - - // move forward in document order - inline bool move_forward(); - - // retrieve the character code of what we're looking at: - // [{"slutfn are the possibilities - inline uint8_t get_type() const { - return current_type; // short functions should be inlined! - } - - // get the int64_t value at this node; valid only if get_type is "l" - inline int64_t get_integer() const { - if (location + 1 >= tape_length) { - return 0; // default value in case of error - } - return static_cast(doc.tape[location + 1]); - } - - // get the value as uint64; valid only if if get_type is "u" - inline uint64_t get_unsigned_integer() const { - if (location + 1 >= tape_length) { - return 0; // default value in case of error - } - return doc.tape[location + 1]; - } - - // get the string value at this node (NULL ended); valid only if get_type is - // " - // note that tabs, and line endings are escaped in the returned value (see - // print_with_escapes) return value is valid UTF-8, it may contain NULL - // chars - // within the string: get_string_length determines the true string length. - inline const char *get_string() const { - return reinterpret_cast( - doc.string_buf.get() + (current_val & internal::JSON_VALUE_MASK) + - sizeof(uint32_t)); - } - - // return the length of the string in bytes - inline uint32_t get_string_length() const { - uint32_t answer; - std::memcpy(&answer, - reinterpret_cast( - doc.string_buf.get() + - (current_val & internal::JSON_VALUE_MASK)), - sizeof(uint32_t)); - return answer; - } - - // get the double value at this node; valid only if - // get_type() is "d" - inline double get_double() const { - if (location + 1 >= tape_length) { - return std::numeric_limits::quiet_NaN(); // default value - // in - // case of error - } - double answer; - std::memcpy(&answer, &doc.tape[location + 1], sizeof(answer)); - return answer; - } - - inline bool is_object_or_array() const { return is_object() || is_array(); } - - inline bool is_object() const { return get_type() == '{'; } - - inline bool is_array() const { return get_type() == '['; } - - inline bool is_string() const { return get_type() == '"'; } - - // Returns true if the current type of the node is an signed integer. - // You can get its value with `get_integer()`. - inline bool is_integer() const { return get_type() == 'l'; } - - // Returns true if the current type of the node is an unsigned integer. - // You can get its value with `get_unsigned_integer()`. - // - // NOTE: - // Only a large value, which is out of range of a 64-bit signed integer, is - // represented internally as an unsigned node. On the other hand, a typical - // positive integer, such as 1, 42, or 1000000, is as a signed node. - // Be aware this function returns false for a signed node. - inline bool is_unsigned_integer() const { return get_type() == 'u'; } - // Returns true if the current type of the node is a double floating-point - // number. - inline bool is_double() const { return get_type() == 'd'; } - // Returns true if the current type of the node is a number (integer or - // floating-point). - inline bool is_number() const { - return is_integer() || is_unsigned_integer() || is_double(); - } - // Returns true if the current type of the node is a bool with true value. - inline bool is_true() const { return get_type() == 't'; } - // Returns true if the current type of the node is a bool with false value. - inline bool is_false() const { return get_type() == 'f'; } - // Returns true if the current type of the node is null. - inline bool is_null() const { return get_type() == 'n'; } - // Returns true if the type byte represents an object of an array - static bool is_object_or_array(uint8_t type) { - return ((type == '[') || (type == '{')); - } - - // when at {, go one level deep, looking for a given key - // if successful, we are left pointing at the value, - // if not, we are still pointing at the object ({) - // (in case of repeated keys, this only finds the first one). - // We seek the key using C's strcmp so if your JSON strings contain - // NULL chars, this would trigger a false positive: if you expect that - // to be the case, take extra precautions. - // Furthermore, we do the comparison character-by-character - // without taking into account Unicode equivalence. - inline bool move_to_key(const char *key); - - // as above, but case insensitive lookup (strcmpi instead of strcmp) - inline bool move_to_key_insensitive(const char *key); - - // when at {, go one level deep, looking for a given key - // if successful, we are left pointing at the value, - // if not, we are still pointing at the object ({) - // (in case of repeated keys, this only finds the first one). - // The string we search for can contain NULL values. - // Furthermore, we do the comparison character-by-character - // without taking into account Unicode equivalence. - inline bool move_to_key(const char *key, uint32_t length); - - // when at a key location within an object, this moves to the accompanying - // value (located next to it). This is equivalent but much faster than - // calling "next()". - inline void move_to_value(); - - // when at [, go one level deep, and advance to the given index. - // if successful, we are left pointing at the value, - // if not, we are still pointing at the array ([) - inline bool move_to_index(uint32_t index); - - // Moves the iterator to the value corresponding to the json pointer. - // Always search from the root of the document. - // if successful, we are left pointing at the value, - // if not, we are still pointing the same value we were pointing before the - // call. The json pointer follows the rfc6901 standard's syntax: - // https://tools.ietf.org/html/rfc6901 However, the standard says "If a - // referenced member name is not unique in an object, the member that is - // referenced is undefined, and evaluation fails". Here we just return the - // first corresponding value. The length parameter is the length of the - // jsonpointer string ('pointer'). - inline bool move_to(const char *pointer, uint32_t length); - - // Moves the iterator to the value corresponding to the json pointer. - // Always search from the root of the document. - // if successful, we are left pointing at the value, - // if not, we are still pointing the same value we were pointing before the - // call. The json pointer implementation follows the rfc6901 standard's - // syntax: https://tools.ietf.org/html/rfc6901 However, the standard says - // "If a referenced member name is not unique in an object, the member that - // is referenced is undefined, and evaluation fails". Here we just return - // the first corresponding value. - inline bool move_to(const std::string &pointer) { - return move_to(pointer.c_str(), uint32_t(pointer.length())); - } - - private: - // Almost the same as move_to(), except it searches from the current - // position. The pointer's syntax is identical, though that case is not - // handled by the rfc6901 standard. The '/' is still required at the - // beginning. However, contrary to move_to(), the URI Fragment Identifier - // Representation is not supported here. Also, in case of failure, we are - // left pointing at the closest value it could reach. For these reasons it - // is private. It exists because it is used by move_to(). - inline bool relative_move_to(const char *pointer, uint32_t length); - - public: - // throughout return true if we can do the navigation, false - // otherwise - - // Within a given scope (series of nodes at the same depth within either an - // array or an object), we move forward. - // Thus, given [true, null, {"a":1}, [1,2]], we would visit true, null, { - // and [. At the object ({) or at the array ([), you can issue a "down" to - // visit their content. valid if we're not at the end of a scope (returns - // true). - inline bool next(); - - // Within a given scope (series of nodes at the same depth within either an - // array or an object), we move backward. - // Thus, given [true, null, {"a":1}, [1,2]], we would visit ], }, null, true - // when starting at the end of the scope. At the object ({) or at the array - // ([), you can issue a "down" to visit their content. - // Performance warning: This function is implemented by starting again - // from the beginning of the scope and scanning forward. You should expect - // it to be relatively slow. - inline bool prev(); - - // Moves back to either the containing array or object (type { or [) from - // within a contained scope. - // Valid unless we are at the first level of the document - inline bool up(); - - // Valid if we're at a [ or { and it starts a non-empty scope; moves us to - // start of that deeper scope if it not empty. Thus, given [true, null, - // {"a":1}, [1,2]], if we are at the { node, we would move to the "a" node. - inline bool down(); - - // move us to the start of our current scope, - // a scope is a series of nodes at the same level - inline void to_start_scope(); - - inline void rewind() { - while (up()) - ; - } - - - // print the node we are currently pointing at - inline bool print(std::ostream & os, bool escape_strings = true) const; - - private: - const document &doc; - size_t max_depth{}; - size_t depth{}; - size_t location{}; // our current location on a tape - size_t tape_length{}; - uint8_t current_type{}; - uint64_t current_val{}; - typedef struct { - size_t start_of_scope; - uint8_t scope_type; - } scopeindex_t; - - scopeindex_t *depth_index{}; -}; - -} // namespace simdjson -#endif // SIMDJSON_DISABLE_DEPRECATED_API - -#endif // SIMDJSON_DOM_PARSEDJSON_ITERATOR_H -/* end file include/simdjson/dom/parsedjson_iterator.h */ - -// Inline functions -/* begin file include/simdjson/dom/array-inl.h */ -#ifndef SIMDJSON_INLINE_ARRAY_H -#define SIMDJSON_INLINE_ARRAY_H - -// Inline implementations go in here. - -#include - -namespace simdjson { - -// -// simdjson_result inline implementation -// -simdjson_really_inline simdjson_result::simdjson_result() noexcept - : internal::simdjson_result_base() {} -simdjson_really_inline simdjson_result::simdjson_result( - dom::array value) noexcept - : internal::simdjson_result_base( - std::forward(value)) {} -simdjson_really_inline simdjson_result::simdjson_result( - error_code error) noexcept - : internal::simdjson_result_base(error) {} - -#if SIMDJSON_EXCEPTIONS - -inline dom::array::iterator simdjson_result::begin() const - noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first.begin(); -} -inline dom::array::iterator simdjson_result::end() const - noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first.end(); -} -inline size_t simdjson_result::size() const noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first.size(); -} - -#endif // SIMDJSON_EXCEPTIONS - -inline simdjson_result simdjson_result::at_pointer( - std::string_view json_pointer) const noexcept { - if (error()) { - return error(); - } - return first.at_pointer(json_pointer); -} -inline simdjson_result simdjson_result::at( - size_t index) const noexcept { - if (error()) { - return error(); - } - return first.at(index); -} - -namespace dom { - -// -// array inline implementation -// -simdjson_really_inline array::array() noexcept : tape{} {} -simdjson_really_inline array::array(const internal::tape_ref &_tape) noexcept - : tape{_tape} {} -inline array::iterator array::begin() const noexcept { - return internal::tape_ref(tape.doc, tape.json_index + 1); -} -inline array::iterator array::end() const noexcept { - return internal::tape_ref(tape.doc, tape.after_element() - 1); -} -inline size_t array::size() const noexcept { return tape.scope_count(); } -inline size_t array::number_of_slots() const noexcept { - return tape.matching_brace_index() - tape.json_index; -} -inline simdjson_result array::at_pointer( - std::string_view json_pointer) const noexcept { - if (json_pointer.empty()) { // an empty string means that we return the - // current node - return element(this->tape); // copy the current node - } else if (json_pointer[0] != '/') { // otherwise there is an error - return INVALID_JSON_POINTER; - } - json_pointer = json_pointer.substr(1); - // - means "the append position" or "the element after the end of the array" - // We don't support this, because we're returning a real element, not a - // position. - if (json_pointer == "-") { - return INDEX_OUT_OF_BOUNDS; - } - - // Read the array index - size_t array_index = 0; - size_t i; - for (i = 0; i < json_pointer.length() && json_pointer[i] != '/'; i++) { - uint8_t digit = uint8_t(json_pointer[i] - '0'); - // Check for non-digit in array index. If it's there, we're trying to - // get a field in an object - if (digit > 9) { - return INCORRECT_TYPE; - } - array_index = array_index * 10 + digit; - } - - // 0 followed by other digits is invalid - if (i > 1 && json_pointer[0] == '0') { - return INVALID_JSON_POINTER; - } // "JSON pointer array index has other characters after 0" - - // Empty string is invalid; so is a "/" with no digits before it - if (i == 0) { - return INVALID_JSON_POINTER; - } // "Empty string in JSON pointer array index" - - // Get the child - auto child = array(tape).at(array_index); - // If there is an error, it ends here - if (child.error()) { - return child; - } - // If there is a /, we're not done yet, call recursively. - if (i < json_pointer.length()) { - child = child.at_pointer(json_pointer.substr(i)); - } - return child; -} - -inline simdjson_result array::at(size_t index) const noexcept { - size_t i = 0; - for (auto element : *this) { - if (i == index) { - return element; - } - i++; - } - return INDEX_OUT_OF_BOUNDS; -} - -// -// array::iterator inline implementation -// -simdjson_really_inline array::iterator::iterator( - const internal::tape_ref &_tape) noexcept : tape{_tape} {} -inline element array::iterator::operator*() const noexcept { - return element(tape); -} -inline array::iterator &array::iterator::operator++() noexcept { - tape.json_index = tape.after_element(); - return *this; -} -inline array::iterator array::iterator::operator++(int)noexcept { - array::iterator out = *this; - ++*this; - return out; -} -inline bool array::iterator::operator!=(const array::iterator &other) const - noexcept { - return tape.json_index != other.tape.json_index; -} -inline bool array::iterator::operator==(const array::iterator &other) const - noexcept { - return tape.json_index == other.tape.json_index; -} -inline bool array::iterator::operator<(const array::iterator &other) const - noexcept { - return tape.json_index < other.tape.json_index; -} -inline bool array::iterator::operator<=(const array::iterator &other) const - noexcept { - return tape.json_index <= other.tape.json_index; -} -inline bool array::iterator::operator>=(const array::iterator &other) const - noexcept { - return tape.json_index >= other.tape.json_index; -} -inline bool array::iterator::operator>(const array::iterator &other) const - noexcept { - return tape.json_index > other.tape.json_index; -} - -} // namespace dom - - -} // namespace simdjson - -/* begin file include/simdjson/dom/element-inl.h */ -#ifndef SIMDJSON_INLINE_ELEMENT_H -#define SIMDJSON_INLINE_ELEMENT_H - -#include -#include - -namespace simdjson { - -// -// simdjson_result inline implementation -// -simdjson_really_inline simdjson_result::simdjson_result() noexcept - : internal::simdjson_result_base() {} -simdjson_really_inline simdjson_result::simdjson_result( - dom::element &&value) noexcept - : internal::simdjson_result_base( - std::forward(value)) {} -simdjson_really_inline simdjson_result::simdjson_result( - error_code error) noexcept - : internal::simdjson_result_base(error) {} -inline simdjson_result simdjson_result::type() - const noexcept { - if (error()) { - return error(); - } - return first.type(); -} - -template -simdjson_really_inline bool simdjson_result::is() const noexcept { - return !error() && first.is(); -} -template -simdjson_really_inline simdjson_result simdjson_result::get() - const noexcept { - if (error()) { - return error(); - } - return first.get(); -} -template -simdjson_warn_unused simdjson_really_inline error_code -simdjson_result::get(T &value) const noexcept { - if (error()) { - return error(); - } - return first.get(value); -} - -simdjson_really_inline simdjson_result -simdjson_result::get_array() const noexcept { - if (error()) { - return error(); - } - return first.get_array(); -} -simdjson_really_inline simdjson_result -simdjson_result::get_object() const noexcept { - if (error()) { - return error(); - } - return first.get_object(); -} -simdjson_really_inline simdjson_result -simdjson_result::get_c_str() const noexcept { - if (error()) { - return error(); - } - return first.get_c_str(); -} -simdjson_really_inline simdjson_result -simdjson_result::get_string_length() const noexcept { - if (error()) { - return error(); - } - return first.get_string_length(); -} -simdjson_really_inline simdjson_result -simdjson_result::get_string() const noexcept { - if (error()) { - return error(); - } - return first.get_string(); -} -simdjson_really_inline simdjson_result -simdjson_result::get_int64() const noexcept { - if (error()) { - return error(); - } - return first.get_int64(); -} -simdjson_really_inline simdjson_result -simdjson_result::get_uint64() const noexcept { - if (error()) { - return error(); - } - return first.get_uint64(); -} -simdjson_really_inline simdjson_result -simdjson_result::get_double() const noexcept { - if (error()) { - return error(); - } - return first.get_double(); -} -simdjson_really_inline simdjson_result -simdjson_result::get_bool() const noexcept { - if (error()) { - return error(); - } - return first.get_bool(); -} - -simdjson_really_inline bool simdjson_result::is_array() const - noexcept { - return !error() && first.is_array(); -} -simdjson_really_inline bool simdjson_result::is_object() const - noexcept { - return !error() && first.is_object(); -} -simdjson_really_inline bool simdjson_result::is_string() const - noexcept { - return !error() && first.is_string(); -} -simdjson_really_inline bool simdjson_result::is_int64() const - noexcept { - return !error() && first.is_int64(); -} -simdjson_really_inline bool simdjson_result::is_uint64() const - noexcept { - return !error() && first.is_uint64(); -} -simdjson_really_inline bool simdjson_result::is_double() const - noexcept { - return !error() && first.is_double(); -} -simdjson_really_inline bool simdjson_result::is_number() const - noexcept { - return !error() && first.is_number(); -} -simdjson_really_inline bool simdjson_result::is_bool() const - noexcept { - return !error() && first.is_bool(); -} - -simdjson_really_inline bool simdjson_result::is_null() const - noexcept { - return !error() && first.is_null(); -} - -simdjson_really_inline simdjson_result - simdjson_result::operator[](std::string_view key) const - noexcept { - if (error()) { - return error(); - } - return first[key]; -} -simdjson_really_inline simdjson_result - simdjson_result::operator[](const char *key) const noexcept { - if (error()) { - return error(); - } - return first[key]; -} -simdjson_really_inline simdjson_result simdjson_result< - dom::element>::at_pointer(const std::string_view json_pointer) const - noexcept { - if (error()) { - return error(); - } - return first.at_pointer(json_pointer); -} -#ifndef SIMDJSON_DISABLE_DEPRECATED_API -[[deprecated( - "For standard compliance, use at_pointer instead, and prefix your pointers " - "with a slash '/', see RFC6901 ")]] simdjson_really_inline - simdjson_result - simdjson_result::at(const std::string_view json_pointer) const - noexcept { - SIMDJSON_PUSH_DISABLE_WARNINGS - SIMDJSON_DISABLE_DEPRECATED_WARNING - if (error()) { - return error(); - } - return first.at(json_pointer); - SIMDJSON_POP_DISABLE_WARNINGS -} -#endif // SIMDJSON_DISABLE_DEPRECATED_API -simdjson_really_inline simdjson_result -simdjson_result::at(size_t index) const noexcept { - if (error()) { - return error(); - } - return first.at(index); -} -simdjson_really_inline simdjson_result -simdjson_result::at_key(std::string_view key) const noexcept { - if (error()) { - return error(); - } - return first.at_key(key); -} -simdjson_really_inline simdjson_result simdjson_result< - dom::element>::at_key_case_insensitive(std::string_view key) const - noexcept { - if (error()) { - return error(); - } - return first.at_key_case_insensitive(key); -} - -#if SIMDJSON_EXCEPTIONS - -simdjson_really_inline simdjson_result::operator bool() const - noexcept(false) { - return get(); -} -simdjson_really_inline simdjson_result::operator const char *() - const noexcept(false) { - return get(); -} -simdjson_really_inline simdjson_result:: -operator std::string_view() const noexcept(false) { - return get(); -} -simdjson_really_inline simdjson_result::operator uint64_t() const - noexcept(false) { - return get(); -} -simdjson_really_inline simdjson_result::operator int64_t() const - noexcept(false) { - return get(); -} -simdjson_really_inline simdjson_result::operator double() const - noexcept(false) { - return get(); -} -simdjson_really_inline simdjson_result::operator dom::array() - const noexcept(false) { - return get(); -} -simdjson_really_inline simdjson_result::operator dom::object() - const noexcept(false) { - return get(); -} - -simdjson_really_inline dom::array::iterator -simdjson_result::begin() const noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first.begin(); -} -simdjson_really_inline dom::array::iterator simdjson_result::end() - const noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first.end(); -} - -#endif // SIMDJSON_EXCEPTIONS - -namespace dom { - -// -// element inline implementation -// -simdjson_really_inline element::element() noexcept : tape{} {} -simdjson_really_inline element::element( - const internal::tape_ref &_tape) noexcept : tape{_tape} {} - -inline element_type element::type() const noexcept { - auto tape_type = tape.tape_ref_type(); - return tape_type == internal::tape_type::FALSE_VALUE - ? element_type::BOOL - : static_cast(tape_type); -} - -inline simdjson_result element::get_bool() const noexcept { - if (tape.is_true()) { - return true; - } else if (tape.is_false()) { - return false; - } - return INCORRECT_TYPE; -} -inline simdjson_result element::get_c_str() const noexcept { - switch (tape.tape_ref_type()) { - case internal::tape_type::STRING: { - return tape.get_c_str(); - } - default: - return INCORRECT_TYPE; - } -} -inline simdjson_result element::get_string_length() const noexcept { - switch (tape.tape_ref_type()) { - case internal::tape_type::STRING: { - return tape.get_string_length(); - } - default: - return INCORRECT_TYPE; - } -} -inline simdjson_result element::get_string() const noexcept { - switch (tape.tape_ref_type()) { - case internal::tape_type::STRING: - return tape.get_string_view(); - default: - return INCORRECT_TYPE; - } -} -inline simdjson_result element::get_uint64() const noexcept { - if (simdjson_unlikely(!tape.is_uint64())) { // branch rarely taken - if (tape.is_int64()) { - int64_t result = tape.next_tape_value(); - if (result < 0) { - return NUMBER_OUT_OF_RANGE; - } - return uint64_t(result); - } - return INCORRECT_TYPE; - } - return tape.next_tape_value(); -} -inline simdjson_result element::get_int64() const noexcept { - if (simdjson_unlikely(!tape.is_int64())) { // branch rarely taken - if (tape.is_uint64()) { - uint64_t result = tape.next_tape_value(); - // Wrapping max in parens to handle Windows issue: - // https://stackoverflow.com/questions/11544073/how-do-i-deal-with-the-max-macro-in-windows-h-colliding-with-max-in-std - if (result > uint64_t((std::numeric_limits::max)())) { - return NUMBER_OUT_OF_RANGE; - } - return static_cast(result); - } - return INCORRECT_TYPE; - } - return tape.next_tape_value(); -} -inline simdjson_result element::get_double() const noexcept { - // Performance considerations: - // 1. Querying tape_ref_type() implies doing a shift, it is fast to just do - // a straight - // comparison. - // 2. Using a switch-case relies on the compiler guessing what kind of code - // generation - // we want... But the compiler cannot know that we expect the type to be - // "double" - // most of the time. - // We can expect get to refer to a double type almost all the time. - // It is important to craft the code accordingly so that the compiler can - // use this - // information. (This could also be solved with profile-guided - // optimization.) - if (simdjson_unlikely(!tape.is_double())) { // branch rarely taken - if (tape.is_uint64()) { - return double(tape.next_tape_value()); - } else if (tape.is_int64()) { - return double(tape.next_tape_value()); - } - return INCORRECT_TYPE; - } - // this is common: - return tape.next_tape_value(); -} -inline simdjson_result element::get_array() const noexcept { - switch (tape.tape_ref_type()) { - case internal::tape_type::START_ARRAY: - return array(tape); - default: - return INCORRECT_TYPE; - } -} -inline simdjson_result element::get_object() const noexcept { - switch (tape.tape_ref_type()) { - case internal::tape_type::START_OBJECT: - return object(tape); - default: - return INCORRECT_TYPE; - } -} - -template -simdjson_warn_unused simdjson_really_inline error_code -element::get(T &value) const noexcept { - return get().get(value); -} -// An element-specific version prevents recursion with -// simdjson_result::get(value) -template <> -simdjson_warn_unused simdjson_really_inline error_code -element::get(element &value) const noexcept { - value = element(tape); - return SUCCESS; -} -template - inline void element::tie(T &value, error_code &error) && noexcept { - error = get(value); -} - -template -simdjson_really_inline bool element::is() const noexcept { - auto result = get(); - return !result.error(); -} - -template <> -inline simdjson_result element::get() const noexcept { - return get_array(); -} -template <> -inline simdjson_result element::get() const noexcept { - return get_object(); -} -template <> -inline simdjson_result element::get() const - noexcept { - return get_c_str(); -} -template <> -inline simdjson_result element::get() const - noexcept { - return get_string(); -} -template <> -inline simdjson_result element::get() const noexcept { - return get_int64(); -} -template <> -inline simdjson_result element::get() const noexcept { - return get_uint64(); -} -template <> -inline simdjson_result element::get() const noexcept { - return get_double(); -} -template <> -inline simdjson_result element::get() const noexcept { - return get_bool(); -} - -inline bool element::is_array() const noexcept { return is(); } -inline bool element::is_object() const noexcept { return is(); } -inline bool element::is_string() const noexcept { - return is(); -} -inline bool element::is_int64() const noexcept { return is(); } -inline bool element::is_uint64() const noexcept { return is(); } -inline bool element::is_double() const noexcept { return is(); } -inline bool element::is_bool() const noexcept { return is(); } -inline bool element::is_number() const noexcept { - return is_int64() || is_uint64() || is_double(); -} - -inline bool element::is_null() const noexcept { return tape.is_null_on_tape(); } - -#if SIMDJSON_EXCEPTIONS - -inline element::operator bool() const noexcept(false) { return get(); } -inline element::operator const char *() const noexcept(false) { - return get(); -} -inline element::operator std::string_view() const noexcept(false) { - return get(); -} -inline element::operator uint64_t() const noexcept(false) { - return get(); -} -inline element::operator int64_t() const noexcept(false) { - return get(); -} -inline element::operator double() const noexcept(false) { - return get(); -} -inline element::operator array() const noexcept(false) { return get(); } -inline element::operator object() const noexcept(false) { - return get(); -} - -inline array::iterator element::begin() const noexcept(false) { - return get().begin(); -} -inline array::iterator element::end() const noexcept(false) { - return get().end(); -} - -#endif // SIMDJSON_EXCEPTIONS - -inline simdjson_result element::operator[](std::string_view key) const - noexcept { - return at_key(key); -} -inline simdjson_result element::operator[](const char *key) const - noexcept { - return at_key(key); -} - -inline simdjson_result element::at_pointer( - std::string_view json_pointer) const noexcept { - switch (tape.tape_ref_type()) { - case internal::tape_type::START_OBJECT: - return object(tape).at_pointer(json_pointer); - case internal::tape_type::START_ARRAY: - return array(tape).at_pointer(json_pointer); - default: { - if (!json_pointer - .empty()) { // a non-empty string is invalid on an atom - return INVALID_JSON_POINTER; - } - // an empty string means that we return the current node - dom::element copy(*this); - return simdjson_result(std::move(copy)); - } - } -} -#ifndef SIMDJSON_DISABLE_DEPRECATED_API -[[deprecated( - "For standard compliance, use at_pointer instead, and prefix your pointers " - "with a slash '/', see RFC6901 ")]] inline simdjson_result -element::at(std::string_view json_pointer) const noexcept { - // version 0.4 of simdjson allowed non-compliant pointers - auto std_pointer = (json_pointer.empty() ? "" : "/") + - std::string(json_pointer.begin(), json_pointer.end()); - return at_pointer(std_pointer); -} -#endif // SIMDJSON_DISABLE_DEPRECATED_API - -inline simdjson_result element::at(size_t index) const noexcept { - return get().at(index); -} -inline simdjson_result element::at_key(std::string_view key) const - noexcept { - return get().at_key(key); -} -inline simdjson_result element::at_key_case_insensitive( - std::string_view key) const noexcept { - return get().at_key_case_insensitive(key); -} - -inline bool element::dump_raw_tape(std::ostream &out) const noexcept { - return tape.doc->dump_raw_tape(out); -} - - -inline std::ostream &operator<<(std::ostream &out, element_type type) { - switch (type) { - case element_type::ARRAY: - return out << "array"; - case element_type::OBJECT: - return out << "object"; - case element_type::INT64: - return out << "int64_t"; - case element_type::UINT64: - return out << "uint64_t"; - case element_type::DOUBLE: - return out << "double"; - case element_type::STRING: - return out << "string"; - case element_type::BOOL: - return out << "bool"; - case element_type::NULL_VALUE: - return out << "null"; - default: - return out << "unexpected content!!!"; // abort() usage is - // forbidden in the library - } -} - -} // namespace dom - -} // namespace simdjson - -#endif // SIMDJSON_INLINE_ELEMENT_H -/* end file include/simdjson/dom/element-inl.h */ - -#if defined(__cpp_lib_ranges) -static_assert(std::ranges::view); -static_assert(std::ranges::sized_range); -#if SIMDJSON_EXCEPTIONS -static_assert( - std::ranges::view>); -static_assert( - std::ranges::sized_range>); -#endif // SIMDJSON_EXCEPTIONS -#endif // defined(__cpp_lib_ranges) - -#endif // SIMDJSON_INLINE_ARRAY_H -/* end file include/simdjson/dom/array-inl.h */ -/* begin file include/simdjson/dom/document_stream-inl.h */ -#ifndef SIMDJSON_INLINE_DOCUMENT_STREAM_H -#define SIMDJSON_INLINE_DOCUMENT_STREAM_H - -#include -#include -#include -namespace simdjson { -namespace dom { - -#ifdef SIMDJSON_THREADS_ENABLED -inline void stage1_worker::finish() { - // After calling "run" someone would call finish() to wait - // for the end of the processing. - // This function will wait until either the thread has done - // the processing or, else, the destructor has been called. - std::unique_lock lock(locking_mutex); - cond_var.wait(lock, [this] { return has_work == false; }); -} - -inline stage1_worker::~stage1_worker() { - // The thread may never outlive the stage1_worker instance - // and will always be stopped/joined before the stage1_worker - // instance is gone. - stop_thread(); -} - -inline void stage1_worker::start_thread() { - std::unique_lock lock(locking_mutex); - if (thread.joinable()) { - return; // This should never happen but we never want to create more - // than one thread. - } - thread = std::thread([this] { - while (true) { - std::unique_lock thread_lock(locking_mutex); - // We wait for either "run" or "stop_thread" to be called. - cond_var.wait(thread_lock, - [this] { return has_work || !can_work; }); - // If, for some reason, the stop_thread() method was called (i.e., - // the - // destructor of stage1_worker is called, then we want to - // immediately destroy - // the thread (and not do any more processing). - if (!can_work) { - break; - } - this->owner->stage1_thread_error = this->owner->run_stage1( - *this->stage1_thread_parser, this->_next_batch_start); - this->has_work = false; - // The condition variable call should be moved after - // thread_lock.unlock() for performance - // reasons but thread sanitizers may report it as a data race if we - // do. - // See - // https://stackoverflow.com/questions/35775501/c-should-condition-variable-be-notified-under-lock - cond_var.notify_one(); // will notify "finish" - thread_lock.unlock(); - } - }); -} - - -inline void stage1_worker::stop_thread() { - std::unique_lock lock(locking_mutex); - // We have to make sure that all locks can be released. - can_work = false; - has_work = false; - cond_var.notify_all(); - lock.unlock(); - if (thread.joinable()) { - thread.join(); - } -} - -inline void stage1_worker::run(document_stream *ds, - dom::parser *stage1, - size_t next_batch_start) { - std::unique_lock lock(locking_mutex); - owner = ds; - _next_batch_start = next_batch_start; - stage1_thread_parser = stage1; - has_work = true; - // The condition variable call should be moved after thread_lock.unlock() - // for performance - // reasons but thread sanitizers may report it as a data race if we do. - // See - // https://stackoverflow.com/questions/35775501/c-should-condition-variable-be-notified-under-lock - cond_var.notify_one(); // will notify the thread lock that we have work - lock.unlock(); -} -#endif - -simdjson_really_inline document_stream::document_stream( - dom::parser &_parser, - const uint8_t *_buf, - size_t _len, - size_t _batch_size) noexcept - : parser{&_parser}, - buf{_buf}, - len{_len}, - batch_size{_batch_size <= MINIMAL_BATCH_SIZE ? MINIMAL_BATCH_SIZE - : _batch_size}, - error { - SUCCESS -} -#ifdef SIMDJSON_THREADS_ENABLED -, use_thread(_parser.threaded) // we need to make a copy because - // _parser.threaded can change -#endif -{ -#ifdef SIMDJSON_THREADS_ENABLED - if (worker.get() == nullptr) { - error = MEMALLOC; - } -#endif -} - -simdjson_really_inline document_stream::document_stream() noexcept - : parser{nullptr}, - buf{nullptr}, - len{0}, - batch_size{0}, - error { - UNINITIALIZED -} -#ifdef SIMDJSON_THREADS_ENABLED -, use_thread(false) -#endif -{ -} - -simdjson_really_inline document_stream::~document_stream() noexcept { -#ifdef SIMDJSON_THREADS_ENABLED - worker.reset(); -#endif -} - -simdjson_really_inline document_stream::iterator::iterator() noexcept - : stream{nullptr}, - finished{true} {} - -simdjson_really_inline document_stream::iterator -document_stream::begin() noexcept { - start(); - // If there are no documents, we're finished. - return iterator(this, error == EMPTY); -} - -simdjson_really_inline document_stream::iterator -document_stream::end() noexcept { - return iterator(this, true); -} - -simdjson_really_inline document_stream::iterator::iterator( - document_stream *_stream, bool is_end) noexcept : stream{_stream}, - finished{is_end} {} - -simdjson_really_inline document_stream::iterator::reference - document_stream::iterator::operator*() noexcept { - // Note that in case of error, we do not yet mark - // the iterator as "finished": this detection is done - // in the operator++ function since it is possible - // to call operator++ repeatedly while omitting - // calls to operator*. - if (stream->error) { - return stream->error; - } - return stream->parser->doc.root(); -} - -simdjson_really_inline document_stream::iterator - &document_stream::iterator::operator++() noexcept { - // If there is an error, then we want the iterator - // to be finished, no matter what. (E.g., we do not - // keep generating documents with errors, or go beyond - // a document with errors.) - // - // Users do not have to call "operator*()" when they use operator++, - // so we need to end the stream in the operator++ function. - // - // Note that setting finished = true is essential otherwise - // we would enter an infinite loop. - if (stream->error) { - finished = true; - } - // Note that stream->error() is guarded against error conditions - // (it will immediately return if stream->error casts to false). - // In effect, this next function does nothing when (stream->error) - // is true (hence the risk of an infinite loop). - stream->next(); - // If that was the last document, we're finished. - // It is the only type of error we do not want to appear - // in operator*. - if (stream->error == EMPTY) { - finished = true; - } - // If we had any other kind of error (not EMPTY) then we want - // to pass it along to the operator* and we cannot mark the result - // as "finished" just yet. - return *this; -} - -simdjson_really_inline bool document_stream::iterator::operator!=( - const document_stream::iterator &other) const noexcept { - return finished != other.finished; -} - -inline void document_stream::start() noexcept { - if (error) { - return; - } - error = parser->ensure_capacity(batch_size); - if (error) { - return; - } - // Always run the first stage 1 parse immediately - batch_start = 0; - error = run_stage1(*parser, batch_start); - while (error == EMPTY) { - // In exceptional cases, we may start with an empty block - batch_start = next_batch_start(); - if (batch_start >= len) { - return; - } - error = run_stage1(*parser, batch_start); - } - if (error) { - return; - } -#ifdef SIMDJSON_THREADS_ENABLED - if (use_thread && next_batch_start() < len) { - // Kick off the first thread if needed - error = stage1_thread_parser.ensure_capacity(batch_size); - if (error) { - return; - } - worker->start_thread(); - start_stage1_thread(); - if (error) { - return; - } - } -#endif // SIMDJSON_THREADS_ENABLED - next(); -} - -simdjson_really_inline size_t document_stream::iterator::current_index() const - noexcept { - return stream->doc_index; -} - -simdjson_really_inline std::string_view document_stream::iterator::source() - const noexcept { - const char *start = - reinterpret_cast(stream->buf) + current_index(); - bool object_or_array = ((*start == '[') || (*start == '{')); - if (object_or_array) { - size_t next_doc_index = - stream->batch_start + - stream->parser->implementation->structural_indexes - [stream->parser->implementation->next_structural_index - 1]; - return std::string_view(start, next_doc_index - current_index() + 1); - } else { - size_t next_doc_index = - stream->batch_start + - stream->parser->implementation->structural_indexes - [stream->parser->implementation->next_structural_index]; - return std::string_view( - reinterpret_cast(stream->buf) + current_index(), - next_doc_index - current_index() - 1); - } -} - - -inline void document_stream::next() noexcept { - // We always exit at once, once in an error condition. - if (error) { - return; - } - - // Load the next document from the batch - doc_index = - batch_start + - parser->implementation - ->structural_indexes[parser->implementation->next_structural_index]; - error = parser->implementation->stage2_next(parser->doc); - // If that was the last document in the batch, load another batch (if - // available) - while (error == EMPTY) { - batch_start = next_batch_start(); - if (batch_start >= len) { - break; - } - -#ifdef SIMDJSON_THREADS_ENABLED - if (use_thread) { - load_from_stage1_thread(); - } else { - error = run_stage1(*parser, batch_start); - } -#else - error = run_stage1(*parser, batch_start); -#endif - if (error) { - continue; - } // If the error was EMPTY, we may want to load another batch. - // Run stage 2 on the first document in the batch - doc_index = batch_start + - parser->implementation->structural_indexes - [parser->implementation->next_structural_index]; - error = parser->implementation->stage2_next(parser->doc); - } -} -inline size_t document_stream::size_in_bytes() const noexcept { return len; } - -inline size_t document_stream::truncated_bytes() const noexcept { - if (error == CAPACITY) { - return len - batch_start; - } - return parser->implementation->structural_indexes - [parser->implementation->n_structural_indexes] - - parser->implementation->structural_indexes - [parser->implementation->n_structural_indexes + 1]; -} - -inline size_t document_stream::next_batch_start() const noexcept { - return batch_start + - parser->implementation->structural_indexes - [parser->implementation->n_structural_indexes]; -} - -inline error_code document_stream::run_stage1(dom::parser &p, - size_t _batch_start) noexcept { - size_t remaining = len - _batch_start; - if (remaining <= batch_size) { - return p.implementation->stage1( - &buf[_batch_start], remaining, stage1_mode::streaming_final); - } else { - return p.implementation->stage1( - &buf[_batch_start], batch_size, stage1_mode::streaming_partial); - } -} - -#ifdef SIMDJSON_THREADS_ENABLED - -inline void document_stream::load_from_stage1_thread() noexcept { - worker->finish(); - // Swap to the parser that was loaded up in the thread. Make sure the parser - // has - // enough memory to swap to, as well. - std::swap(*parser, stage1_thread_parser); - error = stage1_thread_error; - if (error) { - return; - } - - // If there's anything left, start the stage 1 thread! - if (next_batch_start() < len) { - start_stage1_thread(); - } -} - -inline void document_stream::start_stage1_thread() noexcept { - // we call the thread on a lambda that will update - // this->stage1_thread_error - // there is only one thread that may write to this value - // TODO this is NOT exception-safe. - this->stage1_thread_error = - UNINITIALIZED; // In case something goes wrong, make sure it's an error - size_t _next_batch_start = this->next_batch_start(); - - worker->run(this, &this->stage1_thread_parser, _next_batch_start); -} - -#endif // SIMDJSON_THREADS_ENABLED - -} // namespace dom - -simdjson_really_inline -simdjson_result::simdjson_result() noexcept - : simdjson_result_base() {} -simdjson_really_inline simdjson_result::simdjson_result( - error_code error) noexcept : simdjson_result_base(error) {} -simdjson_really_inline simdjson_result::simdjson_result( - dom::document_stream &&value) noexcept - : simdjson_result_base(std::forward(value)) {} - -#if SIMDJSON_EXCEPTIONS -simdjson_really_inline dom::document_stream::iterator -simdjson_result::begin() noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first.begin(); -} -simdjson_really_inline dom::document_stream::iterator -simdjson_result::end() noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first.end(); -} -#else // SIMDJSON_EXCEPTIONS -#ifndef SIMDJSON_DISABLE_DEPRECATED_API -simdjson_really_inline dom::document_stream::iterator -simdjson_result::begin() noexcept { - first.error = error(); - return first.begin(); -} -simdjson_really_inline dom::document_stream::iterator -simdjson_result::end() noexcept { - first.error = error(); - return first.end(); -} -#endif // SIMDJSON_DISABLE_DEPRECATED_API -#endif // SIMDJSON_EXCEPTIONS - -} // namespace simdjson -#endif // SIMDJSON_INLINE_DOCUMENT_STREAM_H -/* end file include/simdjson/dom/document_stream-inl.h */ -/* begin file include/simdjson/dom/document-inl.h */ -#ifndef SIMDJSON_INLINE_DOCUMENT_H -#define SIMDJSON_INLINE_DOCUMENT_H - -// Inline implementations go in here. - -#include -#include - -namespace simdjson { -namespace dom { - -// -// document inline implementation -// -inline element document::root() const noexcept { - return element(internal::tape_ref(this, 1)); -} -simdjson_warn_unused inline size_t document::capacity() const noexcept { - return allocated_capacity; -} - -simdjson_warn_unused inline error_code document::allocate( - size_t capacity) noexcept { - if (capacity == 0) { - string_buf.reset(); - tape.reset(); - allocated_capacity = 0; - return SUCCESS; - } - - // a pathological input like "[[[[..." would generate capacity tape - // elements, so - // need a capacity of at least capacity + 1, but it is also possible to do - // worse with - // "[7,7,7,7,6,7,7,7,6,7,7,6,[7,7,7,7,6,7,7,7,6,7,7,6,7,7,7,7,7,7,6" - // where capacity + 1 tape elements are - // generated, see issue https://github.com/simdjson/simdjson/issues/345 - size_t tape_capacity = SIMDJSON_ROUNDUP_N(capacity + 3, 64); - // a document with only zero-length strings... could have capacity/3 string - // and we would need capacity/3 * 5 bytes on the string buffer - size_t string_capacity = - SIMDJSON_ROUNDUP_N(5 * capacity / 3 + SIMDJSON_PADDING, 64); - string_buf.reset(new (std::nothrow) uint8_t[string_capacity]); - tape.reset(new (std::nothrow) uint64_t[tape_capacity]); - if (!(string_buf && tape)) { - allocated_capacity = 0; - string_buf.reset(); - tape.reset(); - return MEMALLOC; - } - // Technically the allocated_capacity might be larger than capacity - // so the next line is pessimistic. - allocated_capacity = capacity; - return SUCCESS; -} - -inline bool document::dump_raw_tape(std::ostream &os) const noexcept { - uint32_t string_length; - size_t tape_idx = 0; - uint64_t tape_val = tape[tape_idx]; - uint8_t type = uint8_t(tape_val >> 56); - os << tape_idx << " : " << type; - tape_idx++; - size_t how_many = 0; - if (type == 'r') { - how_many = size_t(tape_val & internal::JSON_VALUE_MASK); - } else { - // Error: no starting root node? - return false; - } - os << "\t// pointing to " << how_many << " (right after last node)\n"; - uint64_t payload; - for (; tape_idx < how_many; tape_idx++) { - os << tape_idx << " : "; - tape_val = tape[tape_idx]; - payload = tape_val & internal::JSON_VALUE_MASK; - type = uint8_t(tape_val >> 56); - switch (type) { - case '"': // we have a string - os << "string \""; - std::memcpy(&string_length, - string_buf.get() + payload, - sizeof(uint32_t)); - os << internal::escape_json_string(std::string_view( - reinterpret_cast(string_buf.get() + payload + - sizeof(uint32_t)), - string_length)); - os << '"'; - os << '\n'; - break; - case 'l': // we have a long int - if (tape_idx + 1 >= how_many) { - return false; - } - os << "integer " << static_cast(tape[++tape_idx]) - << "\n"; - break; - case 'u': // we have a long uint - if (tape_idx + 1 >= how_many) { - return false; - } - os << "unsigned integer " << tape[++tape_idx] << "\n"; - break; - case 'd': // we have a double - os << "float "; - if (tape_idx + 1 >= how_many) { - return false; - } - double answer; - std::memcpy(&answer, &tape[++tape_idx], sizeof(answer)); - os << answer << '\n'; - break; - case 'n': // we have a null - os << "null\n"; - break; - case 't': // we have a true - os << "true\n"; - break; - case 'f': // we have a false - os << "false\n"; - break; - case '{': // we have an object - os << "{\t// pointing to next tape location " - << uint32_t(payload) << " (first node after the scope), " - << " saturated count " - << ((payload >> 32) & internal::JSON_COUNT_MASK) << "\n"; - break; - case '}': // we end an object - os << "}\t// pointing to previous tape location " - << uint32_t(payload) << " (start of the scope)\n"; - break; - case '[': // we start an array - os << "[\t// pointing to next tape location " - << uint32_t(payload) << " (first node after the scope), " - << " saturated count " - << ((payload >> 32) & internal::JSON_COUNT_MASK) << "\n"; - break; - case ']': // we end an array - os << "]\t// pointing to previous tape location " - << uint32_t(payload) << " (start of the scope)\n"; - break; - case 'r': // we start and end with the root node - // should we be hitting the root node? - return false; - default: - return false; - } - } - tape_val = tape[tape_idx]; - payload = tape_val & internal::JSON_VALUE_MASK; - type = uint8_t(tape_val >> 56); - os << tape_idx << " : " << type << "\t// pointing to " << payload - << " (start root)\n"; - return true; -} - -} // namespace dom -} // namespace simdjson - -#endif // SIMDJSON_INLINE_DOCUMENT_H -/* end file include/simdjson/dom/document-inl.h */ -/* begin file include/simdjson/dom/object-inl.h */ -#ifndef SIMDJSON_INLINE_OBJECT_H -#define SIMDJSON_INLINE_OBJECT_H - -#include -#include - -namespace simdjson { - -// -// simdjson_result inline implementation -// -simdjson_really_inline simdjson_result::simdjson_result() noexcept - : internal::simdjson_result_base() {} -simdjson_really_inline simdjson_result::simdjson_result( - dom::object value) noexcept - : internal::simdjson_result_base( - std::forward(value)) {} -simdjson_really_inline simdjson_result::simdjson_result( - error_code error) noexcept - : internal::simdjson_result_base(error) {} - -inline simdjson_result simdjson_result::operator[]( - std::string_view key) const noexcept { - if (error()) { - return error(); - } - return first[key]; -} -inline simdjson_result simdjson_result::operator[]( - const char *key) const noexcept { - if (error()) { - return error(); - } - return first[key]; -} -inline simdjson_result simdjson_result::at_pointer( - std::string_view json_pointer) const noexcept { - if (error()) { - return error(); - } - return first.at_pointer(json_pointer); -} -inline simdjson_result simdjson_result::at_key( - std::string_view key) const noexcept { - if (error()) { - return error(); - } - return first.at_key(key); -} -inline simdjson_result simdjson_result< - dom::object>::at_key_case_insensitive(std::string_view key) const noexcept { - if (error()) { - return error(); - } - return first.at_key_case_insensitive(key); -} - -#if SIMDJSON_EXCEPTIONS - -inline dom::object::iterator simdjson_result::begin() const - noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first.begin(); -} -inline dom::object::iterator simdjson_result::end() const - noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first.end(); -} -inline size_t simdjson_result::size() const noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first.size(); -} - -#endif // SIMDJSON_EXCEPTIONS - -namespace dom { - -// -// object inline implementation -// -simdjson_really_inline object::object() noexcept : tape{} {} -simdjson_really_inline object::object(const internal::tape_ref &_tape) noexcept - : tape{_tape} {} -inline object::iterator object::begin() const noexcept { - return internal::tape_ref(tape.doc, tape.json_index + 1); -} -inline object::iterator object::end() const noexcept { - return internal::tape_ref(tape.doc, tape.after_element() - 1); -} -inline size_t object::size() const noexcept { return tape.scope_count(); } - -inline simdjson_result object::operator[](std::string_view key) const - noexcept { - return at_key(key); -} -inline simdjson_result object::operator[](const char *key) const - noexcept { - return at_key(key); -} -inline simdjson_result object::at_pointer( - std::string_view json_pointer) const noexcept { - if (json_pointer.empty()) { // an empty string means that we return the - // current node - return element(this->tape); // copy the current node - } else if (json_pointer[0] != '/') { // otherwise there is an error - return INVALID_JSON_POINTER; - } - json_pointer = json_pointer.substr(1); - size_t slash = json_pointer.find('/'); - std::string_view key = json_pointer.substr(0, slash); - // Grab the child with the given key - simdjson_result child; - - // If there is an escape character in the key, unescape it and then get the - // child. - size_t escape = key.find('~'); - if (escape != std::string_view::npos) { - // Unescape the key - std::string unescaped(key); - do { - switch (unescaped[escape + 1]) { - case '0': - unescaped.replace(escape, 2, "~"); - break; - case '1': - unescaped.replace(escape, 2, "/"); - break; - default: - return INVALID_JSON_POINTER; // "Unexpected ~ escape - // character in JSON - // pointer"); - } - escape = unescaped.find('~', escape + 1); - } while (escape != std::string::npos); - child = at_key(unescaped); - } else { - child = at_key(key); - } - if (child.error()) { - return child; // we do not continue if there was an error - } - // If there is a /, we have to recurse and look up more of the path - if (slash != std::string_view::npos) { - child = child.at_pointer(json_pointer.substr(slash)); - } - return child; -} - -inline simdjson_result object::at_key(std::string_view key) const - noexcept { - iterator end_field = end(); - for (iterator field = begin(); field != end_field; ++field) { - if (field.key_equals(key)) { - return field.value(); - } - } - return NO_SUCH_FIELD; -} -// In case you wonder why we need this, please see -// https://github.com/simdjson/simdjson/issues/323 -// People do seek keys in a case-insensitive manner. -inline simdjson_result object::at_key_case_insensitive( - std::string_view key) const noexcept { - iterator end_field = end(); - for (iterator field = begin(); field != end_field; ++field) { - if (field.key_equals_case_insensitive(key)) { - return field.value(); - } - } - return NO_SUCH_FIELD; -} - -// -// object::iterator inline implementation -// -simdjson_really_inline object::iterator::iterator( - const internal::tape_ref &_tape) noexcept : tape{_tape} {} -inline const key_value_pair object::iterator::operator*() const noexcept { - return key_value_pair(key(), value()); -} -inline bool object::iterator::operator!=(const object::iterator &other) const - noexcept { - return tape.json_index != other.tape.json_index; -} -inline bool object::iterator::operator==(const object::iterator &other) const - noexcept { - return tape.json_index == other.tape.json_index; -} -inline bool object::iterator::operator<(const object::iterator &other) const - noexcept { - return tape.json_index < other.tape.json_index; -} -inline bool object::iterator::operator<=(const object::iterator &other) const - noexcept { - return tape.json_index <= other.tape.json_index; -} -inline bool object::iterator::operator>=(const object::iterator &other) const - noexcept { - return tape.json_index >= other.tape.json_index; -} -inline bool object::iterator::operator>(const object::iterator &other) const - noexcept { - return tape.json_index > other.tape.json_index; -} -inline object::iterator &object::iterator::operator++() noexcept { - tape.json_index++; - tape.json_index = tape.after_element(); - return *this; -} -inline object::iterator object::iterator::operator++(int)noexcept { - object::iterator out = *this; - ++*this; - return out; -} -inline std::string_view object::iterator::key() const noexcept { - return tape.get_string_view(); -} -inline uint32_t object::iterator::key_length() const noexcept { - return tape.get_string_length(); -} -inline const char *object::iterator::key_c_str() const noexcept { - return reinterpret_cast( - &tape.doc->string_buf[size_t(tape.tape_value()) + sizeof(uint32_t)]); -} -inline element object::iterator::value() const noexcept { - return element(internal::tape_ref(tape.doc, tape.json_index + 1)); -} - -/** - * Design notes: - * Instead of constructing a string_view and then comparing it with a - * user-provided strings, it is probably more performant to have dedicated - * functions taking as a parameter the string we want to compare against - * and return true when they are equal. That avoids the creation of a temporary - * std::string_view. Though it is possible for the compiler to avoid entirely - * any overhead due to string_view, relying too much on compiler magic is - * problematic: compiler magic sometimes fail, and then what do you do? - * Also, enticing users to rely on high-performance function is probably better - * on the long run. - */ - -inline bool object::iterator::key_equals(std::string_view o) const noexcept { - // We use the fact that the key length can be computed quickly - // without access to the string buffer. - const uint32_t len = key_length(); - if (o.size() == len) { - // We avoid construction of a temporary string_view instance. - return (memcmp(o.data(), key_c_str(), len) == 0); - } - return false; -} - -inline bool object::iterator::key_equals_case_insensitive( - std::string_view o) const noexcept { - // We use the fact that the key length can be computed quickly - // without access to the string buffer. - const uint32_t len = key_length(); - if (o.size() == len) { - // See For case-insensitive string comparisons, avoid char-by-char - // functions - // https://lemire.me/blog/2020/04/30/for-case-insensitive-string-comparisons-avoid-char-by-char-functions/ - // Note that it might be worth rolling our own strncasecmp function, - // with vectorization. - return (simdjson_strncasecmp(o.data(), key_c_str(), len) == 0); - } - return false; -} -// -// key_value_pair inline implementation -// -inline key_value_pair::key_value_pair(std::string_view _key, - element _value) noexcept : key(_key), - value(_value) { -} - -} // namespace dom - -} // namespace simdjson - -#if defined(__cpp_lib_ranges) -static_assert(std::ranges::view); -static_assert(std::ranges::sized_range); -#if SIMDJSON_EXCEPTIONS -static_assert( - std::ranges::view>); -static_assert( - std::ranges::sized_range>); -#endif // SIMDJSON_EXCEPTIONS -#endif // defined(__cpp_lib_ranges) - -#endif // SIMDJSON_INLINE_OBJECT_H -/* end file include/simdjson/dom/object-inl.h */ -/* begin file include/simdjson/dom/parsedjson_iterator-inl.h */ -#ifndef SIMDJSON_INLINE_PARSEDJSON_ITERATOR_H -#define SIMDJSON_INLINE_PARSEDJSON_ITERATOR_H - -#include - -#ifndef SIMDJSON_DISABLE_DEPRECATED_API - -namespace simdjson { - -// VS2017 reports deprecated warnings when you define a deprecated class's -// methods. -SIMDJSON_PUSH_DISABLE_WARNINGS -SIMDJSON_DISABLE_DEPRECATED_WARNING - -// Because of template weirdness, the actual class definition is inline in the -// document class -simdjson_warn_unused bool dom::parser::Iterator::is_ok() const { - return location < tape_length; -} - -// useful for debugging purposes -size_t dom::parser::Iterator::get_tape_location() const { return location; } - -// useful for debugging purposes -size_t dom::parser::Iterator::get_tape_length() const { return tape_length; } - -// returns the current depth (start at 1 with 0 reserved for the fictitious root -// node) -size_t dom::parser::Iterator::get_depth() const { return depth; } - -// A scope is a series of nodes at the same depth, typically it is either an -// object ({) or an array ([). The root node has type 'r'. -uint8_t dom::parser::Iterator::get_scope_type() const { - return depth_index[depth].scope_type; -} - -bool dom::parser::Iterator::move_forward() { - if (location + 1 >= tape_length) { - return false; // we are at the end! - } - - if ((current_type == '[') || (current_type == '{')) { - // We are entering a new scope - depth++; - assert(depth < max_depth); - depth_index[depth].start_of_scope = location; - depth_index[depth].scope_type = current_type; - } else if ((current_type == ']') || (current_type == '}')) { - // Leaving a scope. - depth--; - } else if (is_number()) { - // these types use 2 locations on the tape, not just one. - location += 1; - } - - location += 1; - current_val = doc.tape[location]; - current_type = uint8_t(current_val >> 56); - return true; -} - -void dom::parser::Iterator::move_to_value() { - // assume that we are on a key, so move by 1. - location += 1; - current_val = doc.tape[location]; - current_type = uint8_t(current_val >> 56); -} - -bool dom::parser::Iterator::move_to_key(const char *key) { - if (down()) { - do { - const bool right_key = (strcmp(get_string(), key) == 0); - move_to_value(); - if (right_key) { - return true; - } - } while (next()); - up(); - } - return false; -} - -bool dom::parser::Iterator::move_to_key_insensitive(const char *key) { - if (down()) { - do { - const bool right_key = - (simdjson_strcasecmp(get_string(), key) == 0); - move_to_value(); - if (right_key) { - return true; - } - } while (next()); - up(); - } - return false; -} - -bool dom::parser::Iterator::move_to_key(const char *key, uint32_t length) { - if (down()) { - do { - bool right_key = ((get_string_length() == length) && - (memcmp(get_string(), key, length) == 0)); - move_to_value(); - if (right_key) { - return true; - } - } while (next()); - up(); - } - return false; -} - -bool dom::parser::Iterator::move_to_index(uint32_t index) { - if (down()) { - uint32_t i = 0; - for (; i < index; i++) { - if (!next()) { - break; - } - } - if (i == index) { - return true; - } - up(); - } - return false; -} - -bool dom::parser::Iterator::prev() { - size_t target_location = location; - to_start_scope(); - size_t npos = location; - if (target_location == npos) { - return false; // we were already at the start - } - size_t oldnpos; - // we have that npos < target_location here - do { - oldnpos = npos; - if ((current_type == '[') || (current_type == '{')) { - // we need to jump - npos = uint32_t(current_val); - } else { - npos = - npos + ((current_type == 'd' || current_type == 'l') ? 2 : 1); - } - } while (npos < target_location); - location = oldnpos; - current_val = doc.tape[location]; - current_type = uint8_t(current_val >> 56); - return true; -} - -bool dom::parser::Iterator::up() { - if (depth == 1) { - return false; // don't allow moving back to root - } - to_start_scope(); - // next we just move to the previous value - depth--; - location -= 1; - current_val = doc.tape[location]; - current_type = uint8_t(current_val >> 56); - return true; -} - -bool dom::parser::Iterator::down() { - if (location + 1 >= tape_length) { - return false; - } - if ((current_type == '[') || (current_type == '{')) { - size_t npos = uint32_t(current_val); - if (npos == location + 2) { - return false; // we have an empty scope - } - depth++; - assert(depth < max_depth); - location = location + 1; - depth_index[depth].start_of_scope = location; - depth_index[depth].scope_type = current_type; - current_val = doc.tape[location]; - current_type = uint8_t(current_val >> 56); - return true; - } - return false; -} - -void dom::parser::Iterator::to_start_scope() { - location = depth_index[depth].start_of_scope; - current_val = doc.tape[location]; - current_type = uint8_t(current_val >> 56); -} - -bool dom::parser::Iterator::next() { - size_t npos; - if ((current_type == '[') || (current_type == '{')) { - // we need to jump - npos = uint32_t(current_val); - } else { - npos = location + (is_number() ? 2 : 1); - } - uint64_t next_val = doc.tape[npos]; - uint8_t next_type = uint8_t(next_val >> 56); - if ((next_type == ']') || (next_type == '}')) { - return false; // we reached the end of the scope - } - location = npos; - current_val = next_val; - current_type = next_type; - return true; -} -dom::parser::Iterator::Iterator(const dom::parser &pj) noexcept(false) - : doc(pj.doc) { -#if SIMDJSON_EXCEPTIONS - if (!pj.valid) { - throw simdjson_error(pj.error); - } -#else - if (!pj.valid) { - return; - } // abort() usage is forbidden in the library -#endif - - max_depth = pj.max_depth(); - depth_index = new scopeindex_t[max_depth + 1]; - depth_index[0].start_of_scope = location; - current_val = doc.tape[location++]; - current_type = uint8_t(current_val >> 56); - depth_index[0].scope_type = current_type; - tape_length = size_t(current_val & internal::JSON_VALUE_MASK); - if (location < tape_length) { - // If we make it here, then depth_capacity must >=2, but the compiler - // may not know this. - current_val = doc.tape[location]; - current_type = uint8_t(current_val >> 56); - depth++; - assert(depth < max_depth); - depth_index[depth].start_of_scope = location; - depth_index[depth].scope_type = current_type; - } -} -dom::parser::Iterator::Iterator(const dom::parser::Iterator &o) noexcept - : doc(o.doc), - max_depth(o.depth), - depth(o.depth), - location(o.location), - tape_length(o.tape_length), - current_type(o.current_type), - current_val(o.current_val) { - depth_index = new scopeindex_t[max_depth + 1]; - std::memcpy( - depth_index, o.depth_index, (depth + 1) * sizeof(depth_index[0])); -} - -dom::parser::Iterator::~Iterator() noexcept { - if (depth_index) { - delete[] depth_index; - } -} - -bool dom::parser::Iterator::print(std::ostream &os, bool escape_strings) const { - if (!is_ok()) { - return false; - } - switch (current_type) { - case '"': // we have a string - os << '"'; - if (escape_strings) { - os << internal::escape_json_string( - std::string_view(get_string(), get_string_length())); - } else { - // was: os << get_string();, but given that we can include null - // chars, we - // have to do something crazier: - std::copy(get_string(), - get_string() + get_string_length(), - std::ostream_iterator(os)); - } - os << '"'; - break; - case 'l': // we have a long int - os << get_integer(); - break; - case 'u': - os << get_unsigned_integer(); - break; - case 'd': - os << get_double(); - break; - case 'n': // we have a null - os << "null"; - break; - case 't': // we have a true - os << "true"; - break; - case 'f': // we have a false - os << "false"; - break; - case '{': // we have an object - case '}': // we end an object - case '[': // we start an array - case ']': // we end an array - os << char(current_type); - break; - default: - return false; - } - return true; -} - -bool dom::parser::Iterator::move_to(const char *pointer, uint32_t length) { - char *new_pointer = nullptr; - if (pointer[0] == '#') { - // Converting fragment representation to string representation - new_pointer = new char[length]; - uint32_t new_length = 0; - for (uint32_t i = 1; i < length; i++) { - if (pointer[i] == '%' && pointer[i + 1] == 'x') { -#if __cpp_exceptions - try { -#endif - int fragment = - std::stoi(std::string(&pointer[i + 2], 2), nullptr, 16); - if (fragment == '\\' || fragment == '"' || - (fragment <= 0x1F)) { - // escaping the character - new_pointer[new_length] = '\\'; - new_length++; - } - new_pointer[new_length] = char(fragment); - i += 3; -#if __cpp_exceptions - } catch (std::invalid_argument &) { - delete[] new_pointer; - return false; // the fragment is invalid - } -#endif - } else { - new_pointer[new_length] = pointer[i]; - } - new_length++; - } - length = new_length; - pointer = new_pointer; - } - - // saving the current state - size_t depth_s = depth; - size_t location_s = location; - uint8_t current_type_s = current_type; - uint64_t current_val_s = current_val; - - rewind(); // The json pointer is used from the root of the document. - - bool found = relative_move_to(pointer, length); - delete[] new_pointer; - - if (!found) { - // since the pointer has found nothing, we get back to the original - // position. - depth = depth_s; - location = location_s; - current_type = current_type_s; - current_val = current_val_s; - } - - return found; -} - -bool dom::parser::Iterator::relative_move_to(const char *pointer, - uint32_t length) { - if (length == 0) { - // returns the whole document - return true; - } - - if (pointer[0] != '/') { - // '/' must be the first character - return false; - } - - // finding the key in an object or the index in an array - std::string key_or_index; - uint32_t offset = 1; - - // checking for the "-" case - if (is_array() && pointer[1] == '-') { - if (length != 2) { - // the pointer must be exactly "/-" - // there can't be anything more after '-' as an index - return false; - } - key_or_index = '-'; - offset = length; // will skip the loop coming right after - } - - // We either transform the first reference token to a valid json key - // or we make sure it is a valid index in an array. - for (; offset < length; offset++) { - if (pointer[offset] == '/') { - // beginning of the next key or index - break; - } - if (is_array() && (pointer[offset] < '0' || pointer[offset] > '9')) { - // the index of an array must be an integer - // we also make sure std::stoi won't discard whitespaces later - return false; - } - if (pointer[offset] == '~') { - // "~1" represents "/" - if (pointer[offset + 1] == '1') { - key_or_index += '/'; - offset++; - continue; - } - // "~0" represents "~" - if (pointer[offset + 1] == '0') { - key_or_index += '~'; - offset++; - continue; - } - } - if (pointer[offset] == '\\') { - if (pointer[offset + 1] == '\\' || pointer[offset + 1] == '"' || - (pointer[offset + 1] <= 0x1F)) { - key_or_index += pointer[offset + 1]; - offset++; - continue; - } - return false; // invalid escaped character - } - if (pointer[offset] == '\"') { - // unescaped quote character. this is an invalid case. - // lets do nothing and assume most pointers will be valid. - // it won't find any corresponding json key anyway. - // return false; - } - key_or_index += pointer[offset]; - } - - bool found = false; - if (is_object()) { - if (move_to_key(key_or_index.c_str(), - uint32_t(key_or_index.length()))) { - found = relative_move_to(pointer + offset, length - offset); - } - } else if (is_array()) { - if (key_or_index == "-") { // handling "-" case first - if (down()) { - while (next()) - ; // moving to the end of the array - // moving to the nonexistent value right after... - size_t npos; - if ((current_type == '[') || (current_type == '{')) { - // we need to jump - npos = uint32_t(current_val); - } else { - npos = - location + - ((current_type == 'd' || current_type == 'l') ? 2 : 1); - } - location = npos; - current_val = doc.tape[npos]; - current_type = uint8_t(current_val >> 56); - return true; // how could it fail ? - } - } else { // regular numeric index - // The index can't have a leading '0' - if (key_or_index[0] == '0' && key_or_index.length() > 1) { - return false; - } - // it cannot be empty - if (key_or_index.length() == 0) { - return false; - } - // we already checked the index contains only valid digits - uint32_t index = std::stoi(key_or_index); - if (move_to_index(index)) { - found = relative_move_to(pointer + offset, length - offset); - } - } - } - - return found; -} - -SIMDJSON_POP_DISABLE_WARNINGS -} // namespace simdjson - -#endif // SIMDJSON_DISABLE_DEPRECATED_API - - -#endif // SIMDJSON_INLINE_PARSEDJSON_ITERATOR_H -/* end file include/simdjson/dom/parsedjson_iterator-inl.h */ -/* begin file include/simdjson/dom/parser-inl.h */ -#ifndef SIMDJSON_INLINE_PARSER_H -#define SIMDJSON_INLINE_PARSER_H - -#include -#include - -namespace simdjson { -namespace dom { - -// -// parser inline implementation -// -simdjson_really_inline parser::parser(size_t max_capacity) noexcept - : _max_capacity{max_capacity}, - loaded_bytes(nullptr) {} -simdjson_really_inline parser::parser(parser &&other) noexcept = default; -simdjson_really_inline parser &parser::operator=(parser &&other) noexcept = - default; - -inline bool parser::is_valid() const noexcept { return valid; } -inline int parser::get_error_code() const noexcept { return error; } -inline std::string parser::get_error_message() const noexcept { - return error_message(error); -} - -inline bool parser::dump_raw_tape(std::ostream &os) const noexcept { - return valid ? doc.dump_raw_tape(os) : false; -} - -inline simdjson_result parser::read_file( - const std::string &path) noexcept { - // Open the file - SIMDJSON_PUSH_DISABLE_WARNINGS - SIMDJSON_DISABLE_DEPRECATED_WARNING // Disable CRT_SECURE warning on MSVC: - // manually verified this is safe - std::FILE *fp = std::fopen(path.c_str(), "rb"); - SIMDJSON_POP_DISABLE_WARNINGS - - if (fp == nullptr) { - return IO_ERROR; - } - - // Get the file size - if (std::fseek(fp, 0, SEEK_END) < 0) { - std::fclose(fp); - return IO_ERROR; - } -#if defined(SIMDJSON_VISUAL_STUDIO) && !SIMDJSON_IS_32BITS - __int64 len = _ftelli64(fp); - if (len == -1L) { - std::fclose(fp); - return IO_ERROR; - } -#else - long len = std::ftell(fp); - if ((len < 0) || (len == LONG_MAX)) { - std::fclose(fp); - return IO_ERROR; - } -#endif - - // Make sure we have enough capacity to load the file - if (_loaded_bytes_capacity < size_t(len)) { - loaded_bytes.reset(internal::allocate_padded_buffer(len)); - if (!loaded_bytes) { - std::fclose(fp); - return MEMALLOC; - } - _loaded_bytes_capacity = len; - } - - // Read the string - std::rewind(fp); - size_t bytes_read = std::fread(loaded_bytes.get(), 1, len, fp); - if (std::fclose(fp) != 0 || bytes_read != size_t(len)) { - return IO_ERROR; - } - - return bytes_read; -} - -inline simdjson_result parser::load(const std::string &path) & - noexcept { - size_t len; - auto _error = read_file(path).get(len); - if (_error) { - return _error; - } - return parse(loaded_bytes.get(), len, false); -} - -inline simdjson_result parser::load_many( - const std::string &path, size_t batch_size) noexcept { - size_t len; - auto _error = read_file(path).get(len); - if (_error) { - return _error; - } - if (batch_size < MINIMAL_BATCH_SIZE) { - batch_size = MINIMAL_BATCH_SIZE; - } - return document_stream( - *this, - reinterpret_cast(loaded_bytes.get()), - len, - batch_size); -} - -inline simdjson_result parser::parse_into_document( - document &provided_doc, - const uint8_t *buf, - size_t len, - bool realloc_if_needed) & - noexcept { - // Important: we need to ensure that document has enough capacity. - // Important: It is possible that provided_doc is actually the internal - // 'doc' within the parser!!! - error_code _error = ensure_capacity(provided_doc, len); - if (_error) { - return _error; - } - if (realloc_if_needed) { - // Make sure we have enough capacity to copy len bytes - if (!loaded_bytes || _loaded_bytes_capacity < len) { - loaded_bytes.reset(internal::allocate_padded_buffer(len)); - if (!loaded_bytes) { - return MEMALLOC; - } - _loaded_bytes_capacity = len; - } - std::memcpy(static_cast(loaded_bytes.get()), buf, len); - } - _error = implementation->parse( - realloc_if_needed - ? reinterpret_cast(loaded_bytes.get()) - : buf, - len, - provided_doc); - - if (_error) { - return _error; - } - - return provided_doc.root(); -} - -simdjson_really_inline simdjson_result parser::parse_into_document( - document &provided_doc, - const char *buf, - size_t len, - bool realloc_if_needed) & - noexcept { - return parse_into_document(provided_doc, - reinterpret_cast(buf), - len, - realloc_if_needed); -} -simdjson_really_inline simdjson_result parser::parse_into_document( - document &provided_doc, const std::string &s) & - noexcept { - return parse_into_document(provided_doc, - s.data(), - s.length(), - s.capacity() - s.length() < SIMDJSON_PADDING); -} -simdjson_really_inline simdjson_result parser::parse_into_document( - document &provided_doc, const padded_string &s) & - noexcept { - return parse_into_document(provided_doc, s.data(), s.length(), false); -} - - -inline simdjson_result parser::parse(const uint8_t *buf, - size_t len, - bool realloc_if_needed) & - noexcept { - return parse_into_document(doc, buf, len, realloc_if_needed); -} - -simdjson_really_inline simdjson_result parser::parse( - const char *buf, size_t len, bool realloc_if_needed) & - noexcept { - return parse( - reinterpret_cast(buf), len, realloc_if_needed); -} -simdjson_really_inline simdjson_result parser::parse( - const std::string &s) & - noexcept { - return parse( - s.data(), s.length(), s.capacity() - s.length() < SIMDJSON_PADDING); -} -simdjson_really_inline simdjson_result parser::parse( - const padded_string &s) & - noexcept { - return parse(s.data(), s.length(), false); -} - -inline simdjson_result parser::parse_many( - const uint8_t *buf, size_t len, size_t batch_size) noexcept { - if (batch_size < MINIMAL_BATCH_SIZE) { - batch_size = MINIMAL_BATCH_SIZE; - } - return document_stream(*this, buf, len, batch_size); -} -inline simdjson_result parser::parse_many( - const char *buf, size_t len, size_t batch_size) noexcept { - return parse_many(reinterpret_cast(buf), len, batch_size); -} -inline simdjson_result parser::parse_many( - const std::string &s, size_t batch_size) noexcept { - return parse_many(s.data(), s.length(), batch_size); -} -inline simdjson_result parser::parse_many( - const padded_string &s, size_t batch_size) noexcept { - return parse_many(s.data(), s.length(), batch_size); -} - -simdjson_really_inline size_t parser::capacity() const noexcept { - return implementation ? implementation->capacity() : 0; -} -simdjson_really_inline size_t parser::max_capacity() const noexcept { - return _max_capacity; -} -simdjson_really_inline size_t parser::max_depth() const noexcept { - return implementation ? implementation->max_depth() : DEFAULT_MAX_DEPTH; -} - -simdjson_warn_unused inline error_code parser::allocate( - size_t capacity, size_t max_depth) noexcept { - // - // Reallocate implementation if needed - // - error_code err; - if (implementation) { - err = implementation->allocate(capacity, max_depth); - } else { - err = simdjson::get_active_implementation() - ->create_dom_parser_implementation( - capacity, max_depth, implementation); - } - if (err) { - return err; - } - return SUCCESS; -} - -#ifndef SIMDJSON_DISABLE_DEPRECATED_API -simdjson_warn_unused inline bool parser::allocate_capacity( - size_t capacity, size_t max_depth) noexcept { - return !allocate(capacity, max_depth); -} -#endif // SIMDJSON_DISABLE_DEPRECATED_API - -inline error_code parser::ensure_capacity(size_t desired_capacity) noexcept { - return ensure_capacity(doc, desired_capacity); -} - - -inline error_code parser::ensure_capacity(document &target_document, - size_t desired_capacity) noexcept { - // 1. It is wasteful to allocate a document and a parser for documents - // spanning less than MINIMAL_DOCUMENT_CAPACITY bytes. - // 2. If we allow desired_capacity = 0 then it is possible to exit this - // function with implementation == nullptr. - if (desired_capacity < MINIMAL_DOCUMENT_CAPACITY) { - desired_capacity = MINIMAL_DOCUMENT_CAPACITY; - } - // If we don't have enough capacity, (try to) automatically bump it. - // If the document needs allocation, do it too. - // Both in one if statement to minimize unlikely branching. - // - // Note: we must make sure that this function is called if capacity() == 0. - // We do so because we - // ensure that desired_capacity > 0. - if (simdjson_unlikely(capacity() < desired_capacity || - target_document.capacity() < desired_capacity)) { - if (desired_capacity > max_capacity()) { - return error = CAPACITY; - } - error_code err1 = target_document.capacity() < desired_capacity - ? target_document.allocate(desired_capacity) - : SUCCESS; - error_code err2 = capacity() < desired_capacity - ? allocate(desired_capacity, max_depth()) - : SUCCESS; - if (err1 != SUCCESS) { - return error = err1; - } - if (err2 != SUCCESS) { - return error = err2; - } - } - return SUCCESS; -} - -simdjson_really_inline void parser::set_max_capacity( - size_t max_capacity) noexcept { - if (max_capacity < MINIMAL_DOCUMENT_CAPACITY) { - _max_capacity = max_capacity; - } else { - _max_capacity = MINIMAL_DOCUMENT_CAPACITY; - } -} - -} // namespace dom -} // namespace simdjson - -#endif // SIMDJSON_INLINE_PARSER_H -/* end file include/simdjson/dom/parser-inl.h */ -/* begin file include/simdjson/internal/tape_ref-inl.h */ -#ifndef SIMDJSON_INLINE_TAPE_REF_H -#define SIMDJSON_INLINE_TAPE_REF_H - -#include - -namespace simdjson { -namespace internal { - -// -// tape_ref inline implementation -// -simdjson_really_inline tape_ref::tape_ref() noexcept : doc{nullptr}, - json_index{0} {} -simdjson_really_inline tape_ref::tape_ref(const dom::document *_doc, - size_t _json_index) noexcept - : doc{_doc}, - json_index{_json_index} {} - - -simdjson_really_inline bool tape_ref::is_document_root() const noexcept { - return json_index == 1; // should we ever change the structure of the tape, - // this should get updated. -} - -// Some value types have a specific on-tape word value. It can be faster -// to check the type by doing a word-to-word comparison instead of extracting -// the -// most significant 8 bits. - -simdjson_really_inline bool tape_ref::is_double() const noexcept { - constexpr uint64_t tape_double = uint64_t(tape_type::DOUBLE) << 56; - return doc->tape[json_index] == tape_double; -} -simdjson_really_inline bool tape_ref::is_int64() const noexcept { - constexpr uint64_t tape_int64 = uint64_t(tape_type::INT64) << 56; - return doc->tape[json_index] == tape_int64; -} -simdjson_really_inline bool tape_ref::is_uint64() const noexcept { - constexpr uint64_t tape_uint64 = uint64_t(tape_type::UINT64) << 56; - return doc->tape[json_index] == tape_uint64; -} -simdjson_really_inline bool tape_ref::is_false() const noexcept { - constexpr uint64_t tape_false = uint64_t(tape_type::FALSE_VALUE) << 56; - return doc->tape[json_index] == tape_false; -} -simdjson_really_inline bool tape_ref::is_true() const noexcept { - constexpr uint64_t tape_true = uint64_t(tape_type::TRUE_VALUE) << 56; - return doc->tape[json_index] == tape_true; -} -simdjson_really_inline bool tape_ref::is_null_on_tape() const noexcept { - constexpr uint64_t tape_null = uint64_t(tape_type::NULL_VALUE) << 56; - return doc->tape[json_index] == tape_null; -} - -inline size_t tape_ref::after_element() const noexcept { - switch (tape_ref_type()) { - case tape_type::START_ARRAY: - case tape_type::START_OBJECT: - return matching_brace_index(); - case tape_type::UINT64: - case tape_type::INT64: - case tape_type::DOUBLE: - return json_index + 2; - default: - return json_index + 1; - } -} -simdjson_really_inline tape_type tape_ref::tape_ref_type() const noexcept { - return static_cast(doc->tape[json_index] >> 56); -} -simdjson_really_inline uint64_t internal::tape_ref::tape_value() const - noexcept { - return doc->tape[json_index] & internal::JSON_VALUE_MASK; -} -simdjson_really_inline uint32_t internal::tape_ref::matching_brace_index() const - noexcept { - return uint32_t(doc->tape[json_index]); -} -simdjson_really_inline uint32_t internal::tape_ref::scope_count() const - noexcept { - return uint32_t((doc->tape[json_index] >> 32) & internal::JSON_COUNT_MASK); -} - -template -simdjson_really_inline T tape_ref::next_tape_value() const noexcept { - static_assert(sizeof(T) == sizeof(uint64_t), - "next_tape_value() template parameter must be 64-bit"); - // Though the following is tempting... - // return *reinterpret_cast(&doc->tape[json_index + 1]); - // It is not generally safe. It is safer, and often faster to rely - // on memcpy. Yes, it is uglier, but it is also encapsulated. - T x; - std::memcpy(&x, &doc->tape[json_index + 1], sizeof(uint64_t)); - return x; -} - -simdjson_really_inline uint32_t internal::tape_ref::get_string_length() const - noexcept { - size_t string_buf_index = size_t(tape_value()); - uint32_t len; - std::memcpy(&len, &doc->string_buf[string_buf_index], sizeof(len)); - return len; -} - -simdjson_really_inline const char *internal::tape_ref::get_c_str() const - noexcept { - size_t string_buf_index = size_t(tape_value()); - return reinterpret_cast( - &doc->string_buf[string_buf_index + sizeof(uint32_t)]); -} - -inline std::string_view internal::tape_ref::get_string_view() const noexcept { - return std::string_view(get_c_str(), get_string_length()); -} - -} // namespace internal -} // namespace simdjson - -#endif // SIMDJSON_INLINE_TAPE_REF_H -/* end file include/simdjson/internal/tape_ref-inl.h */ -/* begin file include/simdjson/dom/serialization-inl.h */ - -#ifndef SIMDJSON_SERIALIZATION_INL_H -#define SIMDJSON_SERIALIZATION_INL_H - - -#include -#include - -namespace simdjson { -namespace dom { -inline bool parser::print_json(std::ostream &os) const noexcept { - if (!valid) { - return false; - } - simdjson::internal::string_builder<> sb; - sb.append(doc.root()); - std::string_view answer = sb.str(); - os << answer; - return true; -} -} -/*** - * Number utility functions - **/ - - -namespace { -/**@private - * Escape sequence like \b or \u0001 - * We expect that most compilers will use 8 bytes for this data structure. - **/ -struct escape_sequence { - uint8_t length; - const char - string[7]; // technically, we only ever need 6 characters, we pad to 8 -}; -/**@private - * This converts a signed integer into a character sequence. - * The caller is responsible for providing enough memory (at least - * 20 characters.) - * Though various runtime libraries provide itoa functions, - * it is not part of the C++ standard. The C++17 standard - * adds the to_chars functions which would do as well, but - * we want to support C++11. - */ -char *fast_itoa(char *output, int64_t value) noexcept { - // This is a standard implementation of itoa. - char buffer[20]; - uint64_t value_positive; - // In general, negating a signed integer is unsafe. - if (value < 0) { - *output++ = '-'; - // Doing value_positive = -value; while avoiding - // undefined behavior warnings. - // It assumes two complement's which is universal at this - // point in time. - std::memcpy(&value_positive, &value, sizeof(value)); - value_positive = (~value_positive) + 1; // this is a negation - } else { - value_positive = value; - } - // We work solely with value_positive. It *might* be easier - // for an optimizing compiler to deal with an unsigned variable - // as far as performance goes. - const char *const end_buffer = buffer + 20; - char *write_pointer = buffer + 19; - // A faster approach is possible if we expect large integers: - // unroll the loop (work in 100s, 1000s) and use some kind of - // memoization. - while (value_positive >= 10) { - *write_pointer-- = char('0' + (value_positive % 10)); - value_positive /= 10; - } - *write_pointer = char('0' + value_positive); - size_t len = end_buffer - write_pointer; - std::memcpy(output, write_pointer, len); - return output + len; -} -/**@private - * This converts an unsigned integer into a character sequence. - * The caller is responsible for providing enough memory (at least - * 19 characters.) - * Though various runtime libraries provide itoa functions, - * it is not part of the C++ standard. The C++17 standard - * adds the to_chars functions which would do as well, but - * we want to support C++11. - */ -char *fast_itoa(char *output, uint64_t value) noexcept { - // This is a standard implementation of itoa. - char buffer[20]; - const char *const end_buffer = buffer + 20; - char *write_pointer = buffer + 19; - // A faster approach is possible if we expect large integers: - // unroll the loop (work in 100s, 1000s) and use some kind of - // memoization. - while (value >= 10) { - *write_pointer-- = char('0' + (value % 10)); - value /= 10; - }; - *write_pointer = char('0' + value); - size_t len = end_buffer - write_pointer; - std::memcpy(output, write_pointer, len); - return output + len; -} -} // anonymous namespace -namespace internal { - -/*** - * Minifier/formatter code. - **/ - -simdjson_really_inline void mini_formatter::number(uint64_t x) { - char number_buffer[24]; - char *newp = fast_itoa(number_buffer, x); - buffer.insert(buffer.end(), number_buffer, newp); -} - -simdjson_really_inline void mini_formatter::number(int64_t x) { - char number_buffer[24]; - char *newp = fast_itoa(number_buffer, x); - buffer.insert(buffer.end(), number_buffer, newp); -} - -simdjson_really_inline void mini_formatter::number(double x) { - char number_buffer[24]; - // Currently, passing the nullptr to the second argument is - // safe because our implementation does not check the second - // argument. - char *newp = internal::to_chars(number_buffer, nullptr, x); - buffer.insert(buffer.end(), number_buffer, newp); -} - -simdjson_really_inline void mini_formatter::start_array() { one_char('['); } -simdjson_really_inline void mini_formatter::end_array() { one_char(']'); } -simdjson_really_inline void mini_formatter::start_object() { one_char('{'); } -simdjson_really_inline void mini_formatter::end_object() { one_char('}'); } -simdjson_really_inline void mini_formatter::comma() { one_char(','); } - - -simdjson_really_inline void mini_formatter::true_atom() { - const char *s = "true"; - buffer.insert(buffer.end(), s, s + 4); -} -simdjson_really_inline void mini_formatter::false_atom() { - const char *s = "false"; - buffer.insert(buffer.end(), s, s + 5); -} -simdjson_really_inline void mini_formatter::null_atom() { - const char *s = "null"; - buffer.insert(buffer.end(), s, s + 4); -} -simdjson_really_inline void mini_formatter::one_char(char c) { - buffer.push_back(c); -} -simdjson_really_inline void mini_formatter::key(std::string_view unescaped) { - string(unescaped); - one_char(':'); -} -simdjson_really_inline void mini_formatter::string(std::string_view unescaped) { - one_char('\"'); - size_t i = 0; - // Fast path for the case where we have no control character, no ", and no - // backslash. - // This should include most keys. - // - // We would like to use 'bool' but some compilers take offense to bitwise - // operation - // with bool types. - constexpr static char needs_escaping[] = { - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - for (; i + 8 <= unescaped.length(); i += 8) { - // Poor's man vectorization. This could get much faster if we used SIMD. - // - // It is not the case that replacing '|' with '||' would be neutral - // performance-wise. - if (needs_escaping[uint8_t(unescaped[i])] | - needs_escaping[uint8_t(unescaped[i + 1])] | - needs_escaping[uint8_t(unescaped[i + 2])] | - needs_escaping[uint8_t(unescaped[i + 3])] | - needs_escaping[uint8_t(unescaped[i + 4])] | - needs_escaping[uint8_t(unescaped[i + 5])] | - needs_escaping[uint8_t(unescaped[i + 6])] | - needs_escaping[uint8_t(unescaped[i + 7])]) { - break; - } - } - for (; i < unescaped.length(); i++) { - if (needs_escaping[uint8_t(unescaped[i])]) { - break; - } - } - // The following is also possible and omits a 256-byte table, but it is - // slower: - // for (; (i < unescaped.length()) && (uint8_t(unescaped[i]) > 0x1F) - // && (unescaped[i] != '\"') && (unescaped[i] != '\\'); i++) {} - - // At least for long strings, the following should be fast. We could - // do better by integrating the checks and the insertion. - buffer.insert(buffer.end(), unescaped.data(), unescaped.data() + i); - // We caught a control character if we enter this loop (slow). - // Note that we are do not restart from the beginning, but rather we - // continue - // from the point where we encountered something that requires escaping. - for (; i < unescaped.length(); i++) { - switch (unescaped[i]) { - case '\"': { - const char *s = "\\\""; - buffer.insert(buffer.end(), s, s + 2); - } break; - case '\\': { - const char *s = "\\\\"; - buffer.insert(buffer.end(), s, s + 2); - } break; - default: - if (uint8_t(unescaped[i]) <= 0x1F) { - // If packed, this uses 8 * 32 bytes. - // Note that we expect most compilers to embed this code in - // the data - // section. - constexpr static escape_sequence escaped[32] = { - {6, "\\u0000"}, {6, "\\u0001"}, {6, "\\u0002"}, - {6, "\\u0003"}, {6, "\\u0004"}, {6, "\\u0005"}, - {6, "\\u0006"}, {6, "\\u0007"}, {2, "\\b"}, - {2, "\\t"}, {2, "\\n"}, {6, "\\u000b"}, - {2, "\\f"}, {2, "\\r"}, {6, "\\u000e"}, - {6, "\\u000f"}, {6, "\\u0010"}, {6, "\\u0011"}, - {6, "\\u0012"}, {6, "\\u0013"}, {6, "\\u0014"}, - {6, "\\u0015"}, {6, "\\u0016"}, {6, "\\u0017"}, - {6, "\\u0018"}, {6, "\\u0019"}, {6, "\\u001a"}, - {6, "\\u001b"}, {6, "\\u001c"}, {6, "\\u001d"}, - {6, "\\u001e"}, {6, "\\u001f"}}; - auto u = escaped[uint8_t(unescaped[i])]; - buffer.insert(buffer.end(), u.string, u.string + u.length); - } else { - one_char(unescaped[i]); - } - } // switch - } // for - one_char('\"'); -} - -inline void mini_formatter::clear() { buffer.clear(); } - -simdjson_really_inline std::string_view mini_formatter::str() const { - return std::string_view(buffer.data(), buffer.size()); -} - - -/*** - * String building code. - **/ - -template -inline void string_builder::append(simdjson::dom::element value) { - // using tape_type = simdjson::internal::tape_type; - size_t depth = 0; - constexpr size_t MAX_DEPTH = 16; - bool is_object[MAX_DEPTH]; - is_object[0] = false; - bool after_value = false; - - internal::tape_ref iter(value.tape); - do { - // print commas after each value - if (after_value) { - format.comma(); - } - // If we are in an object, print the next key and :, and skip to the - // next - // value. - if (is_object[depth]) { - format.key(iter.get_string_view()); - iter.json_index++; - } - switch (iter.tape_ref_type()) { - // Arrays - case tape_type::START_ARRAY: { - // If we're too deep, we need to recurse to go deeper. - depth++; - if (simdjson_unlikely(depth >= MAX_DEPTH)) { - append(simdjson::dom::array(iter)); - iter.json_index = - iter.matching_brace_index() - 1; // Jump to the ] - depth--; - break; - } - - // Output start [ - format.start_array(); - iter.json_index++; - - // Handle empty [] (we don't want to come back around and print - // commas) - if (iter.tape_ref_type() == tape_type::END_ARRAY) { - format.end_array(); - depth--; - break; - } - - is_object[depth] = false; - after_value = false; - continue; - } - - // Objects - case tape_type::START_OBJECT: { - // If we're too deep, we need to recurse to go deeper. - depth++; - if (simdjson_unlikely(depth >= MAX_DEPTH)) { - append(simdjson::dom::object(iter)); - iter.json_index = - iter.matching_brace_index() - 1; // Jump to the } - depth--; - break; - } - - // Output start { - format.start_object(); - iter.json_index++; - - // Handle empty {} (we don't want to come back around and print - // commas) - if (iter.tape_ref_type() == tape_type::END_OBJECT) { - format.end_object(); - depth--; - break; - } - - is_object[depth] = true; - after_value = false; - continue; - } - - // Scalars - case tape_type::STRING: - format.string(iter.get_string_view()); - break; - case tape_type::INT64: - format.number(iter.next_tape_value()); - iter.json_index++; // numbers take up 2 spots, so we need to - // increment - // extra - break; - case tape_type::UINT64: - format.number(iter.next_tape_value()); - iter.json_index++; // numbers take up 2 spots, so we need to - // increment - // extra - break; - case tape_type::DOUBLE: - format.number(iter.next_tape_value()); - iter.json_index++; // numbers take up 2 spots, so we need to - // increment - // extra - break; - case tape_type::TRUE_VALUE: - format.true_atom(); - break; - case tape_type::FALSE_VALUE: - format.false_atom(); - break; - case tape_type::NULL_VALUE: - format.null_atom(); - break; - - // These are impossible - case tape_type::END_ARRAY: - case tape_type::END_OBJECT: - case tape_type::ROOT: - SIMDJSON_UNREACHABLE(); - } - iter.json_index++; - after_value = true; - - // Handle multiple ends in a row - while (depth != 0 && (iter.tape_ref_type() == tape_type::END_ARRAY || - iter.tape_ref_type() == tape_type::END_OBJECT)) { - if (iter.tape_ref_type() == tape_type::END_ARRAY) { - format.end_array(); - } else { - format.end_object(); - } - depth--; - iter.json_index++; - } - - // Stop when we're at depth 0 - } while (depth != 0); -} - -template -inline void string_builder::append(simdjson::dom::object value) { - format.start_object(); - auto pair = value.begin(); - auto end = value.end(); - if (pair != end) { - append(*pair); - for (++pair; pair != end; ++pair) { - format.comma(); - append(*pair); - } - } - format.end_object(); -} - -template -inline void string_builder::append(simdjson::dom::array value) { - format.start_array(); - auto iter = value.begin(); - auto end = value.end(); - if (iter != end) { - append(*iter); - for (++iter; iter != end; ++iter) { - format.comma(); - append(*iter); - } - } - format.end_array(); -} - -template -simdjson_really_inline void string_builder::append( - simdjson::dom::key_value_pair kv) { - format.key(kv.key); - append(kv.value); -} - -template -simdjson_really_inline void string_builder::clear() { - format.clear(); -} - -template -simdjson_really_inline std::string_view string_builder::str() - const { - return format.str(); -} - - -} // namespace internal -} // namespace simdjson - -#endif -/* end file include/simdjson/dom/serialization-inl.h */ - -SIMDJSON_POP_DISABLE_WARNINGS - -#endif // SIMDJSON_DOM_H -/* end file include/simdjson/dom.h */ -/* begin file include/simdjson/builtin.h */ -#ifndef SIMDJSON_BUILTIN_H -#define SIMDJSON_BUILTIN_H - -/* begin file include/simdjson/implementations.h */ -#ifndef SIMDJSON_IMPLEMENTATIONS_H -#define SIMDJSON_IMPLEMENTATIONS_H - -/* begin file include/simdjson/implementation-base.h */ -#ifndef SIMDJSON_IMPLEMENTATION_BASE_H -#define SIMDJSON_IMPLEMENTATION_BASE_H - -/** - * @file - * - * Includes common stuff needed for implementations. - */ - - -// Implementation-internal files (must be included before the implementations -// themselves, to keep -// amalgamation working--otherwise, the first time a file is included, it might -// be put inside the -// #ifdef SIMDJSON_IMPLEMENTATION_ARM64/FALLBACK/etc., which means the other -// implementations can't -// compile unless that implementation is turned on). -/* begin file include/simdjson/internal/jsoncharutils_tables.h */ -#ifndef SIMDJSON_INTERNAL_JSONCHARUTILS_TABLES_H -#define SIMDJSON_INTERNAL_JSONCHARUTILS_TABLES_H - - -#ifdef JSON_TEST_STRINGS -void found_string(const uint8_t *buf, - const uint8_t *parsed_begin, - const uint8_t *parsed_end); -void found_bad_string(const uint8_t *buf); -#endif - -namespace simdjson { -namespace internal { -// structural chars here are -// they are { 0x7b } 0x7d : 0x3a [ 0x5b ] 0x5d , 0x2c (and NULL) -// we are also interested in the four whitespace characters -// space 0x20, linefeed 0x0a, horizontal tab 0x09 and carriage return 0x0d - -extern SIMDJSON_DLLIMPORTEXPORT const bool - structural_or_whitespace_negated[256]; -extern SIMDJSON_DLLIMPORTEXPORT const bool structural_or_whitespace[256]; -extern SIMDJSON_DLLIMPORTEXPORT const uint32_t digit_to_val32[886]; - -} // namespace internal -} // namespace simdjson - -#endif // SIMDJSON_INTERNAL_JSONCHARUTILS_TABLES_H -/* end file include/simdjson/internal/jsoncharutils_tables.h */ -/* begin file include/simdjson/internal/numberparsing_tables.h */ -#ifndef SIMDJSON_INTERNAL_NUMBERPARSING_TABLES_H -#define SIMDJSON_INTERNAL_NUMBERPARSING_TABLES_H - - -namespace simdjson { -namespace internal { -/** - * The smallest non-zero float (binary64) is 2^-1074. - * We take as input numbers of the form w x 10^q where w < 2^64. - * We have that w * 10^-343 < 2^(64-344) 5^-343 < 2^-1076. - * However, we have that - * (2^64-1) * 10^-342 = (2^64-1) * 2^-342 * 5^-342 > 2^-1074. - * Thus it is possible for a number of the form w * 10^-342 where - * w is a 64-bit value to be a non-zero floating-point number. - ********* - * Any number of form w * 10^309 where w>= 1 is going to be - * infinite in binary64 so we never need to worry about powers - * of 5 greater than 308. - */ -constexpr int smallest_power = -342; -constexpr int largest_power = 308; - -/** - * Represents a 128-bit value. - * low: least significant 64 bits. - * high: most significant 64 bits. - */ -struct value128 { - uint64_t low; - uint64_t high; -}; - - -// Precomputed powers of ten from 10^0 to 10^22. These -// can be represented exactly using the double type. -extern SIMDJSON_DLLIMPORTEXPORT const double power_of_ten[]; - - -/** - * When mapping numbers from decimal to binary, - * we go from w * 10^q to m * 2^p but we have - * 10^q = 5^q * 2^q, so effectively - * we are trying to match - * w * 2^q * 5^q to m * 2^p. Thus the powers of two - * are not a concern since they can be represented - * exactly using the binary notation, only the powers of five - * affect the binary significand. - */ - - -// The truncated powers of five from 5^-342 all the way to 5^308 -// The mantissa is truncated to 128 bits, and -// never rounded up. Uses about 10KB. -extern SIMDJSON_DLLIMPORTEXPORT const uint64_t power_of_five_128[]; -} // namespace internal -} // namespace simdjson - -#endif // SIMDJSON_INTERNAL_NUMBERPARSING_TABLES_H -/* end file include/simdjson/internal/numberparsing_tables.h */ -/* begin file include/simdjson/internal/simdprune_tables.h */ -#ifndef SIMDJSON_INTERNAL_SIMDPRUNE_TABLES_H -#define SIMDJSON_INTERNAL_SIMDPRUNE_TABLES_H - -#include - -namespace simdjson { // table modified and copied from -namespace internal { // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetTable - -extern SIMDJSON_DLLIMPORTEXPORT const unsigned char BitsSetTable256mul2[256]; - -extern SIMDJSON_DLLIMPORTEXPORT const uint8_t pshufb_combine_table[272]; - -// 256 * 8 bytes = 2kB, easily fits in cache. -extern SIMDJSON_DLLIMPORTEXPORT const uint64_t thintable_epi8[256]; - -} // namespace internal -} // namespace simdjson - -#endif // SIMDJSON_INTERNAL_SIMDPRUNE_TABLES_H -/* end file include/simdjson/internal/simdprune_tables.h */ - -#endif // SIMDJSON_IMPLEMENTATION_BASE_H -/* end file include/simdjson/implementation-base.h */ - -// -// First, figure out which implementations can be run. Doing it here makes it so -// we don't have to worry about the order -// in which we include them. -// - -#ifndef SIMDJSON_IMPLEMENTATION_ARM64 -#define SIMDJSON_IMPLEMENTATION_ARM64 (SIMDJSON_IS_ARM64) -#endif -#define SIMDJSON_CAN_ALWAYS_RUN_ARM64 \ - SIMDJSON_IMPLEMENTATION_ARM64 &&SIMDJSON_IS_ARM64 - -// Default Haswell to on if this is x86-64. Even if we're not compiled for it, -// it could be selected -// at runtime. -#ifndef SIMDJSON_IMPLEMENTATION_HASWELL -#define SIMDJSON_IMPLEMENTATION_HASWELL (SIMDJSON_IS_X86_64) -#endif -#ifdef _MSC_VER -// To see why (__BMI__) && (__PCLMUL__) && (__LZCNT__) are not part of this -// next line, see -// https://github.com/simdjson/simdjson/issues/1247 -#define SIMDJSON_CAN_ALWAYS_RUN_HASWELL \ - ((SIMDJSON_IMPLEMENTATION_HASWELL) && (SIMDJSON_IS_X86_64) && (__AVX2__)) -#else -#define SIMDJSON_CAN_ALWAYS_RUN_HASWELL \ - ((SIMDJSON_IMPLEMENTATION_HASWELL) && (SIMDJSON_IS_X86_64) && \ - (__AVX2__) && (__BMI__) && (__PCLMUL__) && (__LZCNT__)) -#endif - -// Default Westmere to on if this is x86-64, unless we'll always select Haswell. -#ifndef SIMDJSON_IMPLEMENTATION_WESTMERE -#define SIMDJSON_IMPLEMENTATION_WESTMERE \ - (SIMDJSON_IS_X86_64 && !SIMDJSON_REQUIRES_HASWELL) -#endif -#define SIMDJSON_CAN_ALWAYS_RUN_WESTMERE \ - (SIMDJSON_IMPLEMENTATION_WESTMERE && SIMDJSON_IS_X86_64 && __SSE4_2__ && \ - __PCLMUL__) - -#ifndef SIMDJSON_IMPLEMENTATION_PPC64 -#define SIMDJSON_IMPLEMENTATION_PPC64 (SIMDJSON_IS_PPC64) -#endif -#define SIMDJSON_CAN_ALWAYS_RUN_PPC64 \ - SIMDJSON_IMPLEMENTATION_PPC64 &&SIMDJSON_IS_PPC64 - -// Default Fallback to on unless a builtin implementation has already been -// selected. -#ifndef SIMDJSON_IMPLEMENTATION_FALLBACK -#define SIMDJSON_IMPLEMENTATION_FALLBACK \ - 1 // (!SIMDJSON_CAN_ALWAYS_RUN_ARM64 && !SIMDJSON_CAN_ALWAYS_RUN_HASWELL && - // !SIMDJSON_CAN_ALWAYS_RUN_WESTMERE && !SIMDJSON_CAN_ALWAYS_RUN_PPC64) -#endif -#define SIMDJSON_CAN_ALWAYS_RUN_FALLBACK SIMDJSON_IMPLEMENTATION_FALLBACK - -SIMDJSON_PUSH_DISABLE_WARNINGS -SIMDJSON_DISABLE_UNDESIRED_WARNINGS - -// Implementations -/* begin file include/simdjson/arm64.h */ -#ifndef SIMDJSON_ARM64_H -#define SIMDJSON_ARM64_H - - -#if SIMDJSON_IMPLEMENTATION_ARM64 - -namespace simdjson { -/** - * Implementation for NEON (ARMv8). - */ -namespace arm64 {} // namespace arm64 -} // namespace simdjson - -/* begin file include/simdjson/arm64/implementation.h */ -#ifndef SIMDJSON_ARM64_IMPLEMENTATION_H -#define SIMDJSON_ARM64_IMPLEMENTATION_H - - -namespace simdjson { -namespace arm64 { - -namespace { -using namespace simdjson; -using namespace simdjson::dom; -} - -class implementation final : public simdjson::implementation { - public: - simdjson_really_inline implementation() - : simdjson::implementation( - "arm64", "ARM NEON", internal::instruction_set::NEON) {} - simdjson_warn_unused error_code create_dom_parser_implementation( - size_t capacity, - size_t max_length, - std::unique_ptr &dst) const - noexcept final; - simdjson_warn_unused error_code - minify(const uint8_t *buf, size_t len, uint8_t *dst, size_t &dst_len) const - noexcept final; - simdjson_warn_unused bool validate_utf8(const char *buf, size_t len) const - noexcept final; -}; - -} // namespace arm64 -} // namespace simdjson - -#endif // SIMDJSON_ARM64_IMPLEMENTATION_H -/* end file include/simdjson/arm64/implementation.h */ - -/* begin file include/simdjson/arm64/begin.h */ -// redefining SIMDJSON_IMPLEMENTATION to "arm64" -// #define SIMDJSON_IMPLEMENTATION arm64 -/* end file include/simdjson/arm64/begin.h */ - -// Declarations -/* begin file include/simdjson/generic/dom_parser_implementation.h */ - -namespace simdjson { -namespace arm64 { - -// expectation: sizeof(open_container) = 64/8. -struct open_container { - uint32_t tape_index; // where, on the tape, does the scope ([,{) begins - uint32_t count; // how many elements in the scope -}; // struct open_container - -static_assert(sizeof(open_container) == 64 / 8, - "Open container must be 64 bits"); - -class dom_parser_implementation final - : public internal::dom_parser_implementation { - public: - /** Tape location of each open { or [ */ - std::unique_ptr open_containers{}; - /** Whether each open container is a [ or { */ - std::unique_ptr is_array{}; - /** Buffer passed to stage 1 */ - const uint8_t *buf{}; - /** Length passed to stage 1 */ - size_t len{0}; - /** Document passed to stage 2 */ - dom::document *doc{}; - - inline dom_parser_implementation() noexcept; - inline dom_parser_implementation( - dom_parser_implementation &&other) noexcept; - inline dom_parser_implementation &operator=( - dom_parser_implementation &&other) noexcept; - dom_parser_implementation(const dom_parser_implementation &) = delete; - dom_parser_implementation &operator=(const dom_parser_implementation &) = - delete; - - simdjson_warn_unused error_code parse(const uint8_t *buf, - size_t len, - dom::document &doc) noexcept final; - simdjson_warn_unused error_code stage1(const uint8_t *buf, - size_t len, - stage1_mode partial) noexcept final; - simdjson_warn_unused error_code stage2(dom::document &doc) noexcept final; - simdjson_warn_unused error_code - stage2_next(dom::document &doc) noexcept final; - inline simdjson_warn_unused error_code - set_capacity(size_t capacity) noexcept final; - inline simdjson_warn_unused error_code - set_max_depth(size_t max_depth) noexcept final; - - private: - simdjson_really_inline simdjson_warn_unused error_code - set_capacity_stage1(size_t capacity); -}; - -} // namespace arm64 -} // namespace simdjson - -namespace simdjson { -namespace arm64 { - -inline dom_parser_implementation::dom_parser_implementation() noexcept = - default; -inline dom_parser_implementation::dom_parser_implementation( - dom_parser_implementation &&other) noexcept = default; -inline dom_parser_implementation &dom_parser_implementation::operator=( - dom_parser_implementation &&other) noexcept = default; - -// Leaving these here so they can be inlined if so desired -inline simdjson_warn_unused error_code -dom_parser_implementation::set_capacity(size_t capacity) noexcept { - if (capacity > SIMDJSON_MAXSIZE_BYTES) { - return CAPACITY; - } - // Stage 1 index output - size_t max_structures = SIMDJSON_ROUNDUP_N(capacity, 64) + 2 + 7; - structural_indexes.reset(new (std::nothrow) uint32_t[max_structures]); - if (!structural_indexes) { - _capacity = 0; - return MEMALLOC; - } - structural_indexes[0] = 0; - n_structural_indexes = 0; - - _capacity = capacity; - return SUCCESS; -} - -inline simdjson_warn_unused error_code -dom_parser_implementation::set_max_depth(size_t max_depth) noexcept { - // Stage 2 stacks - open_containers.reset(new (std::nothrow) open_container[max_depth]); - is_array.reset(new (std::nothrow) bool[max_depth]); - if (!is_array || !open_containers) { - _max_depth = 0; - return MEMALLOC; - } - - _max_depth = max_depth; - return SUCCESS; -} - -} // namespace arm64 -} // namespace simdjson -/* end file include/simdjson/generic/dom_parser_implementation.h */ -/* begin file include/simdjson/arm64/intrinsics.h */ -#ifndef SIMDJSON_ARM64_INTRINSICS_H -#define SIMDJSON_ARM64_INTRINSICS_H - -// This should be the correct header whether -// you use visual studio or other compilers. -#include - -#endif // SIMDJSON_ARM64_INTRINSICS_H -/* end file include/simdjson/arm64/intrinsics.h */ -/* begin file include/simdjson/arm64/bitmanipulation.h */ -#ifndef SIMDJSON_ARM64_BITMANIPULATION_H -#define SIMDJSON_ARM64_BITMANIPULATION_H - -namespace simdjson { -namespace arm64 { -namespace { - -// We sometimes call trailing_zero on inputs that are zero, -// but the algorithms do not end up using the returned value. -// Sadly, sanitizers are not smart enough to figure it out. -SIMDJSON_NO_SANITIZE_UNDEFINED -simdjson_really_inline int trailing_zeroes(uint64_t input_num) { -#ifdef SIMDJSON_REGULAR_VISUAL_STUDIO - unsigned long ret; - // Search the mask data from least significant bit (LSB) - // to the most significant bit (MSB) for a set bit (1). - _BitScanForward64(&ret, input_num); - return (int)ret; -#else // SIMDJSON_REGULAR_VISUAL_STUDIO - return __builtin_ctzll(input_num); -#endif // SIMDJSON_REGULAR_VISUAL_STUDIO -} - -/* result might be undefined when input_num is zero */ -simdjson_really_inline uint64_t clear_lowest_bit(uint64_t input_num) { - return input_num & (input_num - 1); -} - -/* result might be undefined when input_num is zero */ -simdjson_really_inline int leading_zeroes(uint64_t input_num) { -#ifdef SIMDJSON_REGULAR_VISUAL_STUDIO - unsigned long leading_zero = 0; - // Search the mask data from most significant bit (MSB) - // to least significant bit (LSB) for a set bit (1). - if (_BitScanReverse64(&leading_zero, input_num)) - return (int)(63 - leading_zero); - else - return 64; -#else - return __builtin_clzll(input_num); -#endif // SIMDJSON_REGULAR_VISUAL_STUDIO -} - -/* result might be undefined when input_num is zero */ -simdjson_really_inline int count_ones(uint64_t input_num) { - return vaddv_u8(vcnt_u8(vcreate_u8(input_num))); -} - - -#if defined(__GNUC__) // catches clang and gcc - /** - * ARM has a fast 64-bit "bit reversal function" that is handy. However, - * it is not generally available as an intrinsic function under Visual - * Studio (though this might be changing). Even under clang/gcc, we - * apparently need to invoke inline assembly. - */ -/* - * We use SIMDJSON_PREFER_REVERSE_BITS as a hint that algorithms that - * work well with bit reversal may use it. - */ -#define SIMDJSON_PREFER_REVERSE_BITS 1 - -/* reverse the bits */ -simdjson_really_inline uint64_t reverse_bits(uint64_t input_num) { - uint64_t rev_bits; - __asm("rbit %0, %1" : "=r"(rev_bits) : "r"(input_num)); - return rev_bits; -} - -/** - * Flips bit at index 63 - lz. Thus if you have 'leading_zeroes' leading zeroes, - * then this will set to zero the leading bit. It is possible for leading_zeroes - *to be - * greating or equal to 63 in which case we trigger undefined behavior, but the - *output - * of such undefined behavior is never used. - **/ -SIMDJSON_NO_SANITIZE_UNDEFINED -simdjson_really_inline uint64_t zero_leading_bit(uint64_t rev_bits, - int leading_zeroes) { - return rev_bits ^ (uint64_t(0x8000000000000000) >> leading_zeroes); -} - -#endif - -simdjson_really_inline bool add_overflow(uint64_t value1, - uint64_t value2, - uint64_t *result) { -#ifdef SIMDJSON_REGULAR_VISUAL_STUDIO - *result = value1 + value2; - return *result < value1; -#else - return __builtin_uaddll_overflow( - value1, value2, reinterpret_cast(result)); -#endif -} - -} // unnamed namespace -} // namespace arm64 -} // namespace simdjson - -#endif // SIMDJSON_ARM64_BITMANIPULATION_H -/* end file include/simdjson/arm64/bitmanipulation.h */ -/* begin file include/simdjson/arm64/bitmask.h */ -#ifndef SIMDJSON_ARM64_BITMASK_H -#define SIMDJSON_ARM64_BITMASK_H - -namespace simdjson { -namespace arm64 { -namespace { - -// -// Perform a "cumulative bitwise xor," flipping bits each time a 1 is -// encountered. -// -// For example, prefix_xor(00100100) == 00011100 -// -simdjson_really_inline uint64_t prefix_xor(uint64_t bitmask) { - ///////////// - // We could do this with PMULL, but it is apparently slow. - // - //#ifdef __ARM_FEATURE_CRYPTO // some ARM processors lack this extension - // return vmull_p64(-1ULL, bitmask); - //#else - // Analysis by @sebpop: - // When diffing the assembly for src/stage1_find_marks.cpp I see that the - // eors are all spread out - // in between other vector code, so effectively the extra cycles of the - // sequence do not matter - // because the GPR units are idle otherwise and the critical path is on the - // FP side. - // Also the PMULL requires two extra fmovs: GPR->FP (3 cycles in N1, 5 - // cycles in A72 ) - // and FP->GPR (2 cycles on N1 and 5 cycles on A72.) - /////////// - bitmask ^= bitmask << 1; - bitmask ^= bitmask << 2; - bitmask ^= bitmask << 4; - bitmask ^= bitmask << 8; - bitmask ^= bitmask << 16; - bitmask ^= bitmask << 32; - return bitmask; -} - -} // unnamed namespace -} // namespace arm64 -} // namespace simdjson - -#endif -/* end file include/simdjson/arm64/bitmask.h */ -/* begin file include/simdjson/arm64/simd.h */ -#ifndef SIMDJSON_ARM64_SIMD_H -#define SIMDJSON_ARM64_SIMD_H - -#include - - -namespace simdjson { -namespace arm64 { -namespace { -namespace simd { - -#ifdef SIMDJSON_REGULAR_VISUAL_STUDIO -namespace { -// Start of private section with Visual Studio workaround - - -/** - * make_uint8x16_t initializes a SIMD register (uint8x16_t). - * This is needed because, incredibly, the syntax uint8x16_t x = {1,2,3...} - * is not recognized under Visual Studio! This is a workaround. - * Using a std::initializer_list as a parameter resulted in - * inefficient code. With the current approach, if the parameters are - * compile-time constants, - * GNU GCC compiles it to ldr, the same as uint8x16_t x = {1,2,3...}. - * You should not use this function except for compile-time constants: - * it is not efficient. - */ -simdjson_really_inline uint8x16_t make_uint8x16_t(uint8_t x1, - uint8_t x2, - uint8_t x3, - uint8_t x4, - uint8_t x5, - uint8_t x6, - uint8_t x7, - uint8_t x8, - uint8_t x9, - uint8_t x10, - uint8_t x11, - uint8_t x12, - uint8_t x13, - uint8_t x14, - uint8_t x15, - uint8_t x16) { - // Doing a load like so end ups generating worse code. - // uint8_t array[16] = {x1, x2, x3, x4, x5, x6, x7, x8, - // x9, x10,x11,x12,x13,x14,x15,x16}; - // return vld1q_u8(array); - uint8x16_t x{}; - // incredibly, Visual Studio does not allow x[0] = x1 - x = vsetq_lane_u8(x1, x, 0); - x = vsetq_lane_u8(x2, x, 1); - x = vsetq_lane_u8(x3, x, 2); - x = vsetq_lane_u8(x4, x, 3); - x = vsetq_lane_u8(x5, x, 4); - x = vsetq_lane_u8(x6, x, 5); - x = vsetq_lane_u8(x7, x, 6); - x = vsetq_lane_u8(x8, x, 7); - x = vsetq_lane_u8(x9, x, 8); - x = vsetq_lane_u8(x10, x, 9); - x = vsetq_lane_u8(x11, x, 10); - x = vsetq_lane_u8(x12, x, 11); - x = vsetq_lane_u8(x13, x, 12); - x = vsetq_lane_u8(x14, x, 13); - x = vsetq_lane_u8(x15, x, 14); - x = vsetq_lane_u8(x16, x, 15); - return x; -} - -simdjson_really_inline uint8x8_t make_uint8x8_t(uint8_t x1, - uint8_t x2, - uint8_t x3, - uint8_t x4, - uint8_t x5, - uint8_t x6, - uint8_t x7, - uint8_t x8) { - uint8x8_t x{}; - x = vset_lane_u8(x1, x, 0); - x = vset_lane_u8(x2, x, 1); - x = vset_lane_u8(x3, x, 2); - x = vset_lane_u8(x4, x, 3); - x = vset_lane_u8(x5, x, 4); - x = vset_lane_u8(x6, x, 5); - x = vset_lane_u8(x7, x, 6); - x = vset_lane_u8(x8, x, 7); - return x; -} - -// We have to do the same work for make_int8x16_t -simdjson_really_inline int8x16_t make_int8x16_t(int8_t x1, - int8_t x2, - int8_t x3, - int8_t x4, - int8_t x5, - int8_t x6, - int8_t x7, - int8_t x8, - int8_t x9, - int8_t x10, - int8_t x11, - int8_t x12, - int8_t x13, - int8_t x14, - int8_t x15, - int8_t x16) { - // Doing a load like so end ups generating worse code. - // int8_t array[16] = {x1, x2, x3, x4, x5, x6, x7, x8, - // x9, x10,x11,x12,x13,x14,x15,x16}; - // return vld1q_s8(array); - int8x16_t x{}; - // incredibly, Visual Studio does not allow x[0] = x1 - x = vsetq_lane_s8(x1, x, 0); - x = vsetq_lane_s8(x2, x, 1); - x = vsetq_lane_s8(x3, x, 2); - x = vsetq_lane_s8(x4, x, 3); - x = vsetq_lane_s8(x5, x, 4); - x = vsetq_lane_s8(x6, x, 5); - x = vsetq_lane_s8(x7, x, 6); - x = vsetq_lane_s8(x8, x, 7); - x = vsetq_lane_s8(x9, x, 8); - x = vsetq_lane_s8(x10, x, 9); - x = vsetq_lane_s8(x11, x, 10); - x = vsetq_lane_s8(x12, x, 11); - x = vsetq_lane_s8(x13, x, 12); - x = vsetq_lane_s8(x14, x, 13); - x = vsetq_lane_s8(x15, x, 14); - x = vsetq_lane_s8(x16, x, 15); - return x; -} - -// End of private section with Visual Studio workaround -} // namespace -#endif // SIMDJSON_REGULAR_VISUAL_STUDIO - - -template -struct simd8; - -// -// Base class of simd8 and simd8, both of which use uint8x16_t -// internally. -// -template > -struct base_u8 { - uint8x16_t value; - static const int SIZE = sizeof(value); - - // Conversion from/to SIMD register - simdjson_really_inline base_u8(const uint8x16_t _value) : value(_value) {} - simdjson_really_inline operator const uint8x16_t &() const { - return this->value; - } - simdjson_really_inline operator uint8x16_t &() { return this->value; } - - // Bit operations - simdjson_really_inline simd8 operator|(const simd8 other) const { - return vorrq_u8(*this, other); - } - simdjson_really_inline simd8 operator&(const simd8 other) const { - return vandq_u8(*this, other); - } - simdjson_really_inline simd8 operator^(const simd8 other) const { - return veorq_u8(*this, other); - } - simdjson_really_inline simd8 bit_andnot(const simd8 other) const { - return vbicq_u8(*this, other); - } - simdjson_really_inline simd8 operator~() const { return *this ^ 0xFFu; } - simdjson_really_inline simd8 &operator|=(const simd8 other) { - auto this_cast = static_cast *>(this); - *this_cast = *this_cast | other; - return *this_cast; - } - simdjson_really_inline simd8 &operator&=(const simd8 other) { - auto this_cast = static_cast *>(this); - *this_cast = *this_cast & other; - return *this_cast; - } - simdjson_really_inline simd8 &operator^=(const simd8 other) { - auto this_cast = static_cast *>(this); - *this_cast = *this_cast ^ other; - return *this_cast; - } - - friend simdjson_really_inline Mask operator==(const simd8 lhs, - const simd8 rhs) { - return vceqq_u8(lhs, rhs); - } - - template - simdjson_really_inline simd8 prev(const simd8 prev_chunk) const { - return vextq_u8(prev_chunk, *this, 16 - N); - } -}; - -// SIMD byte mask type (returned by things like eq and gt) -template <> -struct simd8 : base_u8 { - typedef uint16_t bitmask_t; - typedef uint32_t bitmask2_t; - - static simdjson_really_inline simd8 splat(bool _value) { - return vmovq_n_u8(uint8_t(-(!!_value))); - } - - simdjson_really_inline simd8(const uint8x16_t _value) - : base_u8(_value) {} - // False constructor - simdjson_really_inline simd8() : simd8(vdupq_n_u8(0)) {} - // Splat constructor - simdjson_really_inline simd8(bool _value) : simd8(splat(_value)) {} - - // We return uint32_t instead of uint16_t because that seems to be more - // efficient for most - // purposes (cutting it down to uint16_t costs performance in some - // compilers). - simdjson_really_inline uint32_t to_bitmask() const { -#ifdef SIMDJSON_REGULAR_VISUAL_STUDIO - const uint8x16_t bit_mask = make_uint8x16_t(0x01, - 0x02, - 0x4, - 0x8, - 0x10, - 0x20, - 0x40, - 0x80, - 0x01, - 0x02, - 0x4, - 0x8, - 0x10, - 0x20, - 0x40, - 0x80); -#else - const uint8x16_t bit_mask = {0x01, - 0x02, - 0x4, - 0x8, - 0x10, - 0x20, - 0x40, - 0x80, - 0x01, - 0x02, - 0x4, - 0x8, - 0x10, - 0x20, - 0x40, - 0x80}; -#endif - auto minput = *this & bit_mask; - uint8x16_t tmp = vpaddq_u8(minput, minput); - tmp = vpaddq_u8(tmp, tmp); - tmp = vpaddq_u8(tmp, tmp); - return vgetq_lane_u16(vreinterpretq_u16_u8(tmp), 0); - } - simdjson_really_inline bool any() const { return vmaxvq_u8(*this) != 0; } -}; - -// Unsigned bytes -template <> -struct simd8 : base_u8 { - static simdjson_really_inline uint8x16_t splat(uint8_t _value) { - return vmovq_n_u8(_value); - } - static simdjson_really_inline uint8x16_t zero() { return vdupq_n_u8(0); } - static simdjson_really_inline uint8x16_t load(const uint8_t *values) { - return vld1q_u8(values); - } - - simdjson_really_inline simd8(const uint8x16_t _value) - : base_u8(_value) {} - // Zero constructor - simdjson_really_inline simd8() : simd8(zero()) {} - // Array constructor - simdjson_really_inline simd8(const uint8_t values[16]) - : simd8(load(values)) {} - // Splat constructor - simdjson_really_inline simd8(uint8_t _value) : simd8(splat(_value)) {} -// Member-by-member initialization -#ifdef SIMDJSON_REGULAR_VISUAL_STUDIO - simdjson_really_inline simd8(uint8_t v0, - uint8_t v1, - uint8_t v2, - uint8_t v3, - uint8_t v4, - uint8_t v5, - uint8_t v6, - uint8_t v7, - uint8_t v8, - uint8_t v9, - uint8_t v10, - uint8_t v11, - uint8_t v12, - uint8_t v13, - uint8_t v14, - uint8_t v15) - : simd8(make_uint8x16_t(v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15)) {} -#else - simdjson_really_inline simd8(uint8_t v0, - uint8_t v1, - uint8_t v2, - uint8_t v3, - uint8_t v4, - uint8_t v5, - uint8_t v6, - uint8_t v7, - uint8_t v8, - uint8_t v9, - uint8_t v10, - uint8_t v11, - uint8_t v12, - uint8_t v13, - uint8_t v14, - uint8_t v15) - : simd8(uint8x16_t{v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15}) {} -#endif - - // Repeat 16 values as many times as necessary (usually for lookup tables) - simdjson_really_inline static simd8 repeat_16(uint8_t v0, - uint8_t v1, - uint8_t v2, - uint8_t v3, - uint8_t v4, - uint8_t v5, - uint8_t v6, - uint8_t v7, - uint8_t v8, - uint8_t v9, - uint8_t v10, - uint8_t v11, - uint8_t v12, - uint8_t v13, - uint8_t v14, - uint8_t v15) { - return simd8(v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15); - } - - // Store to array - simdjson_really_inline void store(uint8_t dst[16]) const { - return vst1q_u8(dst, *this); - } - - // Saturated math - simdjson_really_inline simd8 saturating_add( - const simd8 other) const { - return vqaddq_u8(*this, other); - } - simdjson_really_inline simd8 saturating_sub( - const simd8 other) const { - return vqsubq_u8(*this, other); - } - - // Addition/subtraction are the same for signed and unsigned - simdjson_really_inline simd8 operator+( - const simd8 other) const { - return vaddq_u8(*this, other); - } - simdjson_really_inline simd8 operator-( - const simd8 other) const { - return vsubq_u8(*this, other); - } - simdjson_really_inline simd8 &operator+=( - const simd8 other) { - *this = *this + other; - return *this; - } - simdjson_really_inline simd8 &operator-=( - const simd8 other) { - *this = *this - other; - return *this; - } - - // Order-specific operations - simdjson_really_inline uint8_t max_val() const { return vmaxvq_u8(*this); } - simdjson_really_inline uint8_t min_val() const { return vminvq_u8(*this); } - simdjson_really_inline simd8 max_val( - const simd8 other) const { - return vmaxq_u8(*this, other); - } - simdjson_really_inline simd8 min_val( - const simd8 other) const { - return vminq_u8(*this, other); - } - simdjson_really_inline simd8 operator<=( - const simd8 other) const { - return vcleq_u8(*this, other); - } - simdjson_really_inline simd8 operator>=( - const simd8 other) const { - return vcgeq_u8(*this, other); - } - simdjson_really_inline simd8 operator<( - const simd8 other) const { - return vcltq_u8(*this, other); - } - simdjson_really_inline simd8 operator>( - const simd8 other) const { - return vcgtq_u8(*this, other); - } - // Same as >, but instead of guaranteeing all 1's == true, false = 0 and - // true = nonzero. For ARM, returns all 1's. - simdjson_really_inline simd8 gt_bits( - const simd8 other) const { - return simd8(*this > other); - } - // Same as <, but instead of guaranteeing all 1's == true, false = 0 and - // true = nonzero. For ARM, returns all 1's. - simdjson_really_inline simd8 lt_bits( - const simd8 other) const { - return simd8(*this < other); - } - - // Bit-specific operations - simdjson_really_inline simd8 any_bits_set(simd8 bits) const { - return vtstq_u8(*this, bits); - } - simdjson_really_inline bool any_bits_set_anywhere() const { - return this->max_val() != 0; - } - simdjson_really_inline bool any_bits_set_anywhere( - simd8 bits) const { - return (*this & bits).any_bits_set_anywhere(); - } - template - simdjson_really_inline simd8 shr() const { - return vshrq_n_u8(*this, N); - } - template - simdjson_really_inline simd8 shl() const { - return vshlq_n_u8(*this, N); - } - - // Perform a lookup assuming the value is between 0 and 16 (undefined - // behavior for out of range values) - template - simdjson_really_inline simd8 lookup_16(simd8 lookup_table) const { - return lookup_table.apply_lookup_16_to(*this); - } - - - // Copies to 'output" all bytes corresponding to a 0 in the mask - // (interpreted as a bitset). - // Passing a 0 value for mask would be equivalent to writing out every byte - // to output. - // Only the first 16 - count_ones(mask) bytes of the result are significant - // but 16 bytes - // get written. - // Design consideration: it seems like a function with the - // signature simd8 compress(uint16_t mask) would be - // sensible, but the AVX ISA makes this kind of approach difficult. - template - simdjson_really_inline void compress(uint16_t mask, L *output) const { - using internal::thintable_epi8; - using internal::BitsSetTable256mul2; - using internal::pshufb_combine_table; - // this particular implementation was inspired by work done by - // @animetosho - // we do it in two steps, first 8 bytes and then second 8 bytes - uint8_t mask1 = uint8_t(mask); // least significant 8 bits - uint8_t mask2 = uint8_t(mask >> 8); // most significant 8 bits - // next line just loads the 64-bit values thintable_epi8[mask1] and - // thintable_epi8[mask2] into a 128-bit register, using only - // two instructions on most compilers. - uint64x2_t shufmask64 = {thintable_epi8[mask1], thintable_epi8[mask2]}; - uint8x16_t shufmask = vreinterpretq_u8_u64(shufmask64); -// we increment by 0x08 the second half of the mask -#ifdef SIMDJSON_REGULAR_VISUAL_STUDIO - uint8x16_t inc = make_uint8x16_t(0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0x08, - 0x08, - 0x08, - 0x08, - 0x08, - 0x08, - 0x08, - 0x08); -#else - uint8x16_t inc = {0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0x08, - 0x08, - 0x08, - 0x08, - 0x08, - 0x08, - 0x08, - 0x08}; -#endif - shufmask = vaddq_u8(shufmask, inc); - // this is the version "nearly pruned" - uint8x16_t pruned = vqtbl1q_u8(*this, shufmask); - // we still need to put the two halves together. - // we compute the popcount of the first half: - int pop1 = BitsSetTable256mul2[mask1]; - // then load the corresponding mask, what it does is to write - // only the first pop1 bytes from the first 8 bytes, and then - // it fills in with the bytes from the second 8 bytes + some filling - // at the end. - uint8x16_t compactmask = vld1q_u8( - reinterpret_cast(pshufb_combine_table + pop1 * 8)); - uint8x16_t answer = vqtbl1q_u8(pruned, compactmask); - vst1q_u8(reinterpret_cast(output), answer); - } - - // Copies all bytes corresponding to a 0 in the low half of the mask - // (interpreted as a - // bitset) to output1, then those corresponding to a 0 in the high half to - // output2. - template - simdjson_really_inline void compress_halves(uint16_t mask, - L *output1, - L *output2) const { - using internal::thintable_epi8; - uint8_t mask1 = uint8_t(mask); // least significant 8 bits - uint8_t mask2 = uint8_t(mask >> 8); // most significant 8 bits - uint8x8_t compactmask1 = vcreate_u8(thintable_epi8[mask1]); - uint8x8_t compactmask2 = vcreate_u8(thintable_epi8[mask2]); -// we increment by 0x08 the second half of the mask -#ifdef SIMDJSON_REGULAR_VISUAL_STUDIO - uint8x8_t inc = - make_uint8x8_t(0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08); -#else - uint8x8_t inc = {0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08}; -#endif - compactmask2 = vadd_u8(compactmask2, inc); - // store each result (with the second store possibly overlapping the - // first) - vst1_u8((uint8_t *)output1, vqtbl1_u8(*this, compactmask1)); - vst1_u8((uint8_t *)output2, vqtbl1_u8(*this, compactmask2)); - } - - template - simdjson_really_inline simd8 lookup_16(L replace0, - L replace1, - L replace2, - L replace3, - L replace4, - L replace5, - L replace6, - L replace7, - L replace8, - L replace9, - L replace10, - L replace11, - L replace12, - L replace13, - L replace14, - L replace15) const { - return lookup_16(simd8::repeat_16(replace0, - replace1, - replace2, - replace3, - replace4, - replace5, - replace6, - replace7, - replace8, - replace9, - replace10, - replace11, - replace12, - replace13, - replace14, - replace15)); - } - - template - simdjson_really_inline simd8 apply_lookup_16_to( - const simd8 original) { - return vqtbl1q_u8(*this, simd8(original)); - } -}; - -// Signed bytes -template <> -struct simd8 { - int8x16_t value; - - static simdjson_really_inline simd8 splat(int8_t _value) { - return vmovq_n_s8(_value); - } - static simdjson_really_inline simd8 zero() { return vdupq_n_s8(0); } - static simdjson_really_inline simd8 load(const int8_t values[16]) { - return vld1q_s8(values); - } - - // Conversion from/to SIMD register - simdjson_really_inline simd8(const int8x16_t _value) : value{_value} {} - simdjson_really_inline operator const int8x16_t &() const { - return this->value; - } - simdjson_really_inline operator int8x16_t &() { return this->value; } - - // Zero constructor - simdjson_really_inline simd8() : simd8(zero()) {} - // Splat constructor - simdjson_really_inline simd8(int8_t _value) : simd8(splat(_value)) {} - // Array constructor - simdjson_really_inline simd8(const int8_t *values) : simd8(load(values)) {} -// Member-by-member initialization -#ifdef SIMDJSON_REGULAR_VISUAL_STUDIO - simdjson_really_inline simd8(int8_t v0, - int8_t v1, - int8_t v2, - int8_t v3, - int8_t v4, - int8_t v5, - int8_t v6, - int8_t v7, - int8_t v8, - int8_t v9, - int8_t v10, - int8_t v11, - int8_t v12, - int8_t v13, - int8_t v14, - int8_t v15) - : simd8(make_int8x16_t(v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15)) {} -#else - simdjson_really_inline simd8(int8_t v0, - int8_t v1, - int8_t v2, - int8_t v3, - int8_t v4, - int8_t v5, - int8_t v6, - int8_t v7, - int8_t v8, - int8_t v9, - int8_t v10, - int8_t v11, - int8_t v12, - int8_t v13, - int8_t v14, - int8_t v15) - : simd8(int8x16_t{v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15}) {} -#endif - // Repeat 16 values as many times as necessary (usually for lookup tables) - simdjson_really_inline static simd8 repeat_16(int8_t v0, - int8_t v1, - int8_t v2, - int8_t v3, - int8_t v4, - int8_t v5, - int8_t v6, - int8_t v7, - int8_t v8, - int8_t v9, - int8_t v10, - int8_t v11, - int8_t v12, - int8_t v13, - int8_t v14, - int8_t v15) { - return simd8(v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15); - } - - // Store to array - simdjson_really_inline void store(int8_t dst[16]) const { - return vst1q_s8(dst, *this); - } - -// Explicit conversion to/from unsigned -// -// Under Visual Studio/ARM64 uint8x16_t and int8x16_t are apparently the same -// type. -// In theory, we could check this occurrence with std::same_as and -// std::enabled_if but it is C++14 -// and relatively ugly and hard to read. -#ifndef SIMDJSON_REGULAR_VISUAL_STUDIO - simdjson_really_inline explicit simd8(const uint8x16_t other) - : simd8(vreinterpretq_s8_u8(other)) {} -#endif - simdjson_really_inline explicit operator simd8() const { - return vreinterpretq_u8_s8(this->value); - } - - // Math - simdjson_really_inline simd8 operator+( - const simd8 other) const { - return vaddq_s8(*this, other); - } - simdjson_really_inline simd8 operator-( - const simd8 other) const { - return vsubq_s8(*this, other); - } - simdjson_really_inline simd8 &operator+=( - const simd8 other) { - *this = *this + other; - return *this; - } - simdjson_really_inline simd8 &operator-=( - const simd8 other) { - *this = *this - other; - return *this; - } - - // Order-sensitive comparisons - simdjson_really_inline simd8 max_val( - const simd8 other) const { - return vmaxq_s8(*this, other); - } - simdjson_really_inline simd8 min_val( - const simd8 other) const { - return vminq_s8(*this, other); - } - simdjson_really_inline simd8 operator>( - const simd8 other) const { - return vcgtq_s8(*this, other); - } - simdjson_really_inline simd8 operator<( - const simd8 other) const { - return vcltq_s8(*this, other); - } - simdjson_really_inline simd8 operator==( - const simd8 other) const { - return vceqq_s8(*this, other); - } - - template - simdjson_really_inline simd8 prev( - const simd8 prev_chunk) const { - return vextq_s8(prev_chunk, *this, 16 - N); - } - - // Perform a lookup assuming no value is larger than 16 - template - simdjson_really_inline simd8 lookup_16(simd8 lookup_table) const { - return lookup_table.apply_lookup_16_to(*this); - } - template - simdjson_really_inline simd8 lookup_16(L replace0, - L replace1, - L replace2, - L replace3, - L replace4, - L replace5, - L replace6, - L replace7, - L replace8, - L replace9, - L replace10, - L replace11, - L replace12, - L replace13, - L replace14, - L replace15) const { - return lookup_16(simd8::repeat_16(replace0, - replace1, - replace2, - replace3, - replace4, - replace5, - replace6, - replace7, - replace8, - replace9, - replace10, - replace11, - replace12, - replace13, - replace14, - replace15)); - } - - template - simdjson_really_inline simd8 apply_lookup_16_to( - const simd8 original) { - return vqtbl1q_s8(*this, simd8(original)); - } -}; - -template -struct simd8x64 { - static constexpr int NUM_CHUNKS = 64 / sizeof(simd8); - static_assert(NUM_CHUNKS == 4, - "ARM kernel should use four registers per 64-byte block."); - const simd8 chunks[NUM_CHUNKS]; - - simd8x64(const simd8x64 &o) = delete; // no copy allowed - simd8x64 &operator=(const simd8 &other) = - delete; // no assignment allowed - simd8x64() = delete; // no default constructor allowed - - simdjson_really_inline simd8x64(const simd8 chunk0, - const simd8 chunk1, - const simd8 chunk2, - const simd8 chunk3) - : chunks{chunk0, chunk1, chunk2, chunk3} {} - simdjson_really_inline simd8x64(const T ptr[64]) - : chunks{simd8::load(ptr), - simd8::load(ptr + 16), - simd8::load(ptr + 32), - simd8::load(ptr + 48)} {} - - simdjson_really_inline void store(T ptr[64]) const { - this->chunks[0].store(ptr + sizeof(simd8) * 0); - this->chunks[1].store(ptr + sizeof(simd8) * 1); - this->chunks[2].store(ptr + sizeof(simd8) * 2); - this->chunks[3].store(ptr + sizeof(simd8) * 3); - } - - simdjson_really_inline simd8 reduce_or() const { - return (this->chunks[0] | this->chunks[1]) | - (this->chunks[2] | this->chunks[3]); - } - - - simdjson_really_inline uint64_t compress(uint64_t mask, T *output) const { - uint64_t popcounts = - vget_lane_u64(vreinterpret_u64_u8(vcnt_u8(vcreate_u8(~mask))), 0); - // compute the prefix sum of the popcounts of each byte - uint64_t offsets = popcounts * 0x0101010101010101; - this->chunks[0].compress_halves( - uint16_t(mask), output, &output[popcounts & 0xFF]); - this->chunks[1].compress_halves(uint16_t(mask >> 16), - &output[(offsets >> 8) & 0xFF], - &output[(offsets >> 16) & 0xFF]); - this->chunks[2].compress_halves(uint16_t(mask >> 32), - &output[(offsets >> 24) & 0xFF], - &output[(offsets >> 32) & 0xFF]); - this->chunks[3].compress_halves(uint16_t(mask >> 48), - &output[(offsets >> 40) & 0xFF], - &output[(offsets >> 48) & 0xFF]); - return offsets >> 56; - } - - simdjson_really_inline uint64_t to_bitmask() const { -#ifdef SIMDJSON_REGULAR_VISUAL_STUDIO - const uint8x16_t bit_mask = make_uint8x16_t(0x01, - 0x02, - 0x4, - 0x8, - 0x10, - 0x20, - 0x40, - 0x80, - 0x01, - 0x02, - 0x4, - 0x8, - 0x10, - 0x20, - 0x40, - 0x80); -#else - const uint8x16_t bit_mask = {0x01, - 0x02, - 0x4, - 0x8, - 0x10, - 0x20, - 0x40, - 0x80, - 0x01, - 0x02, - 0x4, - 0x8, - 0x10, - 0x20, - 0x40, - 0x80}; -#endif - // Add each of the elements next to each other, successively, to stuff - // each 8 byte mask into one. - uint8x16_t sum0 = - vpaddq_u8(this->chunks[0] & bit_mask, this->chunks[1] & bit_mask); - uint8x16_t sum1 = - vpaddq_u8(this->chunks[2] & bit_mask, this->chunks[3] & bit_mask); - sum0 = vpaddq_u8(sum0, sum1); - sum0 = vpaddq_u8(sum0, sum0); - return vgetq_lane_u64(vreinterpretq_u64_u8(sum0), 0); - } - - simdjson_really_inline uint64_t eq(const T m) const { - const simd8 mask = simd8::splat(m); - return simd8x64(this->chunks[0] == mask, - this->chunks[1] == mask, - this->chunks[2] == mask, - this->chunks[3] == mask) - .to_bitmask(); - } - - simdjson_really_inline uint64_t lteq(const T m) const { - const simd8 mask = simd8::splat(m); - return simd8x64(this->chunks[0] <= mask, - this->chunks[1] <= mask, - this->chunks[2] <= mask, - this->chunks[3] <= mask) - .to_bitmask(); - } -}; // struct simd8x64 - -} // namespace simd -} // unnamed namespace -} // namespace arm64 -} // namespace simdjson - -#endif // SIMDJSON_ARM64_SIMD_H -/* end file include/simdjson/arm64/simd.h */ -/* begin file include/simdjson/generic/jsoncharutils.h */ - -namespace simdjson { -namespace arm64 { -namespace { -namespace jsoncharutils { - -// return non-zero if not a structural or whitespace char -// zero otherwise -simdjson_really_inline uint32_t is_not_structural_or_whitespace(uint8_t c) { - return internal::structural_or_whitespace_negated[c]; -} - -simdjson_really_inline uint32_t is_structural_or_whitespace(uint8_t c) { - return internal::structural_or_whitespace[c]; -} - -// returns a value with the high 16 bits set if not valid -// otherwise returns the conversion of the 4 hex digits at src into the bottom -// 16 bits of the 32-bit return register -// -// see -// https://lemire.me/blog/2019/04/17/parsing-short-hexadecimal-strings-efficiently/ -static inline uint32_t hex_to_u32_nocheck( - const uint8_t *src) { // strictly speaking, static inline is a C-ism - uint32_t v1 = internal::digit_to_val32[630 + src[0]]; - uint32_t v2 = internal::digit_to_val32[420 + src[1]]; - uint32_t v3 = internal::digit_to_val32[210 + src[2]]; - uint32_t v4 = internal::digit_to_val32[0 + src[3]]; - return v1 | v2 | v3 | v4; -} - -// given a code point cp, writes to c -// the utf-8 code, outputting the length in -// bytes, if the length is zero, the code point -// is invalid -// -// This can possibly be made faster using pdep -// and clz and table lookups, but JSON documents -// have few escaped code points, and the following -// function looks cheap. -// -// Note: we assume that surrogates are treated separately -// -simdjson_really_inline size_t codepoint_to_utf8(uint32_t cp, uint8_t *c) { - if (cp <= 0x7F) { - c[0] = uint8_t(cp); - return 1; // ascii - } - if (cp <= 0x7FF) { - c[0] = uint8_t((cp >> 6) + 192); - c[1] = uint8_t((cp & 63) + 128); - return 2; // universal plane - // Surrogates are treated elsewhere... - //} //else if (0xd800 <= cp && cp <= 0xdfff) { - // return 0; // surrogates // could put assert here - } else if (cp <= 0xFFFF) { - c[0] = uint8_t((cp >> 12) + 224); - c[1] = uint8_t(((cp >> 6) & 63) + 128); - c[2] = uint8_t((cp & 63) + 128); - return 3; - } else if (cp <= - 0x10FFFF) { // if you know you have a valid code point, this - // is not needed - c[0] = uint8_t((cp >> 18) + 240); - c[1] = uint8_t(((cp >> 12) & 63) + 128); - c[2] = uint8_t(((cp >> 6) & 63) + 128); - c[3] = uint8_t((cp & 63) + 128); - return 4; - } - // will return 0 when the code point was too large. - return 0; // bad r -} - -#ifdef SIMDJSON_IS_32BITS // _umul128 for x86, arm -// this is a slow emulation routine for 32-bit -// -static simdjson_really_inline uint64_t __emulu(uint32_t x, uint32_t y) { - return x * (uint64_t)y; -} -static simdjson_really_inline uint64_t _umul128(uint64_t ab, - uint64_t cd, - uint64_t *hi) { - uint64_t ad = __emulu((uint32_t)(ab >> 32), (uint32_t)cd); - uint64_t bd = __emulu((uint32_t)ab, (uint32_t)cd); - uint64_t adbc = ad + __emulu((uint32_t)ab, (uint32_t)(cd >> 32)); - uint64_t adbc_carry = !!(adbc < ad); - uint64_t lo = bd + (adbc << 32); - *hi = __emulu((uint32_t)(ab >> 32), (uint32_t)(cd >> 32)) + (adbc >> 32) + - (adbc_carry << 32) + !!(lo < bd); - return lo; -} -#endif - -using internal::value128; - -simdjson_really_inline value128 full_multiplication(uint64_t value1, - uint64_t value2) { - value128 answer; -#if defined(SIMDJSON_REGULAR_VISUAL_STUDIO) || defined(SIMDJSON_IS_32BITS) -#ifdef _M_ARM64 - // ARM64 has native support for 64-bit multiplications, no need to emultate - answer.high = __umulh(value1, value2); - answer.low = value1 * value2; -#else - answer.low = _umul128( - value1, value2, &answer.high); // _umul128 not available on ARM64 -#endif // _M_ARM64 -#else // defined(SIMDJSON_REGULAR_VISUAL_STUDIO) || defined(SIMDJSON_IS_32BITS) - __uint128_t r = (static_cast<__uint128_t>(value1)) * value2; - answer.low = uint64_t(r); - answer.high = uint64_t(r >> 64); -#endif - return answer; -} - -} // namespace jsoncharutils -} // unnamed namespace -} // namespace arm64 -} // namespace simdjson -/* end file include/simdjson/generic/jsoncharutils.h */ -/* begin file include/simdjson/generic/atomparsing.h */ -namespace simdjson { -namespace arm64 { -namespace { -/// @private -namespace atomparsing { - -// The string_to_uint32 is exclusively used to map literal strings to 32-bit -// values. -// We use memcpy instead of a pointer cast to avoid undefined behaviors since we -// cannot -// be certain that the character pointer will be properly aligned. -// You might think that using memcpy makes this function expensive, but you'd be -// wrong. -// All decent optimizing compilers (GCC, clang, Visual Studio) will compile -// string_to_uint32("false"); -// to the compile-time constant 1936482662. -simdjson_really_inline uint32_t string_to_uint32(const char *str) { - uint32_t val; - std::memcpy(&val, str, sizeof(uint32_t)); - return val; -} - - -// Again in str4ncmp we use a memcpy to avoid undefined behavior. The memcpy may -// appear expensive. -// Yet all decent optimizing compilers will compile memcpy to a single -// instruction, just about. -simdjson_warn_unused simdjson_really_inline uint32_t -str4ncmp(const uint8_t *src, const char *atom) { - uint32_t - srcval; // we want to avoid unaligned 32-bit loads (undefined in C/C++) - static_assert(sizeof(uint32_t) <= SIMDJSON_PADDING, - "SIMDJSON_PADDING must be larger than 4 bytes"); - std::memcpy(&srcval, src, sizeof(uint32_t)); - return srcval ^ string_to_uint32(atom); -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_true_atom( - const uint8_t *src) { - return (str4ncmp(src, "true") | - jsoncharutils::is_not_structural_or_whitespace(src[4])) == 0; -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_true_atom( - const uint8_t *src, size_t len) { - if (len > 4) { - return is_valid_true_atom(src); - } else if (len == 4) { - return !str4ncmp(src, "true"); - } else { - return false; - } -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_false_atom( - const uint8_t *src) { - return (str4ncmp(src + 1, "alse") | - jsoncharutils::is_not_structural_or_whitespace(src[5])) == 0; -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_false_atom( - const uint8_t *src, size_t len) { - if (len > 5) { - return is_valid_false_atom(src); - } else if (len == 5) { - return !str4ncmp(src + 1, "alse"); - } else { - return false; - } -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_null_atom( - const uint8_t *src) { - return (str4ncmp(src, "null") | - jsoncharutils::is_not_structural_or_whitespace(src[4])) == 0; -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_null_atom( - const uint8_t *src, size_t len) { - if (len > 4) { - return is_valid_null_atom(src); - } else if (len == 4) { - return !str4ncmp(src, "null"); - } else { - return false; - } -} - -} // namespace atomparsing -} // unnamed namespace -} // namespace arm64 -} // namespace simdjson -/* end file include/simdjson/generic/atomparsing.h */ -/* begin file include/simdjson/arm64/stringparsing.h */ -#ifndef SIMDJSON_ARM64_STRINGPARSING_H -#define SIMDJSON_ARM64_STRINGPARSING_H - - -namespace simdjson { -namespace arm64 { -namespace { - -using namespace simd; - -// Holds backslashes and quotes locations. -struct backslash_and_quote { - public: - static constexpr uint32_t BYTES_PROCESSED = 32; - simdjson_really_inline static backslash_and_quote copy_and_find( - const uint8_t *src, uint8_t *dst); - - simdjson_really_inline bool has_quote_first() { - return ((bs_bits - 1) & quote_bits) != 0; - } - simdjson_really_inline bool has_backslash() { return bs_bits != 0; } - simdjson_really_inline int quote_index() { - return trailing_zeroes(quote_bits); - } - simdjson_really_inline int backslash_index() { - return trailing_zeroes(bs_bits); - } - - uint32_t bs_bits; - uint32_t quote_bits; -}; // struct backslash_and_quote - -simdjson_really_inline backslash_and_quote -backslash_and_quote::copy_and_find(const uint8_t *src, uint8_t *dst) { - // this can read up to 31 bytes beyond the buffer size, but we require - // SIMDJSON_PADDING of padding - static_assert(SIMDJSON_PADDING >= (BYTES_PROCESSED - 1), - "backslash and quote finder must process fewer than " - "SIMDJSON_PADDING bytes"); - simd8 v0(src); - simd8 v1(src + sizeof(v0)); - v0.store(dst); - v1.store(dst + sizeof(v0)); - - // Getting a 64-bit bitmask is much cheaper than multiple 16-bit bitmasks on - // ARM; therefore, we - // smash them together into a 64-byte mask and get the bitmask from there. - uint64_t bs_and_quote = - simd8x64(v0 == '\\', v1 == '\\', v0 == '"', v1 == '"') - .to_bitmask(); - return { - uint32_t(bs_and_quote), // bs_bits - uint32_t(bs_and_quote >> 32) // quote_bits - }; -} - -} // unnamed namespace -} // namespace arm64 -} // namespace simdjson - -/* begin file include/simdjson/generic/stringparsing.h */ -// This file contains the common code every implementation uses -// It is intended to be included multiple times and compiled multiple times - -namespace simdjson { -namespace arm64 { -namespace { -/// @private -namespace stringparsing { - -// begin copypasta -// These chars yield themselves: " \ / -// b -> backspace, f -> formfeed, n -> newline, r -> cr, t -> horizontal tab -// u not handled in this table as it's complex -static const uint8_t escape_map[256] = { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x0. - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0x22, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x2f, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x4. - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x5c, 0, 0, 0, // 0x5. - 0, 0, 0x08, 0, 0, 0, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0x0a, 0, // 0x6. - 0, 0, 0x0d, 0, 0x09, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x7. - - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -}; - -// handle a unicode codepoint -// write appropriate values into dest -// src will advance 6 bytes or 12 bytes -// dest will advance a variable amount (return via pointer) -// return true if the unicode codepoint was valid -// We work in little-endian then swap at write time -simdjson_warn_unused simdjson_really_inline bool handle_unicode_codepoint( - const uint8_t **src_ptr, uint8_t **dst_ptr) { - // jsoncharutils::hex_to_u32_nocheck fills high 16 bits of the return value - // with 1s if the - // conversion isn't valid; we defer the check for this to inside the - // multilingual plane check - uint32_t code_point = jsoncharutils::hex_to_u32_nocheck(*src_ptr + 2); - *src_ptr += 6; - // check for low surrogate for characters outside the Basic - // Multilingual Plane. - if (code_point >= 0xd800 && code_point < 0xdc00) { - if (((*src_ptr)[0] != '\\') || (*src_ptr)[1] != 'u') { - return false; - } - uint32_t code_point_2 = jsoncharutils::hex_to_u32_nocheck(*src_ptr + 2); - - // if the first code point is invalid we will get here, as we will go - // past - // the check for being outside the Basic Multilingual plane. If we don't - // find a \u immediately afterwards we fail out anyhow, but if we do, - // this check catches both the case of the first code point being - // invalid - // or the second code point being invalid. - if ((code_point | code_point_2) >> 16) { - return false; - } - - code_point = - (((code_point - 0xd800) << 10) | (code_point_2 - 0xdc00)) + 0x10000; - *src_ptr += 6; - } - size_t offset = jsoncharutils::codepoint_to_utf8(code_point, *dst_ptr); - *dst_ptr += offset; - return offset > 0; -} - -/** - * Unescape a string from src to dst, stopping at a final unescaped quote. E.g., - * if src points at 'joe"', then - * dst needs to have four free bytes. - */ -simdjson_warn_unused simdjson_really_inline uint8_t *parse_string( - const uint8_t *src, uint8_t *dst) { - while (1) { - // Copy the next n bytes, and find the backslash and quote in them. - auto bs_quote = backslash_and_quote::copy_and_find(src, dst); - // If the next thing is the end quote, copy and return - if (bs_quote.has_quote_first()) { - // we encountered quotes first. Move dst to point to quotes and exit - return dst + bs_quote.quote_index(); - } - if (bs_quote.has_backslash()) { - /* find out where the backspace is */ - auto bs_dist = bs_quote.backslash_index(); - uint8_t escape_char = src[bs_dist + 1]; - /* we encountered backslash first. Handle backslash */ - if (escape_char == 'u') { - /* move src/dst up to the start; they will be further adjusted - within the unicode codepoint handling code. */ - src += bs_dist; - dst += bs_dist; - if (!handle_unicode_codepoint(&src, &dst)) { - return nullptr; - } - } else { - /* simple 1:1 conversion. Will eat bs_dist+2 characters in input - * and - * write bs_dist+1 characters to output - * note this may reach beyond the part of the buffer we've - * actually - * seen. I think this is ok */ - uint8_t escape_result = escape_map[escape_char]; - if (escape_result == 0u) { - return nullptr; /* bogus escape value is an error */ - } - dst[bs_dist] = escape_result; - src += bs_dist + 2; - dst += bs_dist + 1; - } - } else { - /* they are the same. Since they can't co-occur, it means we - * encountered neither. */ - src += backslash_and_quote::BYTES_PROCESSED; - dst += backslash_and_quote::BYTES_PROCESSED; - } - } - /* can't be reached */ - return nullptr; -} - -simdjson_unused simdjson_warn_unused simdjson_really_inline error_code -parse_string_to_buffer(const uint8_t *src, - uint8_t *¤t_string_buf_loc, - std::string_view &s) { - if (*(src++) != '"') { - return STRING_ERROR; - } - auto end = stringparsing::parse_string(src, current_string_buf_loc); - if (!end) { - return STRING_ERROR; - } - s = std::string_view(reinterpret_cast(current_string_buf_loc), - end - current_string_buf_loc); - current_string_buf_loc = end; - return SUCCESS; -} - -} // namespace stringparsing -} // unnamed namespace -} // namespace arm64 -} // namespace simdjson -/* end file include/simdjson/generic/stringparsing.h */ - -#endif // SIMDJSON_ARM64_STRINGPARSING_H -/* end file include/simdjson/arm64/stringparsing.h */ -/* begin file include/simdjson/arm64/numberparsing.h */ -#ifndef SIMDJSON_ARM64_NUMBERPARSING_H -#define SIMDJSON_ARM64_NUMBERPARSING_H - -namespace simdjson { -namespace arm64 { -namespace { - -// we don't have SSE, so let us use a scalar function -// credit: https://johnnylee-sde.github.io/Fast-numeric-string-to-int/ -static simdjson_really_inline uint32_t -parse_eight_digits_unrolled(const uint8_t *chars) { - uint64_t val; - std::memcpy(&val, chars, sizeof(uint64_t)); - val = (val & 0x0F0F0F0F0F0F0F0F) * 2561 >> 8; - val = (val & 0x00FF00FF00FF00FF) * 6553601 >> 16; - return uint32_t((val & 0x0000FFFF0000FFFF) * 42949672960001 >> 32); -} - -} // unnamed namespace -} // namespace arm64 -} // namespace simdjson - -#define SIMDJSON_SWAR_NUMBER_PARSING 1 - -/* begin file include/simdjson/generic/numberparsing.h */ -#include - -namespace simdjson { -namespace arm64 { - -namespace ondemand { -/** - * The type of a JSON number - */ -enum class number_type { - floating_point_number = 1, /// a binary64 number - signed_integer, /// a signed integer that fits in a 64-bit word using two's - /// complement - unsigned_integer /// a positive integer larger or equal to 1<<63 -}; -} - -namespace { -/// @private -namespace numberparsing { - - -#ifdef JSON_TEST_NUMBERS -#define INVALID_NUMBER(SRC) (found_invalid_number((SRC)), NUMBER_ERROR) -#define WRITE_INTEGER(VALUE, SRC, WRITER) \ - (found_integer((VALUE), (SRC)), (WRITER).append_s64((VALUE))) -#define WRITE_UNSIGNED(VALUE, SRC, WRITER) \ - (found_unsigned_integer((VALUE), (SRC)), (WRITER).append_u64((VALUE))) -#define WRITE_DOUBLE(VALUE, SRC, WRITER) \ - (found_float((VALUE), (SRC)), (WRITER).append_double((VALUE))) -#else -#define INVALID_NUMBER(SRC) (NUMBER_ERROR) -#define WRITE_INTEGER(VALUE, SRC, WRITER) (WRITER).append_s64((VALUE)) -#define WRITE_UNSIGNED(VALUE, SRC, WRITER) (WRITER).append_u64((VALUE)) -#define WRITE_DOUBLE(VALUE, SRC, WRITER) (WRITER).append_double((VALUE)) -#endif - -namespace { -// Convert a mantissa, an exponent and a sign bit into an ieee64 double. -// The real_exponent needs to be in [0, 2046] (technically real_exponent = 2047 -// would be acceptable). -// The mantissa should be in [0,1<<53). The bit at index (1ULL << 52) while be -// zeroed. -simdjson_really_inline double to_double(uint64_t mantissa, - uint64_t real_exponent, - bool negative) { - double d; - mantissa &= ~(1ULL << 52); - mantissa |= real_exponent << 52; - mantissa |= ((static_cast(negative)) << 63); - std::memcpy(&d, &mantissa, sizeof(d)); - return d; -} -} -// Attempts to compute i * 10^(power) exactly; and if "negative" is -// true, negate the result. -// This function will only work in some cases, when it does not work, success is -// set to false. This should work *most of the time* (like 99% of the time). -// We assume that power is in the [smallest_power, -// largest_power] interval: the caller is responsible for this check. -simdjson_really_inline bool compute_float_64(int64_t power, - uint64_t i, - bool negative, - double &d) { -// we start with a fast path -// It was described in -// Clinger WD. How to read floating point numbers accurately. -// ACM SIGPLAN Notices. 1990 -#ifndef FLT_EVAL_METHOD -#error "FLT_EVAL_METHOD should be defined, please include cfloat." -#endif -#if (FLT_EVAL_METHOD != 1) && (FLT_EVAL_METHOD != 0) - // We cannot be certain that x/y is rounded to nearest. - if (0 <= power && power <= 22 && i <= 9007199254740991) { -#else - if (-22 <= power && power <= 22 && i <= 9007199254740991) { -#endif - // convert the integer into a double. This is lossless since - // 0 <= i <= 2^53 - 1. - d = double(i); - // - // The general idea is as follows. - // If 0 <= s < 2^53 and if 10^0 <= p <= 10^22 then - // 1) Both s and p can be represented exactly as 64-bit floating-point - // values - // (binary64). - // 2) Because s and p can be represented exactly as floating-point - // values, - // then s * p - // and s / p will produce correctly rounded values. - // - if (power < 0) { - d = d / simdjson::internal::power_of_ten[-power]; - } else { - d = d * simdjson::internal::power_of_ten[power]; - } - if (negative) { - d = -d; - } - return true; - } - // When 22 < power && power < 22 + 16, we could - // hope for another, secondary fast path. It was - // described by David M. Gay in "Correctly rounded - // binary-decimal and decimal-binary conversions." (1990) - // If you need to compute i * 10^(22 + x) for x < 16, - // first compute i * 10^x, if you know that result is exact - // (e.g., when i * 10^x < 2^53), - // then you can still proceed and do (i * 10^x) * 10^22. - // Is this worth your time? - // You need 22 < power *and* power < 22 + 16 *and* (i * 10^(x-22) < 2^53) - // for this second fast path to work. - // If you you have 22 < power *and* power < 22 + 16, and then you - // optimistically compute "i * 10^(x-22)", there is still a chance that you - // have wasted your time if i * 10^(x-22) >= 2^53. It makes the use cases of - // this optimization maybe less common than we would like. Source: - // http://www.exploringbinary.com/fast-path-decimal-to-floating-point-conversion/ - // also used in RapidJSON: https://rapidjson.org/strtod_8h_source.html - - // The fast path has now failed, so we are failing back on the slower path. - - // In the slow path, we need to adjust i so that it is > 1<<63 which is - // always - // possible, except if i == 0, so we handle i == 0 separately. - if (i == 0) { - d = 0.0; - return true; - } - - - // The exponent is 1024 + 63 + power - // + floor(log(5**power)/log(2)). - // The 1024 comes from the ieee64 standard. - // The 63 comes from the fact that we use a 64-bit word. - // - // Computing floor(log(5**power)/log(2)) could be - // slow. Instead we use a fast function. - // - // For power in (-400,350), we have that - // (((152170 + 65536) * power ) >> 16); - // is equal to - // floor(log(5**power)/log(2)) + power when power >= 0 - // and it is equal to - // ceil(log(5**-power)/log(2)) + power when power < 0 - // - // The 65536 is (1<<16) and corresponds to - // (65536 * power) >> 16 ---> power - // - // ((152170 * power ) >> 16) is equal to - // floor(log(5**power)/log(2)) - // - // Note that this is not magic: 152170/(1<<16) is - // approximatively equal to log(5)/log(2). - // The 1<<16 value is a power of two; we could use a - // larger power of 2 if we wanted to. - // - int64_t exponent = (((152170 + 65536) * power) >> 16) + 1024 + 63; - - - // We want the most significant bit of i to be 1. Shift if needed. - int lz = leading_zeroes(i); - i <<= lz; - - - // We are going to need to do some 64-bit arithmetic to get a precise - // product. - // We use a table lookup approach. - // It is safe because - // power >= smallest_power - // and power <= largest_power - // We recover the mantissa of the power, it has a leading 1. It is always - // rounded down. - // - // We want the most significant 64 bits of the product. We know - // this will be non-zero because the most significant bit of i is - // 1. - const uint32_t index = - 2 * uint32_t(power - simdjson::internal::smallest_power); - // Optimization: It may be that materializing the index as a variable might - // confuse some compilers and prevent effective complex-addressing loads. - // (Done for code clarity.) - // - // The full_multiplication function computes the 128-bit product of two - // 64-bit words - // with a returned value of type value128 with a "low component" - // corresponding to the - // 64-bit least significant bits of the product and with a "high component" - // corresponding - // to the 64-bit most significant bits of the product. - simdjson::internal::value128 firstproduct = - jsoncharutils::full_multiplication( - i, simdjson::internal::power_of_five_128[index]); - // Both i and power_of_five_128[index] have their most significant bit set - // to 1 which - // implies that the either the most or the second most significant bit of - // the product - // is 1. We pack values in this manner for efficiency reasons: it maximizes - // the use - // we make of the product. It also makes it easy to reason about the - // product: there - // is 0 or 1 leading zero in the product. - - // Unless the least significant 9 bits of the high (64-bit) part of the full - // product are all 1s, then we know that the most significant 55 bits are - // exact and no further work is needed. Having 55 bits is necessary because - // we need 53 bits for the mantissa but we have to have one rounding bit and - // we can waste a bit if the most significant bit of the product is zero. - if ((firstproduct.high & 0x1FF) == 0x1FF) { - // We want to compute i * 5^q, but only care about the top 55 bits at - // most. - // Consider the scenario where q>=0. Then 5^q may not fit in 64-bits. - // Doing - // the full computation is wasteful. So we do what is called a - // "truncated - // multiplication". - // We take the most significant 64-bits, and we put them in - // power_of_five_128[index]. Usually, that's good enough to approximate - // i * 5^q - // to the desired approximation using one multiplication. Sometimes it - // does not suffice. - // Then we store the next most significant 64 bits in - // power_of_five_128[index + 1], and - // then we get a better approximation to i * 5^q. In very rare cases, - // even that - // will not suffice, though it is seemingly very hard to find such a - // scenario. - // - // That's for when q>=0. The logic for q<0 is somewhat similar but it is - // somewhat - // more complicated. - // - // There is an extra layer of complexity in that we need more than 55 - // bits of - // accuracy in the round-to-even scenario. - // - // The full_multiplication function computes the 128-bit product of two - // 64-bit words - // with a returned value of type value128 with a "low component" - // corresponding to the - // 64-bit least significant bits of the product and with a "high - // component" corresponding - // to the 64-bit most significant bits of the product. - simdjson::internal::value128 secondproduct = - jsoncharutils::full_multiplication( - i, simdjson::internal::power_of_five_128[index + 1]); - firstproduct.low += secondproduct.high; - if (secondproduct.high > firstproduct.low) { - firstproduct.high++; - } - // At this point, we might need to add at most one to firstproduct, but - // this - // can only change the value of firstproduct.high if firstproduct.low is - // maximal. - if (simdjson_unlikely(firstproduct.low == 0xFFFFFFFFFFFFFFFF)) { - // This is very unlikely, but if so, we need to do much more work! - return false; - } - } - uint64_t lower = firstproduct.low; - uint64_t upper = firstproduct.high; - // The final mantissa should be 53 bits with a leading 1. - // We shift it so that it occupies 54 bits with a leading 1. - /////// - uint64_t upperbit = upper >> 63; - uint64_t mantissa = upper >> (upperbit + 9); - lz += int(1 ^ upperbit); - - // Here we have mantissa < (1<<54). - int64_t real_exponent = exponent - lz; - if (simdjson_unlikely(real_exponent <= 0)) { // we have a subnormal? - // Here have that real_exponent <= 0 so -real_exponent >= 0 - if (-real_exponent + 1 >= 64) { // if we have more than 64 bits below - // the minimum exponent, you have a - // zero for sure. - d = 0.0; - return true; - } - // next line is safe because -real_exponent + 1 < 0 - mantissa >>= -real_exponent + 1; - // Thankfully, we can't have both "round-to-even" and subnormals because - // "round-to-even" only occurs for powers close to 0. - mantissa += (mantissa & 1); // round up - mantissa >>= 1; - // There is a weird scenario where we don't have a subnormal but just. - // Suppose we start with 2.2250738585072013e-308, we end up - // with 0x3fffffffffffff x 2^-1023-53 which is technically subnormal - // whereas 0x40000000000000 x 2^-1023-53 is normal. Now, we need to - // round - // up 0x3fffffffffffff x 2^-1023-53 and once we do, we are no longer - // subnormal, but we can only know this after rounding. - // So we only declare a subnormal if we are smaller than the threshold. - real_exponent = (mantissa < (uint64_t(1) << 52)) ? 0 : 1; - d = to_double(mantissa, real_exponent, negative); - return true; - } - // We have to round to even. The "to even" part - // is only a problem when we are right in between two floats - // which we guard against. - // If we have lots of trailing zeros, we may fall right between two - // floating-point values. - // - // The round-to-even cases take the form of a number 2m+1 which is in - // (2^53,2^54] - // times a power of two. That is, it is right between a number with binary - // significand - // m and another number with binary significand m+1; and it must be the case - // that it cannot be represented by a float itself. - // - // We must have that w * 10 ^q == (2m+1) * 2^p for some power of two 2^p. - // Recall that 10^q = 5^q * 2^q. - // When q >= 0, we must have that (2m+1) is divible by 5^q, so 5^q <= 2^54. - // We have that - // 5^23 <= 2^54 and it is the last power of five to qualify, so q <= 23. - // When q<0, we have w >= (2m+1) x 5^{-q}. We must have that w<2^{64} so - // (2m+1) x 5^{-q} < 2^{64}. We have that 2m+1>2^{53}. Hence, we must have - // 2^{53} x 5^{-q} < 2^{64}. - // Hence we have 5^{-q} < 2^{11}$ or q>= -4. - // - // We require lower <= 1 and not lower == 0 because we could not prove that - // that lower == 0 is implied; but we could prove that lower <= 1 is a - // necessary and sufficient test. - if (simdjson_unlikely((lower <= 1) && (power >= -4) && (power <= 23) && - ((mantissa & 3) == 1))) { - if ((mantissa << (upperbit + 64 - 53 - 2)) == upper) { - mantissa &= ~1; // flip it so that we do not round up - } - } - - mantissa += mantissa & 1; - mantissa >>= 1; - - // Here we have mantissa < (1<<53), unless there was an overflow - if (mantissa >= (1ULL << 53)) { - ////////// - // This will happen when parsing values such as 7.2057594037927933e+16 - //////// - mantissa = (1ULL << 52); - real_exponent++; - } - mantissa &= ~(1ULL << 52); - // we have to check that real_exponent is in range, otherwise we bail out - if (simdjson_unlikely(real_exponent > 2046)) { - // We have an infinite value!!! We could actually throw an error here if - // we could. - return false; - } - d = to_double(mantissa, real_exponent, negative); - return true; -} - -// We call a fallback floating-point parser that might be slow. Note -// it will accept JSON numbers, but the JSON spec. is more restrictive so -// before you call parse_float_fallback, you need to have validated the input -// string with the JSON grammar. -// It will return an error (false) if the parsed number is infinite. -// The string parsing itself always succeeds. We know that there is at least -// one digit. -static bool parse_float_fallback(const uint8_t *ptr, double *outDouble) { - *outDouble = - simdjson::internal::from_chars(reinterpret_cast(ptr)); - // We do not accept infinite values. - - // Detecting finite values in a portable manner is ridiculously hard, - // ideally - // we would want to do: - // return !std::isfinite(*outDouble); - // but that mysteriously fails under legacy/old libc++ libraries, see - // https://github.com/simdjson/simdjson/issues/1286 - // - // Therefore, fall back to this solution (the extra parens are there - // to handle that max may be a macro on windows). - return !(*outDouble > (std::numeric_limits::max)() || - *outDouble < std::numeric_limits::lowest()); -} -static bool parse_float_fallback(const uint8_t *ptr, - const uint8_t *end_ptr, - double *outDouble) { - *outDouble = - simdjson::internal::from_chars(reinterpret_cast(ptr), - reinterpret_cast(end_ptr)); - // We do not accept infinite values. - - // Detecting finite values in a portable manner is ridiculously hard, - // ideally - // we would want to do: - // return !std::isfinite(*outDouble); - // but that mysteriously fails under legacy/old libc++ libraries, see - // https://github.com/simdjson/simdjson/issues/1286 - // - // Therefore, fall back to this solution (the extra parens are there - // to handle that max may be a macro on windows). - return !(*outDouble > (std::numeric_limits::max)() || - *outDouble < std::numeric_limits::lowest()); -} - -// check quickly whether the next 8 chars are made of digits -// at a glance, it looks better than Mula's -// http://0x80.pl/articles/swar-digits-validate.html -simdjson_really_inline bool is_made_of_eight_digits_fast(const uint8_t *chars) { - uint64_t val; - // this can read up to 7 bytes beyond the buffer size, but we require - // SIMDJSON_PADDING of padding - static_assert(7 <= SIMDJSON_PADDING, - "SIMDJSON_PADDING must be bigger than 7"); - std::memcpy(&val, chars, 8); - // a branchy method might be faster: - // return (( val & 0xF0F0F0F0F0F0F0F0 ) == 0x3030303030303030) - // && (( (val + 0x0606060606060606) & 0xF0F0F0F0F0F0F0F0 ) == - // 0x3030303030303030); - return (((val & 0xF0F0F0F0F0F0F0F0) | - (((val + 0x0606060606060606) & 0xF0F0F0F0F0F0F0F0) >> 4)) == - 0x3333333333333333); -} - -template -error_code slow_float_parsing(simdjson_unused const uint8_t *src, W writer) { - double d; - if (parse_float_fallback(src, &d)) { - writer.append_double(d); - return SUCCESS; - } - return INVALID_NUMBER(src); -} - -template -SIMDJSON_NO_SANITIZE_UNDEFINED // We deliberately allow overflow here and check - // later - simdjson_really_inline bool - parse_digit(const uint8_t c, I &i) { - const uint8_t digit = static_cast(c - '0'); - if (digit > 9) { - return false; - } - // PERF NOTE: multiplication by 10 is cheaper than arbitrary integer - // multiplication - i = 10 * i + digit; // might overflow, we will handle the overflow later - return true; -} - -simdjson_really_inline error_code -parse_decimal(simdjson_unused const uint8_t *const src, - const uint8_t *&p, - uint64_t &i, - int64_t &exponent) { - // we continue with the fiction that we have an integer. If the - // floating point number is representable as x * 10^z for some integer - // z that fits in 53 bits, then we will be able to convert back the - // the integer into a float in a lossless manner. - const uint8_t *const first_after_period = p; - -#ifdef SIMDJSON_SWAR_NUMBER_PARSING -#if SIMDJSON_SWAR_NUMBER_PARSING - // this helps if we have lots of decimals! - // this turns out to be frequent enough. - if (is_made_of_eight_digits_fast(p)) { - i = i * 100000000 + parse_eight_digits_unrolled(p); - p += 8; - } -#endif // SIMDJSON_SWAR_NUMBER_PARSING -#endif // #ifdef SIMDJSON_SWAR_NUMBER_PARSING - // Unrolling the first digit makes a small difference on some - // implementations (e.g. westmere) - if (parse_digit(*p, i)) { - ++p; - } - while (parse_digit(*p, i)) { - p++; - } - exponent = first_after_period - p; - // Decimal without digits (123.) is illegal - if (exponent == 0) { - return INVALID_NUMBER(src); - } - return SUCCESS; -} - -simdjson_really_inline error_code -parse_exponent(simdjson_unused const uint8_t *const src, - const uint8_t *&p, - int64_t &exponent) { - // Exp Sign: -123.456e[-]78 - bool neg_exp = ('-' == *p); - if (neg_exp || '+' == *p) { - p++; - } // Skip + as well - - // Exponent: -123.456e-[78] - auto start_exp = p; - int64_t exp_number = 0; - while (parse_digit(*p, exp_number)) { - ++p; - } - // It is possible for parse_digit to overflow. - // In particular, it could overflow to INT64_MIN, and we cannot do - - // INT64_MIN. - // Thus we *must* check for possible overflow before we negate exp_number. - - // Performance notes: it may seem like combining the two "simdjson_unlikely - // checks" below into - // a single simdjson_unlikely path would be faster. The reasoning is sound, - // but the compiler may - // not oblige and may, in fact, generate two distinct paths in any case. It - // might be - // possible to do uint64_t(p - start_exp - 1) >= 18 but it could end up - // trading off - // instructions for a simdjson_likely branch, an unconclusive gain. - - // If there were no digits, it's an error. - if (simdjson_unlikely(p == start_exp)) { - return INVALID_NUMBER(src); - } - // We have a valid positive exponent in exp_number at this point, except - // that - // it may have overflowed. - - // If there were more than 18 digits, we may have overflowed the integer. We - // have to do - // something!!!! - if (simdjson_unlikely(p > start_exp + 18)) { - // Skip leading zeroes: 1e000000000000000000001 is technically valid and - // doesn't overflow - while (*start_exp == '0') { - start_exp++; - } - // 19 digits could overflow int64_t and is kind of absurd anyway. We - // don't - // support exponents smaller than -999,999,999,999,999,999 and bigger - // than 999,999,999,999,999,999. - // We can truncate. - // Note that 999999999999999999 is assuredly too large. The maximal - // ieee64 value before - // infinity is ~1.8e308. The smallest subnormal is ~5e-324. So, - // actually, we could - // truncate at 324. - // Note that there is no reason to fail per se at this point in time. - // E.g., 0e999999999999999999999 is a fine number. - if (p > start_exp + 18) { - exp_number = 999999999999999999; - } - } - // At this point, we know that exp_number is a sane, positive, signed - // integer. - // It is <= 999,999,999,999,999,999. As long as 'exponent' is in - // [-8223372036854775808, 8223372036854775808], we won't overflow. Because - // 'exponent' - // is bounded in magnitude by the size of the JSON input, we are fine in - // this universe. - // To sum it up: the next line should never overflow. - exponent += (neg_exp ? -exp_number : exp_number); - return SUCCESS; -} - -simdjson_really_inline size_t significant_digits(const uint8_t *start_digits, - size_t digit_count) { - // It is possible that the integer had an overflow. - // We have to handle the case where we have 0.0000somenumber. - const uint8_t *start = start_digits; - while ((*start == '0') || (*start == '.')) { - ++start; - } - // we over-decrement by one when there is a '.' - return digit_count - size_t(start - start_digits); -} - -template -simdjson_really_inline error_code write_float(const uint8_t *const src, - bool negative, - uint64_t i, - const uint8_t *start_digits, - size_t digit_count, - int64_t exponent, - W &writer) { - // If we frequently had to deal with long strings of digits, - // we could extend our code by using a 128-bit integer instead - // of a 64-bit integer. However, this is uncommon in practice. - // - // 9999999999999999999 < 2**64 so we can accommodate 19 digits. - // If we have a decimal separator, then digit_count - 1 is the number of - // digits, but we - // may not have a decimal separator! - if (simdjson_unlikely(digit_count > 19 && - significant_digits(start_digits, digit_count) > 19)) { - // Ok, chances are good that we had an overflow! - // this is almost never going to get called!!! - // we start anew, going slowly!!! - // This will happen in the following examples: - // 10000000000000000000000000000000000000000000e+308 - // 3.1415926535897932384626433832795028841971693993751 - // - // NOTE: This makes a *copy* of the writer and passes it to - // slow_float_parsing. This happens - // because slow_float_parsing is a non-inlined function. If we passed - // our writer reference to - // it, it would force it to be stored in memory, preventing the compiler - // from picking it apart - // and putting into registers. i.e. if we pass it as reference, it gets - // slow. - // This is what forces the skip_double, as well. - error_code error = slow_float_parsing(src, writer); - writer.skip_double(); - return error; - } - // NOTE: it's weird that the simdjson_unlikely() only wraps half the if, but - // it seems to get slower any other - // way we've tried: - // https://github.com/simdjson/simdjson/pull/990#discussion_r448497331 - // To future reader: we'd love if someone found a better way, or at least - // could explain this result! - if (simdjson_unlikely(exponent < simdjson::internal::smallest_power) || - (exponent > simdjson::internal::largest_power)) { - // - // Important: smallest_power is such that it leads to a zero value. - // Observe that 18446744073709551615e-343 == 0, i.e. (2**64 - 1) e -343 - // is zero - // so something x 10^-343 goes to zero, but not so with something x - // 10^-342. - static_assert(simdjson::internal::smallest_power <= -342, - "smallest_power is not small enough"); - // - if ((exponent < simdjson::internal::smallest_power) || (i == 0)) { - WRITE_DOUBLE(0, src, writer); - return SUCCESS; - } else { // (exponent > largest_power) and (i != 0) - // We have, for sure, an infinite value and simdjson refuses to - // parse infinite values. - return INVALID_NUMBER(src); - } - } - double d; - if (!compute_float_64(exponent, i, negative, d)) { - // we are almost never going to get here. - if (!parse_float_fallback(src, &d)) { - return INVALID_NUMBER(src); - } - } - WRITE_DOUBLE(d, src, writer); - return SUCCESS; -} - -// for performance analysis, it is sometimes useful to skip parsing -#ifdef SIMDJSON_SKIPNUMBERPARSING - -template -simdjson_really_inline error_code parse_number(const uint8_t *const, - W &writer) { - writer.append_s64(0); // always write zero - return SUCCESS; // always succeeds -} - -simdjson_unused simdjson_really_inline simdjson_result parse_unsigned( - const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result parse_integer( - const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result parse_double( - const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result -parse_unsigned_in_string(const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result -parse_integer_in_string(const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result -parse_double_in_string(const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline bool is_negative( - const uint8_t *src) noexcept { - return false; -} -simdjson_unused simdjson_really_inline simdjson_result is_integer( - const uint8_t *src) noexcept { - return false; -} -simdjson_unused simdjson_really_inline simdjson_result -get_number_type(const uint8_t *src) noexcept { - return ondemand::number_type::signed_integer; -} -#else - -// parse the number at src -// define JSON_TEST_NUMBERS for unit testing -// -// It is assumed that the number is followed by a structural ({,},],[) character -// or a white space character. If that is not the case (e.g., when the JSON -// document is made of a single number), then it is necessary to copy the -// content and append a space before calling this function. -// -// Our objective is accurate parsing (ULP of 0) at high speed. -template -simdjson_really_inline error_code parse_number(const uint8_t *const src, - W &writer) { - // - // Check for minus sign - // - bool negative = (*src == '-'); - const uint8_t *p = src + negative; - - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - if (digit_count == 0 || ('0' == *start_digits && digit_count > 1)) { - return INVALID_NUMBER(src); - } - - // - // Handle floats if there is a . or e (or both) - // - int64_t exponent = 0; - bool is_float = false; - if ('.' == *p) { - is_float = true; - ++p; - SIMDJSON_TRY(parse_decimal(src, p, i, exponent)); - digit_count = - int(p - start_digits); // used later to guard against overflows - } - if (('e' == *p) || ('E' == *p)) { - is_float = true; - ++p; - SIMDJSON_TRY(parse_exponent(src, p, exponent)); - } - if (is_float) { - const bool dirty_end = - jsoncharutils::is_not_structural_or_whitespace(*p); - SIMDJSON_TRY(write_float( - src, negative, i, start_digits, digit_count, exponent, writer)); - if (dirty_end) { - return INVALID_NUMBER(src); - } - return SUCCESS; - } - - // The longest negative 64-bit number is 19 digits. - // The longest positive 64-bit number is 20 digits. - // We do it this way so we don't trigger this branch unless we must. - size_t longest_digit_count = negative ? 19 : 20; - if (digit_count > longest_digit_count) { - return INVALID_NUMBER(src); - } - if (digit_count == longest_digit_count) { - if (negative) { - // Anything negative above INT64_MAX+1 is invalid - if (i > uint64_t(INT64_MAX) + 1) { - return INVALID_NUMBER(src); - } - WRITE_INTEGER(~i + 1, src, writer); - if (jsoncharutils::is_not_structural_or_whitespace(*p)) { - return INVALID_NUMBER(src); - } - return SUCCESS; - // Positive overflow check: - // - A 20 digit number starting with 2-9 is overflow, because - // 18,446,744,073,709,551,615 is the - // biggest uint64_t. - // - A 20 digit number starting with 1 is overflow if it is less - // than INT64_MAX. - // If we got here, it's a 20 digit number starting with the digit - // "1". - // - If a 20 digit number starting with 1 overflowed (i*10+digit), - // the result will be smaller - // than 1,553,255,926,290,448,384. - // - That is smaller than the smallest possible 20-digit number the - // user could write: - // 10,000,000,000,000,000,000. - // - Therefore, if the number is positive and lower than that, it's - // overflow. - // - The value we are looking at is less than or equal to INT64_MAX. - // - } else if (src[0] != uint8_t('1') || i <= uint64_t(INT64_MAX)) { - return INVALID_NUMBER(src); - } - } - - // Write unsigned if it doesn't fit in a signed integer. - if (i > uint64_t(INT64_MAX)) { - WRITE_UNSIGNED(i, src, writer); - } else { - WRITE_INTEGER(negative ? (~i + 1) : i, src, writer); - } - if (jsoncharutils::is_not_structural_or_whitespace(*p)) { - return INVALID_NUMBER(src); - } - return SUCCESS; -} - -// Inlineable functions -namespace { - -// This table can be used to characterize the final character of an integer -// string. For JSON structural character and allowable white space characters, -// we return SUCCESS. For 'e', '.' and 'E', we return INCORRECT_TYPE. Otherwise -// we return NUMBER_ERROR. -// Optimization note: we could easily reduce the size of the table by half (to -// 128) -// at the cost of an extra branch. -// Optimization note: we want the values to use at most 8 bits (not, e.g., 32 -// bits): -static_assert(error_code(uint8_t(NUMBER_ERROR)) == NUMBER_ERROR, - "bad NUMBER_ERROR cast"); -static_assert(error_code(uint8_t(SUCCESS)) == SUCCESS, "bad NUMBER_ERROR cast"); -static_assert(error_code(uint8_t(INCORRECT_TYPE)) == INCORRECT_TYPE, - "bad NUMBER_ERROR cast"); - -const uint8_t integer_string_finisher[256] = { - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, - SUCCESS, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, - NUMBER_ERROR, INCORRECT_TYPE, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, INCORRECT_TYPE, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, SUCCESS, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, INCORRECT_TYPE, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, - SUCCESS, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR}; - -// Parse any number from 0 to 18,446,744,073,709,551,615 -simdjson_unused simdjson_really_inline simdjson_result parse_unsigned( - const uint8_t *const src) noexcept { - const uint8_t *p = src; - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // The longest positive 64-bit number is 20 digits. - // We do it this way so we don't trigger this branch unless we must. - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > 20)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > 20)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if (integer_string_finisher[*p] != SUCCESS) { - return error_code(integer_string_finisher[*p]); - } - - if (digit_count == 20) { - // Positive overflow check: - // - A 20 digit number starting with 2-9 is overflow, because - // 18,446,744,073,709,551,615 is the - // biggest uint64_t. - // - A 20 digit number starting with 1 is overflow if it is less than - // INT64_MAX. - // If we got here, it's a 20 digit number starting with the digit "1". - // - If a 20 digit number starting with 1 overflowed (i*10+digit), the - // result will be smaller - // than 1,553,255,926,290,448,384. - // - That is smaller than the smallest possible 20-digit number the user - // could write: - // 10,000,000,000,000,000,000. - // - Therefore, if the number is positive and lower than that, it's - // overflow. - // - The value we are looking at is less than or equal to INT64_MAX. - // - if (src[0] != uint8_t('1') || i <= uint64_t(INT64_MAX)) { - return INCORRECT_TYPE; - } - } - - return i; -} - - -// Parse any number from 0 to 18,446,744,073,709,551,615 -// Never read at src_end or beyond -simdjson_unused simdjson_really_inline simdjson_result parse_unsigned( - const uint8_t *const src, const uint8_t *const src_end) noexcept { - const uint8_t *p = src; - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while ((p != src_end) && parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // The longest positive 64-bit number is 20 digits. - // We do it this way so we don't trigger this branch unless we must. - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > 20)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > 20)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if ((p != src_end) && integer_string_finisher[*p] != SUCCESS) { - return error_code(integer_string_finisher[*p]); - } - - if (digit_count == 20) { - // Positive overflow check: - // - A 20 digit number starting with 2-9 is overflow, because - // 18,446,744,073,709,551,615 is the - // biggest uint64_t. - // - A 20 digit number starting with 1 is overflow if it is less than - // INT64_MAX. - // If we got here, it's a 20 digit number starting with the digit "1". - // - If a 20 digit number starting with 1 overflowed (i*10+digit), the - // result will be smaller - // than 1,553,255,926,290,448,384. - // - That is smaller than the smallest possible 20-digit number the user - // could write: - // 10,000,000,000,000,000,000. - // - Therefore, if the number is positive and lower than that, it's - // overflow. - // - The value we are looking at is less than or equal to INT64_MAX. - // - if (src[0] != uint8_t('1') || i <= uint64_t(INT64_MAX)) { - return INCORRECT_TYPE; - } - } - - return i; -} - -// Parse any number from 0 to 18,446,744,073,709,551,615 -simdjson_unused simdjson_really_inline simdjson_result -parse_unsigned_in_string(const uint8_t *const src) noexcept { - const uint8_t *p = src + 1; - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // The longest positive 64-bit number is 20 digits. - // We do it this way so we don't trigger this branch unless we must. - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > 20)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > 20)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if (*p != '"') { - return NUMBER_ERROR; - } - - if (digit_count == 20) { - // Positive overflow check: - // - A 20 digit number starting with 2-9 is overflow, because - // 18,446,744,073,709,551,615 is the - // biggest uint64_t. - // - A 20 digit number starting with 1 is overflow if it is less than - // INT64_MAX. - // If we got here, it's a 20 digit number starting with the digit "1". - // - If a 20 digit number starting with 1 overflowed (i*10+digit), the - // result will be smaller - // than 1,553,255,926,290,448,384. - // - That is smaller than the smallest possible 20-digit number the user - // could write: - // 10,000,000,000,000,000,000. - // - Therefore, if the number is positive and lower than that, it's - // overflow. - // - The value we are looking at is less than or equal to INT64_MAX. - // - // Note: we use src[1] and not src[0] because src[0] is the quote - // character in this - // instance. - if (src[1] != uint8_t('1') || i <= uint64_t(INT64_MAX)) { - return INCORRECT_TYPE; - } - } - - return i; -} - -// Parse any number from -9,223,372,036,854,775,808 to -// 9,223,372,036,854,775,807 -simdjson_unused simdjson_really_inline simdjson_result parse_integer( - const uint8_t *src) noexcept { - // - // Check for minus sign - // - bool negative = (*src == '-'); - const uint8_t *p = src + negative; - - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // We go from - // -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807 - // so we can never represent numbers that have more than 19 digits. - size_t longest_digit_count = 19; - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > longest_digit_count)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > longest_digit_count)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if (integer_string_finisher[*p] != SUCCESS) { - return error_code(integer_string_finisher[*p]); - } - // Negative numbers have can go down to - INT64_MAX - 1 whereas positive - // numbers are limited to INT64_MAX. - // Performance note: This check is only needed when digit_count == - // longest_digit_count but it is - // so cheap that we might as well always make it. - if (i > uint64_t(INT64_MAX) + uint64_t(negative)) { - return INCORRECT_TYPE; - } - return negative ? (~i + 1) : i; -} - -// Parse any number from -9,223,372,036,854,775,808 to -// 9,223,372,036,854,775,807 -// Never read at src_end or beyond -simdjson_unused simdjson_really_inline simdjson_result parse_integer( - const uint8_t *const src, const uint8_t *const src_end) noexcept { - // - // Check for minus sign - // - if (src == src_end) { - return NUMBER_ERROR; - } - bool negative = (*src == '-'); - const uint8_t *p = src + negative; - - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while ((p != src_end) && parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // We go from - // -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807 - // so we can never represent numbers that have more than 19 digits. - size_t longest_digit_count = 19; - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > longest_digit_count)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > longest_digit_count)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if ((p != src_end) && integer_string_finisher[*p] != SUCCESS) { - return error_code(integer_string_finisher[*p]); - } - // Negative numbers have can go down to - INT64_MAX - 1 whereas positive - // numbers are limited to INT64_MAX. - // Performance note: This check is only needed when digit_count == - // longest_digit_count but it is - // so cheap that we might as well always make it. - if (i > uint64_t(INT64_MAX) + uint64_t(negative)) { - return INCORRECT_TYPE; - } - return negative ? (~i + 1) : i; -} - -// Parse any number from -9,223,372,036,854,775,808 to -// 9,223,372,036,854,775,807 -simdjson_unused simdjson_really_inline simdjson_result -parse_integer_in_string(const uint8_t *src) noexcept { - // - // Check for minus sign - // - bool negative = (*(src + 1) == '-'); - const uint8_t *p = src + negative + 1; - - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // We go from - // -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807 - // so we can never represent numbers that have more than 19 digits. - size_t longest_digit_count = 19; - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > longest_digit_count)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > longest_digit_count)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if (*p != '"') { - return NUMBER_ERROR; - } - // Negative numbers have can go down to - INT64_MAX - 1 whereas positive - // numbers are limited to INT64_MAX. - // Performance note: This check is only needed when digit_count == - // longest_digit_count but it is - // so cheap that we might as well always make it. - if (i > uint64_t(INT64_MAX) + uint64_t(negative)) { - return INCORRECT_TYPE; - } - return negative ? (~i + 1) : i; -} - -simdjson_unused simdjson_really_inline simdjson_result parse_double( - const uint8_t *src) noexcept { - // - // Check for minus sign - // - bool negative = (*src == '-'); - src += negative; - - // - // Parse the integer part. - // - uint64_t i = 0; - const uint8_t *p = src; - p += parse_digit(*p, i); - bool leading_zero = (i == 0); - while (parse_digit(*p, i)) { - p++; - } - // no integer digits, or 0123 (zero must be solo) - if (p == src) { - return INCORRECT_TYPE; - } - if ((leading_zero && p != src + 1)) { - return NUMBER_ERROR; - } - - // - // Parse the decimal part. - // - int64_t exponent = 0; - bool overflow; - if (simdjson_likely(*p == '.')) { - p++; - const uint8_t *start_decimal_digits = p; - if (!parse_digit(*p, i)) { - return NUMBER_ERROR; - } // no decimal digits - p++; - while (parse_digit(*p, i)) { - p++; - } - exponent = -(p - start_decimal_digits); - - // Overflow check. More than 19 digits (minus the decimal) may be - // overflow. - overflow = p - src - 1 > 19; - if (simdjson_unlikely(overflow && leading_zero)) { - // Skip leading 0.00000 and see if it still overflows - const uint8_t *start_digits = src + 2; - while (*start_digits == '0') { - start_digits++; - } - overflow = start_digits - src > 19; - } - } else { - overflow = p - src > 19; - } - - // - // Parse the exponent - // - if (*p == 'e' || *p == 'E') { - p++; - bool exp_neg = *p == '-'; - p += exp_neg || *p == '+'; - - uint64_t exp = 0; - const uint8_t *start_exp_digits = p; - while (parse_digit(*p, exp)) { - p++; - } - // no exp digits, or 20+ exp digits - if (p - start_exp_digits == 0 || p - start_exp_digits > 19) { - return NUMBER_ERROR; - } - - exponent += exp_neg ? 0 - exp : exp; - } - - if (jsoncharutils::is_not_structural_or_whitespace(*p)) { - return NUMBER_ERROR; - } - - overflow = overflow || exponent < simdjson::internal::smallest_power || - exponent > simdjson::internal::largest_power; - - // - // Assemble (or slow-parse) the float - // - double d; - if (simdjson_likely(!overflow)) { - if (compute_float_64(exponent, i, negative, d)) { - return d; - } - } - if (!parse_float_fallback(src - negative, &d)) { - return NUMBER_ERROR; - } - return d; -} - -simdjson_unused simdjson_really_inline bool is_negative( - const uint8_t *src) noexcept { - return (*src == '-'); -} - -simdjson_unused simdjson_really_inline simdjson_result is_integer( - const uint8_t *src) noexcept { - bool negative = (*src == '-'); - src += negative; - const uint8_t *p = src; - while (static_cast(*p - '0') <= 9) { - p++; - } - if (p == src) { - return NUMBER_ERROR; - } - if (jsoncharutils::is_structural_or_whitespace(*p)) { - return true; - } - return false; -} - -simdjson_unused simdjson_really_inline simdjson_result -get_number_type(const uint8_t *src) noexcept { - bool negative = (*src == '-'); - src += negative; - const uint8_t *p = src; - while (static_cast(*p - '0') <= 9) { - p++; - } - if (p == src) { - return NUMBER_ERROR; - } - if (jsoncharutils::is_structural_or_whitespace(*p)) { - int digit_count = int(p - src); - if (digit_count >= 19) { - const uint8_t *smaller_big_integer = - reinterpret_cast("9223372036854775808"); - if ((digit_count >= 20) || - (memcmp(src, smaller_big_integer, 19) >= 0)) { - return ondemand::number_type::unsigned_integer; - } - } - return ondemand::number_type::signed_integer; - } - return ondemand::number_type::floating_point_number; -} - -// Never read at src_end or beyond -simdjson_unused simdjson_really_inline simdjson_result parse_double( - const uint8_t *src, const uint8_t *const src_end) noexcept { - if (src == src_end) { - return NUMBER_ERROR; - } - // - // Check for minus sign - // - bool negative = (*src == '-'); - src += negative; - - // - // Parse the integer part. - // - uint64_t i = 0; - const uint8_t *p = src; - if (p == src_end) { - return NUMBER_ERROR; - } - p += parse_digit(*p, i); - bool leading_zero = (i == 0); - while ((p != src_end) && parse_digit(*p, i)) { - p++; - } - // no integer digits, or 0123 (zero must be solo) - if (p == src) { - return INCORRECT_TYPE; - } - if ((leading_zero && p != src + 1)) { - return NUMBER_ERROR; - } - - // - // Parse the decimal part. - // - int64_t exponent = 0; - bool overflow; - if (simdjson_likely((p != src_end) && (*p == '.'))) { - p++; - const uint8_t *start_decimal_digits = p; - if ((p == src_end) || !parse_digit(*p, i)) { - return NUMBER_ERROR; - } // no decimal digits - p++; - while ((p != src_end) && parse_digit(*p, i)) { - p++; - } - exponent = -(p - start_decimal_digits); - - // Overflow check. More than 19 digits (minus the decimal) may be - // overflow. - overflow = p - src - 1 > 19; - if (simdjson_unlikely(overflow && leading_zero)) { - // Skip leading 0.00000 and see if it still overflows - const uint8_t *start_digits = src + 2; - while (*start_digits == '0') { - start_digits++; - } - overflow = start_digits - src > 19; - } - } else { - overflow = p - src > 19; - } - - // - // Parse the exponent - // - if ((p != src_end) && (*p == 'e' || *p == 'E')) { - p++; - if (p == src_end) { - return NUMBER_ERROR; - } - bool exp_neg = *p == '-'; - p += exp_neg || *p == '+'; - - uint64_t exp = 0; - const uint8_t *start_exp_digits = p; - while ((p != src_end) && parse_digit(*p, exp)) { - p++; - } - // no exp digits, or 20+ exp digits - if (p - start_exp_digits == 0 || p - start_exp_digits > 19) { - return NUMBER_ERROR; - } - - exponent += exp_neg ? 0 - exp : exp; - } - - if ((p != src_end) && jsoncharutils::is_not_structural_or_whitespace(*p)) { - return NUMBER_ERROR; - } - - overflow = overflow || exponent < simdjson::internal::smallest_power || - exponent > simdjson::internal::largest_power; - - // - // Assemble (or slow-parse) the float - // - double d; - if (simdjson_likely(!overflow)) { - if (compute_float_64(exponent, i, negative, d)) { - return d; - } - } - if (!parse_float_fallback(src - negative, src_end, &d)) { - return NUMBER_ERROR; - } - return d; -} - -simdjson_unused simdjson_really_inline simdjson_result -parse_double_in_string(const uint8_t *src) noexcept { - // - // Check for minus sign - // - bool negative = (*(src + 1) == '-'); - src += negative + 1; - - // - // Parse the integer part. - // - uint64_t i = 0; - const uint8_t *p = src; - p += parse_digit(*p, i); - bool leading_zero = (i == 0); - while (parse_digit(*p, i)) { - p++; - } - // no integer digits, or 0123 (zero must be solo) - if (p == src) { - return INCORRECT_TYPE; - } - if ((leading_zero && p != src + 1)) { - return NUMBER_ERROR; - } - - // - // Parse the decimal part. - // - int64_t exponent = 0; - bool overflow; - if (simdjson_likely(*p == '.')) { - p++; - const uint8_t *start_decimal_digits = p; - if (!parse_digit(*p, i)) { - return NUMBER_ERROR; - } // no decimal digits - p++; - while (parse_digit(*p, i)) { - p++; - } - exponent = -(p - start_decimal_digits); - - // Overflow check. More than 19 digits (minus the decimal) may be - // overflow. - overflow = p - src - 1 > 19; - if (simdjson_unlikely(overflow && leading_zero)) { - // Skip leading 0.00000 and see if it still overflows - const uint8_t *start_digits = src + 2; - while (*start_digits == '0') { - start_digits++; - } - overflow = start_digits - src > 19; - } - } else { - overflow = p - src > 19; - } - - // - // Parse the exponent - // - if (*p == 'e' || *p == 'E') { - p++; - bool exp_neg = *p == '-'; - p += exp_neg || *p == '+'; - - uint64_t exp = 0; - const uint8_t *start_exp_digits = p; - while (parse_digit(*p, exp)) { - p++; - } - // no exp digits, or 20+ exp digits - if (p - start_exp_digits == 0 || p - start_exp_digits > 19) { - return NUMBER_ERROR; - } - - exponent += exp_neg ? 0 - exp : exp; - } - - if (*p != '"') { - return NUMBER_ERROR; - } - - overflow = overflow || exponent < simdjson::internal::smallest_power || - exponent > simdjson::internal::largest_power; - - // - // Assemble (or slow-parse) the float - // - double d; - if (simdjson_likely(!overflow)) { - if (compute_float_64(exponent, i, negative, d)) { - return d; - } - } - if (!parse_float_fallback(src - negative, &d)) { - return NUMBER_ERROR; - } - return d; -} -} // namespace {} -#endif // SIMDJSON_SKIPNUMBERPARSING - -} // namespace numberparsing -} // unnamed namespace -} // namespace arm64 -} // namespace simdjson -/* end file include/simdjson/generic/numberparsing.h */ - -#endif // SIMDJSON_ARM64_NUMBERPARSING_H -/* end file include/simdjson/arm64/numberparsing.h */ -/* begin file include/simdjson/arm64/end.h */ -/* end file include/simdjson/arm64/end.h */ - -#endif // SIMDJSON_IMPLEMENTATION_ARM64 - -#endif // SIMDJSON_ARM64_H -/* end file include/simdjson/arm64.h */ -/* begin file include/simdjson/fallback.h */ -#ifndef SIMDJSON_FALLBACK_H -#define SIMDJSON_FALLBACK_H - - -#if SIMDJSON_IMPLEMENTATION_FALLBACK - -namespace simdjson { -/** - * Fallback implementation (runs on any machine). - */ -namespace fallback {} // namespace fallback -} // namespace simdjson - -/* begin file include/simdjson/fallback/implementation.h */ -#ifndef SIMDJSON_FALLBACK_IMPLEMENTATION_H -#define SIMDJSON_FALLBACK_IMPLEMENTATION_H - - -namespace simdjson { -namespace fallback { - -namespace { -using namespace simdjson; -using namespace simdjson::dom; -} - -class implementation final : public simdjson::implementation { - public: - simdjson_really_inline implementation() - : simdjson::implementation( - "fallback", "Generic fallback implementation", 0) {} - simdjson_warn_unused error_code create_dom_parser_implementation( - size_t capacity, - size_t max_length, - std::unique_ptr &dst) const - noexcept final; - simdjson_warn_unused error_code - minify(const uint8_t *buf, size_t len, uint8_t *dst, size_t &dst_len) const - noexcept final; - simdjson_warn_unused bool validate_utf8(const char *buf, size_t len) const - noexcept final; -}; - -} // namespace fallback -} // namespace simdjson - -#endif // SIMDJSON_FALLBACK_IMPLEMENTATION_H -/* end file include/simdjson/fallback/implementation.h */ - -/* begin file include/simdjson/fallback/begin.h */ -// redefining SIMDJSON_IMPLEMENTATION to "fallback" -// #define SIMDJSON_IMPLEMENTATION fallback -/* end file include/simdjson/fallback/begin.h */ - -// Declarations -/* begin file include/simdjson/generic/dom_parser_implementation.h */ - -namespace simdjson { -namespace fallback { - -// expectation: sizeof(open_container) = 64/8. -struct open_container { - uint32_t tape_index; // where, on the tape, does the scope ([,{) begins - uint32_t count; // how many elements in the scope -}; // struct open_container - -static_assert(sizeof(open_container) == 64 / 8, - "Open container must be 64 bits"); - -class dom_parser_implementation final - : public internal::dom_parser_implementation { - public: - /** Tape location of each open { or [ */ - std::unique_ptr open_containers{}; - /** Whether each open container is a [ or { */ - std::unique_ptr is_array{}; - /** Buffer passed to stage 1 */ - const uint8_t *buf{}; - /** Length passed to stage 1 */ - size_t len{0}; - /** Document passed to stage 2 */ - dom::document *doc{}; - - inline dom_parser_implementation() noexcept; - inline dom_parser_implementation( - dom_parser_implementation &&other) noexcept; - inline dom_parser_implementation &operator=( - dom_parser_implementation &&other) noexcept; - dom_parser_implementation(const dom_parser_implementation &) = delete; - dom_parser_implementation &operator=(const dom_parser_implementation &) = - delete; - - simdjson_warn_unused error_code parse(const uint8_t *buf, - size_t len, - dom::document &doc) noexcept final; - simdjson_warn_unused error_code stage1(const uint8_t *buf, - size_t len, - stage1_mode partial) noexcept final; - simdjson_warn_unused error_code stage2(dom::document &doc) noexcept final; - simdjson_warn_unused error_code - stage2_next(dom::document &doc) noexcept final; - inline simdjson_warn_unused error_code - set_capacity(size_t capacity) noexcept final; - inline simdjson_warn_unused error_code - set_max_depth(size_t max_depth) noexcept final; - - private: - simdjson_really_inline simdjson_warn_unused error_code - set_capacity_stage1(size_t capacity); -}; - -} // namespace fallback -} // namespace simdjson - -namespace simdjson { -namespace fallback { - -inline dom_parser_implementation::dom_parser_implementation() noexcept = - default; -inline dom_parser_implementation::dom_parser_implementation( - dom_parser_implementation &&other) noexcept = default; -inline dom_parser_implementation &dom_parser_implementation::operator=( - dom_parser_implementation &&other) noexcept = default; - -// Leaving these here so they can be inlined if so desired -inline simdjson_warn_unused error_code -dom_parser_implementation::set_capacity(size_t capacity) noexcept { - if (capacity > SIMDJSON_MAXSIZE_BYTES) { - return CAPACITY; - } - // Stage 1 index output - size_t max_structures = SIMDJSON_ROUNDUP_N(capacity, 64) + 2 + 7; - structural_indexes.reset(new (std::nothrow) uint32_t[max_structures]); - if (!structural_indexes) { - _capacity = 0; - return MEMALLOC; - } - structural_indexes[0] = 0; - n_structural_indexes = 0; - - _capacity = capacity; - return SUCCESS; -} - -inline simdjson_warn_unused error_code -dom_parser_implementation::set_max_depth(size_t max_depth) noexcept { - // Stage 2 stacks - open_containers.reset(new (std::nothrow) open_container[max_depth]); - is_array.reset(new (std::nothrow) bool[max_depth]); - if (!is_array || !open_containers) { - _max_depth = 0; - return MEMALLOC; - } - - _max_depth = max_depth; - return SUCCESS; -} - -} // namespace fallback -} // namespace simdjson -/* end file include/simdjson/generic/dom_parser_implementation.h */ -/* begin file include/simdjson/fallback/bitmanipulation.h */ -#ifndef SIMDJSON_FALLBACK_BITMANIPULATION_H -#define SIMDJSON_FALLBACK_BITMANIPULATION_H - -#include - -namespace simdjson { -namespace fallback { -namespace { - -#if defined(_MSC_VER) && !defined(_M_ARM64) && !defined(_M_X64) -static inline unsigned char _BitScanForward64(unsigned long *ret, uint64_t x) { - unsigned long x0 = (unsigned long)x, top, bottom; - _BitScanForward(&top, (unsigned long)(x >> 32)); - _BitScanForward(&bottom, x0); - *ret = x0 ? bottom : 32 + top; - return x != 0; -} -static unsigned char _BitScanReverse64(unsigned long *ret, uint64_t x) { - unsigned long x1 = (unsigned long)(x >> 32), top, bottom; - _BitScanReverse(&top, x1); - _BitScanReverse(&bottom, (unsigned long)x); - *ret = x1 ? top + 32 : bottom; - return x != 0; -} -#endif - -/* result might be undefined when input_num is zero */ -simdjson_really_inline int leading_zeroes(uint64_t input_num) { -#ifdef _MSC_VER - unsigned long leading_zero = 0; - // Search the mask data from most significant bit (MSB) - // to least significant bit (LSB) for a set bit (1). - if (_BitScanReverse64(&leading_zero, input_num)) - return (int)(63 - leading_zero); - else - return 64; -#else - return __builtin_clzll(input_num); -#endif // _MSC_VER -} - -} // unnamed namespace -} // namespace fallback -} // namespace simdjson - -#endif // SIMDJSON_FALLBACK_BITMANIPULATION_H -/* end file include/simdjson/fallback/bitmanipulation.h */ -/* begin file include/simdjson/generic/jsoncharutils.h */ - -namespace simdjson { -namespace fallback { -namespace { -namespace jsoncharutils { - -// return non-zero if not a structural or whitespace char -// zero otherwise -simdjson_really_inline uint32_t is_not_structural_or_whitespace(uint8_t c) { - return internal::structural_or_whitespace_negated[c]; -} - -simdjson_really_inline uint32_t is_structural_or_whitespace(uint8_t c) { - return internal::structural_or_whitespace[c]; -} - -// returns a value with the high 16 bits set if not valid -// otherwise returns the conversion of the 4 hex digits at src into the bottom -// 16 bits of the 32-bit return register -// -// see -// https://lemire.me/blog/2019/04/17/parsing-short-hexadecimal-strings-efficiently/ -static inline uint32_t hex_to_u32_nocheck( - const uint8_t *src) { // strictly speaking, static inline is a C-ism - uint32_t v1 = internal::digit_to_val32[630 + src[0]]; - uint32_t v2 = internal::digit_to_val32[420 + src[1]]; - uint32_t v3 = internal::digit_to_val32[210 + src[2]]; - uint32_t v4 = internal::digit_to_val32[0 + src[3]]; - return v1 | v2 | v3 | v4; -} - -// given a code point cp, writes to c -// the utf-8 code, outputting the length in -// bytes, if the length is zero, the code point -// is invalid -// -// This can possibly be made faster using pdep -// and clz and table lookups, but JSON documents -// have few escaped code points, and the following -// function looks cheap. -// -// Note: we assume that surrogates are treated separately -// -simdjson_really_inline size_t codepoint_to_utf8(uint32_t cp, uint8_t *c) { - if (cp <= 0x7F) { - c[0] = uint8_t(cp); - return 1; // ascii - } - if (cp <= 0x7FF) { - c[0] = uint8_t((cp >> 6) + 192); - c[1] = uint8_t((cp & 63) + 128); - return 2; // universal plane - // Surrogates are treated elsewhere... - //} //else if (0xd800 <= cp && cp <= 0xdfff) { - // return 0; // surrogates // could put assert here - } else if (cp <= 0xFFFF) { - c[0] = uint8_t((cp >> 12) + 224); - c[1] = uint8_t(((cp >> 6) & 63) + 128); - c[2] = uint8_t((cp & 63) + 128); - return 3; - } else if (cp <= - 0x10FFFF) { // if you know you have a valid code point, this - // is not needed - c[0] = uint8_t((cp >> 18) + 240); - c[1] = uint8_t(((cp >> 12) & 63) + 128); - c[2] = uint8_t(((cp >> 6) & 63) + 128); - c[3] = uint8_t((cp & 63) + 128); - return 4; - } - // will return 0 when the code point was too large. - return 0; // bad r -} - -#ifdef SIMDJSON_IS_32BITS // _umul128 for x86, arm -// this is a slow emulation routine for 32-bit -// -static simdjson_really_inline uint64_t __emulu(uint32_t x, uint32_t y) { - return x * (uint64_t)y; -} -static simdjson_really_inline uint64_t _umul128(uint64_t ab, - uint64_t cd, - uint64_t *hi) { - uint64_t ad = __emulu((uint32_t)(ab >> 32), (uint32_t)cd); - uint64_t bd = __emulu((uint32_t)ab, (uint32_t)cd); - uint64_t adbc = ad + __emulu((uint32_t)ab, (uint32_t)(cd >> 32)); - uint64_t adbc_carry = !!(adbc < ad); - uint64_t lo = bd + (adbc << 32); - *hi = __emulu((uint32_t)(ab >> 32), (uint32_t)(cd >> 32)) + (adbc >> 32) + - (adbc_carry << 32) + !!(lo < bd); - return lo; -} -#endif - -using internal::value128; - -simdjson_really_inline value128 full_multiplication(uint64_t value1, - uint64_t value2) { - value128 answer; -#if defined(SIMDJSON_REGULAR_VISUAL_STUDIO) || defined(SIMDJSON_IS_32BITS) -#ifdef _M_ARM64 - // ARM64 has native support for 64-bit multiplications, no need to emultate - answer.high = __umulh(value1, value2); - answer.low = value1 * value2; -#else - answer.low = _umul128( - value1, value2, &answer.high); // _umul128 not available on ARM64 -#endif // _M_ARM64 -#else // defined(SIMDJSON_REGULAR_VISUAL_STUDIO) || defined(SIMDJSON_IS_32BITS) - __uint128_t r = (static_cast<__uint128_t>(value1)) * value2; - answer.low = uint64_t(r); - answer.high = uint64_t(r >> 64); -#endif - return answer; -} - -} // namespace jsoncharutils -} // unnamed namespace -} // namespace fallback -} // namespace simdjson -/* end file include/simdjson/generic/jsoncharutils.h */ -/* begin file include/simdjson/generic/atomparsing.h */ -namespace simdjson { -namespace fallback { -namespace { -/// @private -namespace atomparsing { - -// The string_to_uint32 is exclusively used to map literal strings to 32-bit -// values. -// We use memcpy instead of a pointer cast to avoid undefined behaviors since we -// cannot -// be certain that the character pointer will be properly aligned. -// You might think that using memcpy makes this function expensive, but you'd be -// wrong. -// All decent optimizing compilers (GCC, clang, Visual Studio) will compile -// string_to_uint32("false"); -// to the compile-time constant 1936482662. -simdjson_really_inline uint32_t string_to_uint32(const char *str) { - uint32_t val; - std::memcpy(&val, str, sizeof(uint32_t)); - return val; -} - - -// Again in str4ncmp we use a memcpy to avoid undefined behavior. The memcpy may -// appear expensive. -// Yet all decent optimizing compilers will compile memcpy to a single -// instruction, just about. -simdjson_warn_unused simdjson_really_inline uint32_t -str4ncmp(const uint8_t *src, const char *atom) { - uint32_t - srcval; // we want to avoid unaligned 32-bit loads (undefined in C/C++) - static_assert(sizeof(uint32_t) <= SIMDJSON_PADDING, - "SIMDJSON_PADDING must be larger than 4 bytes"); - std::memcpy(&srcval, src, sizeof(uint32_t)); - return srcval ^ string_to_uint32(atom); -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_true_atom( - const uint8_t *src) { - return (str4ncmp(src, "true") | - jsoncharutils::is_not_structural_or_whitespace(src[4])) == 0; -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_true_atom( - const uint8_t *src, size_t len) { - if (len > 4) { - return is_valid_true_atom(src); - } else if (len == 4) { - return !str4ncmp(src, "true"); - } else { - return false; - } -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_false_atom( - const uint8_t *src) { - return (str4ncmp(src + 1, "alse") | - jsoncharutils::is_not_structural_or_whitespace(src[5])) == 0; -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_false_atom( - const uint8_t *src, size_t len) { - if (len > 5) { - return is_valid_false_atom(src); - } else if (len == 5) { - return !str4ncmp(src + 1, "alse"); - } else { - return false; - } -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_null_atom( - const uint8_t *src) { - return (str4ncmp(src, "null") | - jsoncharutils::is_not_structural_or_whitespace(src[4])) == 0; -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_null_atom( - const uint8_t *src, size_t len) { - if (len > 4) { - return is_valid_null_atom(src); - } else if (len == 4) { - return !str4ncmp(src, "null"); - } else { - return false; - } -} - -} // namespace atomparsing -} // unnamed namespace -} // namespace fallback -} // namespace simdjson -/* end file include/simdjson/generic/atomparsing.h */ -/* begin file include/simdjson/fallback/stringparsing.h */ -#ifndef SIMDJSON_FALLBACK_STRINGPARSING_H -#define SIMDJSON_FALLBACK_STRINGPARSING_H - - -namespace simdjson { -namespace fallback { -namespace { - -// Holds backslashes and quotes locations. -struct backslash_and_quote { - public: - static constexpr uint32_t BYTES_PROCESSED = 1; - simdjson_really_inline static backslash_and_quote copy_and_find( - const uint8_t *src, uint8_t *dst); - - simdjson_really_inline bool has_quote_first() { return c == '"'; } - simdjson_really_inline bool has_backslash() { return c == '\\'; } - simdjson_really_inline int quote_index() { return c == '"' ? 0 : 1; } - simdjson_really_inline int backslash_index() { return c == '\\' ? 0 : 1; } - - uint8_t c; -}; // struct backslash_and_quote - -simdjson_really_inline backslash_and_quote -backslash_and_quote::copy_and_find(const uint8_t *src, uint8_t *dst) { - // store to dest unconditionally - we can overwrite the bits we don't like - // later - dst[0] = src[0]; - return {src[0]}; -} - -} // unnamed namespace -} // namespace fallback -} // namespace simdjson - -/* begin file include/simdjson/generic/stringparsing.h */ -// This file contains the common code every implementation uses -// It is intended to be included multiple times and compiled multiple times - -namespace simdjson { -namespace fallback { -namespace { -/// @private -namespace stringparsing { - -// begin copypasta -// These chars yield themselves: " \ / -// b -> backspace, f -> formfeed, n -> newline, r -> cr, t -> horizontal tab -// u not handled in this table as it's complex -static const uint8_t escape_map[256] = { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x0. - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0x22, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x2f, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x4. - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x5c, 0, 0, 0, // 0x5. - 0, 0, 0x08, 0, 0, 0, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0x0a, 0, // 0x6. - 0, 0, 0x0d, 0, 0x09, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x7. - - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -}; - -// handle a unicode codepoint -// write appropriate values into dest -// src will advance 6 bytes or 12 bytes -// dest will advance a variable amount (return via pointer) -// return true if the unicode codepoint was valid -// We work in little-endian then swap at write time -simdjson_warn_unused simdjson_really_inline bool handle_unicode_codepoint( - const uint8_t **src_ptr, uint8_t **dst_ptr) { - // jsoncharutils::hex_to_u32_nocheck fills high 16 bits of the return value - // with 1s if the - // conversion isn't valid; we defer the check for this to inside the - // multilingual plane check - uint32_t code_point = jsoncharutils::hex_to_u32_nocheck(*src_ptr + 2); - *src_ptr += 6; - // check for low surrogate for characters outside the Basic - // Multilingual Plane. - if (code_point >= 0xd800 && code_point < 0xdc00) { - if (((*src_ptr)[0] != '\\') || (*src_ptr)[1] != 'u') { - return false; - } - uint32_t code_point_2 = jsoncharutils::hex_to_u32_nocheck(*src_ptr + 2); - - // if the first code point is invalid we will get here, as we will go - // past - // the check for being outside the Basic Multilingual plane. If we don't - // find a \u immediately afterwards we fail out anyhow, but if we do, - // this check catches both the case of the first code point being - // invalid - // or the second code point being invalid. - if ((code_point | code_point_2) >> 16) { - return false; - } - - code_point = - (((code_point - 0xd800) << 10) | (code_point_2 - 0xdc00)) + 0x10000; - *src_ptr += 6; - } - size_t offset = jsoncharutils::codepoint_to_utf8(code_point, *dst_ptr); - *dst_ptr += offset; - return offset > 0; -} - -/** - * Unescape a string from src to dst, stopping at a final unescaped quote. E.g., - * if src points at 'joe"', then - * dst needs to have four free bytes. - */ -simdjson_warn_unused simdjson_really_inline uint8_t *parse_string( - const uint8_t *src, uint8_t *dst) { - while (1) { - // Copy the next n bytes, and find the backslash and quote in them. - auto bs_quote = backslash_and_quote::copy_and_find(src, dst); - // If the next thing is the end quote, copy and return - if (bs_quote.has_quote_first()) { - // we encountered quotes first. Move dst to point to quotes and exit - return dst + bs_quote.quote_index(); - } - if (bs_quote.has_backslash()) { - /* find out where the backspace is */ - auto bs_dist = bs_quote.backslash_index(); - uint8_t escape_char = src[bs_dist + 1]; - /* we encountered backslash first. Handle backslash */ - if (escape_char == 'u') { - /* move src/dst up to the start; they will be further adjusted - within the unicode codepoint handling code. */ - src += bs_dist; - dst += bs_dist; - if (!handle_unicode_codepoint(&src, &dst)) { - return nullptr; - } - } else { - /* simple 1:1 conversion. Will eat bs_dist+2 characters in input - * and - * write bs_dist+1 characters to output - * note this may reach beyond the part of the buffer we've - * actually - * seen. I think this is ok */ - uint8_t escape_result = escape_map[escape_char]; - if (escape_result == 0u) { - return nullptr; /* bogus escape value is an error */ - } - dst[bs_dist] = escape_result; - src += bs_dist + 2; - dst += bs_dist + 1; - } - } else { - /* they are the same. Since they can't co-occur, it means we - * encountered neither. */ - src += backslash_and_quote::BYTES_PROCESSED; - dst += backslash_and_quote::BYTES_PROCESSED; - } - } - /* can't be reached */ - return nullptr; -} - -simdjson_unused simdjson_warn_unused simdjson_really_inline error_code -parse_string_to_buffer(const uint8_t *src, - uint8_t *¤t_string_buf_loc, - std::string_view &s) { - if (*(src++) != '"') { - return STRING_ERROR; - } - auto end = stringparsing::parse_string(src, current_string_buf_loc); - if (!end) { - return STRING_ERROR; - } - s = std::string_view(reinterpret_cast(current_string_buf_loc), - end - current_string_buf_loc); - current_string_buf_loc = end; - return SUCCESS; -} - -} // namespace stringparsing -} // unnamed namespace -} // namespace fallback -} // namespace simdjson -/* end file include/simdjson/generic/stringparsing.h */ - -#endif // SIMDJSON_FALLBACK_STRINGPARSING_H -/* end file include/simdjson/fallback/stringparsing.h */ -/* begin file include/simdjson/fallback/numberparsing.h */ -#ifndef SIMDJSON_FALLBACK_NUMBERPARSING_H -#define SIMDJSON_FALLBACK_NUMBERPARSING_H - -#ifdef JSON_TEST_NUMBERS // for unit testing -void found_invalid_number(const uint8_t *buf); -void found_integer(int64_t result, const uint8_t *buf); -void found_unsigned_integer(uint64_t result, const uint8_t *buf); -void found_float(double result, const uint8_t *buf); -#endif - -namespace simdjson { -namespace fallback { -namespace { -// credit: https://johnnylee-sde.github.io/Fast-numeric-string-to-int/ -static simdjson_really_inline uint32_t -parse_eight_digits_unrolled(const char *chars) { - uint64_t val; - memcpy(&val, chars, sizeof(uint64_t)); - val = (val & 0x0F0F0F0F0F0F0F0F) * 2561 >> 8; - val = (val & 0x00FF00FF00FF00FF) * 6553601 >> 16; - return uint32_t((val & 0x0000FFFF0000FFFF) * 42949672960001 >> 32); -} -static simdjson_really_inline uint32_t -parse_eight_digits_unrolled(const uint8_t *chars) { - return parse_eight_digits_unrolled(reinterpret_cast(chars)); -} - -} // unnamed namespace -} // namespace fallback -} // namespace simdjson - -#define SIMDJSON_SWAR_NUMBER_PARSING 1 - -/* begin file include/simdjson/generic/numberparsing.h */ -#include - -namespace simdjson { -namespace fallback { - -namespace ondemand { -/** - * The type of a JSON number - */ -enum class number_type { - floating_point_number = 1, /// a binary64 number - signed_integer, /// a signed integer that fits in a 64-bit word using two's - /// complement - unsigned_integer /// a positive integer larger or equal to 1<<63 -}; -} - -namespace { -/// @private -namespace numberparsing { - - -#ifdef JSON_TEST_NUMBERS -#define INVALID_NUMBER(SRC) (found_invalid_number((SRC)), NUMBER_ERROR) -#define WRITE_INTEGER(VALUE, SRC, WRITER) \ - (found_integer((VALUE), (SRC)), (WRITER).append_s64((VALUE))) -#define WRITE_UNSIGNED(VALUE, SRC, WRITER) \ - (found_unsigned_integer((VALUE), (SRC)), (WRITER).append_u64((VALUE))) -#define WRITE_DOUBLE(VALUE, SRC, WRITER) \ - (found_float((VALUE), (SRC)), (WRITER).append_double((VALUE))) -#else -#define INVALID_NUMBER(SRC) (NUMBER_ERROR) -#define WRITE_INTEGER(VALUE, SRC, WRITER) (WRITER).append_s64((VALUE)) -#define WRITE_UNSIGNED(VALUE, SRC, WRITER) (WRITER).append_u64((VALUE)) -#define WRITE_DOUBLE(VALUE, SRC, WRITER) (WRITER).append_double((VALUE)) -#endif - -namespace { -// Convert a mantissa, an exponent and a sign bit into an ieee64 double. -// The real_exponent needs to be in [0, 2046] (technically real_exponent = 2047 -// would be acceptable). -// The mantissa should be in [0,1<<53). The bit at index (1ULL << 52) while be -// zeroed. -simdjson_really_inline double to_double(uint64_t mantissa, - uint64_t real_exponent, - bool negative) { - double d; - mantissa &= ~(1ULL << 52); - mantissa |= real_exponent << 52; - mantissa |= ((static_cast(negative)) << 63); - std::memcpy(&d, &mantissa, sizeof(d)); - return d; -} -} -// Attempts to compute i * 10^(power) exactly; and if "negative" is -// true, negate the result. -// This function will only work in some cases, when it does not work, success is -// set to false. This should work *most of the time* (like 99% of the time). -// We assume that power is in the [smallest_power, -// largest_power] interval: the caller is responsible for this check. -simdjson_really_inline bool compute_float_64(int64_t power, - uint64_t i, - bool negative, - double &d) { -// we start with a fast path -// It was described in -// Clinger WD. How to read floating point numbers accurately. -// ACM SIGPLAN Notices. 1990 -#ifndef FLT_EVAL_METHOD -#error "FLT_EVAL_METHOD should be defined, please include cfloat." -#endif -#if (FLT_EVAL_METHOD != 1) && (FLT_EVAL_METHOD != 0) - // We cannot be certain that x/y is rounded to nearest. - if (0 <= power && power <= 22 && i <= 9007199254740991) { -#else - if (-22 <= power && power <= 22 && i <= 9007199254740991) { -#endif - // convert the integer into a double. This is lossless since - // 0 <= i <= 2^53 - 1. - d = double(i); - // - // The general idea is as follows. - // If 0 <= s < 2^53 and if 10^0 <= p <= 10^22 then - // 1) Both s and p can be represented exactly as 64-bit floating-point - // values - // (binary64). - // 2) Because s and p can be represented exactly as floating-point - // values, - // then s * p - // and s / p will produce correctly rounded values. - // - if (power < 0) { - d = d / simdjson::internal::power_of_ten[-power]; - } else { - d = d * simdjson::internal::power_of_ten[power]; - } - if (negative) { - d = -d; - } - return true; - } - // When 22 < power && power < 22 + 16, we could - // hope for another, secondary fast path. It was - // described by David M. Gay in "Correctly rounded - // binary-decimal and decimal-binary conversions." (1990) - // If you need to compute i * 10^(22 + x) for x < 16, - // first compute i * 10^x, if you know that result is exact - // (e.g., when i * 10^x < 2^53), - // then you can still proceed and do (i * 10^x) * 10^22. - // Is this worth your time? - // You need 22 < power *and* power < 22 + 16 *and* (i * 10^(x-22) < 2^53) - // for this second fast path to work. - // If you you have 22 < power *and* power < 22 + 16, and then you - // optimistically compute "i * 10^(x-22)", there is still a chance that you - // have wasted your time if i * 10^(x-22) >= 2^53. It makes the use cases of - // this optimization maybe less common than we would like. Source: - // http://www.exploringbinary.com/fast-path-decimal-to-floating-point-conversion/ - // also used in RapidJSON: https://rapidjson.org/strtod_8h_source.html - - // The fast path has now failed, so we are failing back on the slower path. - - // In the slow path, we need to adjust i so that it is > 1<<63 which is - // always - // possible, except if i == 0, so we handle i == 0 separately. - if (i == 0) { - d = 0.0; - return true; - } - - - // The exponent is 1024 + 63 + power - // + floor(log(5**power)/log(2)). - // The 1024 comes from the ieee64 standard. - // The 63 comes from the fact that we use a 64-bit word. - // - // Computing floor(log(5**power)/log(2)) could be - // slow. Instead we use a fast function. - // - // For power in (-400,350), we have that - // (((152170 + 65536) * power ) >> 16); - // is equal to - // floor(log(5**power)/log(2)) + power when power >= 0 - // and it is equal to - // ceil(log(5**-power)/log(2)) + power when power < 0 - // - // The 65536 is (1<<16) and corresponds to - // (65536 * power) >> 16 ---> power - // - // ((152170 * power ) >> 16) is equal to - // floor(log(5**power)/log(2)) - // - // Note that this is not magic: 152170/(1<<16) is - // approximatively equal to log(5)/log(2). - // The 1<<16 value is a power of two; we could use a - // larger power of 2 if we wanted to. - // - int64_t exponent = (((152170 + 65536) * power) >> 16) + 1024 + 63; - - - // We want the most significant bit of i to be 1. Shift if needed. - int lz = leading_zeroes(i); - i <<= lz; - - - // We are going to need to do some 64-bit arithmetic to get a precise - // product. - // We use a table lookup approach. - // It is safe because - // power >= smallest_power - // and power <= largest_power - // We recover the mantissa of the power, it has a leading 1. It is always - // rounded down. - // - // We want the most significant 64 bits of the product. We know - // this will be non-zero because the most significant bit of i is - // 1. - const uint32_t index = - 2 * uint32_t(power - simdjson::internal::smallest_power); - // Optimization: It may be that materializing the index as a variable might - // confuse some compilers and prevent effective complex-addressing loads. - // (Done for code clarity.) - // - // The full_multiplication function computes the 128-bit product of two - // 64-bit words - // with a returned value of type value128 with a "low component" - // corresponding to the - // 64-bit least significant bits of the product and with a "high component" - // corresponding - // to the 64-bit most significant bits of the product. - simdjson::internal::value128 firstproduct = - jsoncharutils::full_multiplication( - i, simdjson::internal::power_of_five_128[index]); - // Both i and power_of_five_128[index] have their most significant bit set - // to 1 which - // implies that the either the most or the second most significant bit of - // the product - // is 1. We pack values in this manner for efficiency reasons: it maximizes - // the use - // we make of the product. It also makes it easy to reason about the - // product: there - // is 0 or 1 leading zero in the product. - - // Unless the least significant 9 bits of the high (64-bit) part of the full - // product are all 1s, then we know that the most significant 55 bits are - // exact and no further work is needed. Having 55 bits is necessary because - // we need 53 bits for the mantissa but we have to have one rounding bit and - // we can waste a bit if the most significant bit of the product is zero. - if ((firstproduct.high & 0x1FF) == 0x1FF) { - // We want to compute i * 5^q, but only care about the top 55 bits at - // most. - // Consider the scenario where q>=0. Then 5^q may not fit in 64-bits. - // Doing - // the full computation is wasteful. So we do what is called a - // "truncated - // multiplication". - // We take the most significant 64-bits, and we put them in - // power_of_five_128[index]. Usually, that's good enough to approximate - // i * 5^q - // to the desired approximation using one multiplication. Sometimes it - // does not suffice. - // Then we store the next most significant 64 bits in - // power_of_five_128[index + 1], and - // then we get a better approximation to i * 5^q. In very rare cases, - // even that - // will not suffice, though it is seemingly very hard to find such a - // scenario. - // - // That's for when q>=0. The logic for q<0 is somewhat similar but it is - // somewhat - // more complicated. - // - // There is an extra layer of complexity in that we need more than 55 - // bits of - // accuracy in the round-to-even scenario. - // - // The full_multiplication function computes the 128-bit product of two - // 64-bit words - // with a returned value of type value128 with a "low component" - // corresponding to the - // 64-bit least significant bits of the product and with a "high - // component" corresponding - // to the 64-bit most significant bits of the product. - simdjson::internal::value128 secondproduct = - jsoncharutils::full_multiplication( - i, simdjson::internal::power_of_five_128[index + 1]); - firstproduct.low += secondproduct.high; - if (secondproduct.high > firstproduct.low) { - firstproduct.high++; - } - // At this point, we might need to add at most one to firstproduct, but - // this - // can only change the value of firstproduct.high if firstproduct.low is - // maximal. - if (simdjson_unlikely(firstproduct.low == 0xFFFFFFFFFFFFFFFF)) { - // This is very unlikely, but if so, we need to do much more work! - return false; - } - } - uint64_t lower = firstproduct.low; - uint64_t upper = firstproduct.high; - // The final mantissa should be 53 bits with a leading 1. - // We shift it so that it occupies 54 bits with a leading 1. - /////// - uint64_t upperbit = upper >> 63; - uint64_t mantissa = upper >> (upperbit + 9); - lz += int(1 ^ upperbit); - - // Here we have mantissa < (1<<54). - int64_t real_exponent = exponent - lz; - if (simdjson_unlikely(real_exponent <= 0)) { // we have a subnormal? - // Here have that real_exponent <= 0 so -real_exponent >= 0 - if (-real_exponent + 1 >= 64) { // if we have more than 64 bits below - // the minimum exponent, you have a - // zero for sure. - d = 0.0; - return true; - } - // next line is safe because -real_exponent + 1 < 0 - mantissa >>= -real_exponent + 1; - // Thankfully, we can't have both "round-to-even" and subnormals because - // "round-to-even" only occurs for powers close to 0. - mantissa += (mantissa & 1); // round up - mantissa >>= 1; - // There is a weird scenario where we don't have a subnormal but just. - // Suppose we start with 2.2250738585072013e-308, we end up - // with 0x3fffffffffffff x 2^-1023-53 which is technically subnormal - // whereas 0x40000000000000 x 2^-1023-53 is normal. Now, we need to - // round - // up 0x3fffffffffffff x 2^-1023-53 and once we do, we are no longer - // subnormal, but we can only know this after rounding. - // So we only declare a subnormal if we are smaller than the threshold. - real_exponent = (mantissa < (uint64_t(1) << 52)) ? 0 : 1; - d = to_double(mantissa, real_exponent, negative); - return true; - } - // We have to round to even. The "to even" part - // is only a problem when we are right in between two floats - // which we guard against. - // If we have lots of trailing zeros, we may fall right between two - // floating-point values. - // - // The round-to-even cases take the form of a number 2m+1 which is in - // (2^53,2^54] - // times a power of two. That is, it is right between a number with binary - // significand - // m and another number with binary significand m+1; and it must be the case - // that it cannot be represented by a float itself. - // - // We must have that w * 10 ^q == (2m+1) * 2^p for some power of two 2^p. - // Recall that 10^q = 5^q * 2^q. - // When q >= 0, we must have that (2m+1) is divible by 5^q, so 5^q <= 2^54. - // We have that - // 5^23 <= 2^54 and it is the last power of five to qualify, so q <= 23. - // When q<0, we have w >= (2m+1) x 5^{-q}. We must have that w<2^{64} so - // (2m+1) x 5^{-q} < 2^{64}. We have that 2m+1>2^{53}. Hence, we must have - // 2^{53} x 5^{-q} < 2^{64}. - // Hence we have 5^{-q} < 2^{11}$ or q>= -4. - // - // We require lower <= 1 and not lower == 0 because we could not prove that - // that lower == 0 is implied; but we could prove that lower <= 1 is a - // necessary and sufficient test. - if (simdjson_unlikely((lower <= 1) && (power >= -4) && (power <= 23) && - ((mantissa & 3) == 1))) { - if ((mantissa << (upperbit + 64 - 53 - 2)) == upper) { - mantissa &= ~1; // flip it so that we do not round up - } - } - - mantissa += mantissa & 1; - mantissa >>= 1; - - // Here we have mantissa < (1<<53), unless there was an overflow - if (mantissa >= (1ULL << 53)) { - ////////// - // This will happen when parsing values such as 7.2057594037927933e+16 - //////// - mantissa = (1ULL << 52); - real_exponent++; - } - mantissa &= ~(1ULL << 52); - // we have to check that real_exponent is in range, otherwise we bail out - if (simdjson_unlikely(real_exponent > 2046)) { - // We have an infinite value!!! We could actually throw an error here if - // we could. - return false; - } - d = to_double(mantissa, real_exponent, negative); - return true; -} - -// We call a fallback floating-point parser that might be slow. Note -// it will accept JSON numbers, but the JSON spec. is more restrictive so -// before you call parse_float_fallback, you need to have validated the input -// string with the JSON grammar. -// It will return an error (false) if the parsed number is infinite. -// The string parsing itself always succeeds. We know that there is at least -// one digit. -static bool parse_float_fallback(const uint8_t *ptr, double *outDouble) { - *outDouble = - simdjson::internal::from_chars(reinterpret_cast(ptr)); - // We do not accept infinite values. - - // Detecting finite values in a portable manner is ridiculously hard, - // ideally - // we would want to do: - // return !std::isfinite(*outDouble); - // but that mysteriously fails under legacy/old libc++ libraries, see - // https://github.com/simdjson/simdjson/issues/1286 - // - // Therefore, fall back to this solution (the extra parens are there - // to handle that max may be a macro on windows). - return !(*outDouble > (std::numeric_limits::max)() || - *outDouble < std::numeric_limits::lowest()); -} -static bool parse_float_fallback(const uint8_t *ptr, - const uint8_t *end_ptr, - double *outDouble) { - *outDouble = - simdjson::internal::from_chars(reinterpret_cast(ptr), - reinterpret_cast(end_ptr)); - // We do not accept infinite values. - - // Detecting finite values in a portable manner is ridiculously hard, - // ideally - // we would want to do: - // return !std::isfinite(*outDouble); - // but that mysteriously fails under legacy/old libc++ libraries, see - // https://github.com/simdjson/simdjson/issues/1286 - // - // Therefore, fall back to this solution (the extra parens are there - // to handle that max may be a macro on windows). - return !(*outDouble > (std::numeric_limits::max)() || - *outDouble < std::numeric_limits::lowest()); -} - -// check quickly whether the next 8 chars are made of digits -// at a glance, it looks better than Mula's -// http://0x80.pl/articles/swar-digits-validate.html -simdjson_really_inline bool is_made_of_eight_digits_fast(const uint8_t *chars) { - uint64_t val; - // this can read up to 7 bytes beyond the buffer size, but we require - // SIMDJSON_PADDING of padding - static_assert(7 <= SIMDJSON_PADDING, - "SIMDJSON_PADDING must be bigger than 7"); - std::memcpy(&val, chars, 8); - // a branchy method might be faster: - // return (( val & 0xF0F0F0F0F0F0F0F0 ) == 0x3030303030303030) - // && (( (val + 0x0606060606060606) & 0xF0F0F0F0F0F0F0F0 ) == - // 0x3030303030303030); - return (((val & 0xF0F0F0F0F0F0F0F0) | - (((val + 0x0606060606060606) & 0xF0F0F0F0F0F0F0F0) >> 4)) == - 0x3333333333333333); -} - -template -error_code slow_float_parsing(simdjson_unused const uint8_t *src, W writer) { - double d; - if (parse_float_fallback(src, &d)) { - writer.append_double(d); - return SUCCESS; - } - return INVALID_NUMBER(src); -} - -template -SIMDJSON_NO_SANITIZE_UNDEFINED // We deliberately allow overflow here and check - // later - simdjson_really_inline bool - parse_digit(const uint8_t c, I &i) { - const uint8_t digit = static_cast(c - '0'); - if (digit > 9) { - return false; - } - // PERF NOTE: multiplication by 10 is cheaper than arbitrary integer - // multiplication - i = 10 * i + digit; // might overflow, we will handle the overflow later - return true; -} - -simdjson_really_inline error_code -parse_decimal(simdjson_unused const uint8_t *const src, - const uint8_t *&p, - uint64_t &i, - int64_t &exponent) { - // we continue with the fiction that we have an integer. If the - // floating point number is representable as x * 10^z for some integer - // z that fits in 53 bits, then we will be able to convert back the - // the integer into a float in a lossless manner. - const uint8_t *const first_after_period = p; - -#ifdef SIMDJSON_SWAR_NUMBER_PARSING -#if SIMDJSON_SWAR_NUMBER_PARSING - // this helps if we have lots of decimals! - // this turns out to be frequent enough. - if (is_made_of_eight_digits_fast(p)) { - i = i * 100000000 + parse_eight_digits_unrolled(p); - p += 8; - } -#endif // SIMDJSON_SWAR_NUMBER_PARSING -#endif // #ifdef SIMDJSON_SWAR_NUMBER_PARSING - // Unrolling the first digit makes a small difference on some - // implementations (e.g. westmere) - if (parse_digit(*p, i)) { - ++p; - } - while (parse_digit(*p, i)) { - p++; - } - exponent = first_after_period - p; - // Decimal without digits (123.) is illegal - if (exponent == 0) { - return INVALID_NUMBER(src); - } - return SUCCESS; -} - -simdjson_really_inline error_code -parse_exponent(simdjson_unused const uint8_t *const src, - const uint8_t *&p, - int64_t &exponent) { - // Exp Sign: -123.456e[-]78 - bool neg_exp = ('-' == *p); - if (neg_exp || '+' == *p) { - p++; - } // Skip + as well - - // Exponent: -123.456e-[78] - auto start_exp = p; - int64_t exp_number = 0; - while (parse_digit(*p, exp_number)) { - ++p; - } - // It is possible for parse_digit to overflow. - // In particular, it could overflow to INT64_MIN, and we cannot do - - // INT64_MIN. - // Thus we *must* check for possible overflow before we negate exp_number. - - // Performance notes: it may seem like combining the two "simdjson_unlikely - // checks" below into - // a single simdjson_unlikely path would be faster. The reasoning is sound, - // but the compiler may - // not oblige and may, in fact, generate two distinct paths in any case. It - // might be - // possible to do uint64_t(p - start_exp - 1) >= 18 but it could end up - // trading off - // instructions for a simdjson_likely branch, an unconclusive gain. - - // If there were no digits, it's an error. - if (simdjson_unlikely(p == start_exp)) { - return INVALID_NUMBER(src); - } - // We have a valid positive exponent in exp_number at this point, except - // that - // it may have overflowed. - - // If there were more than 18 digits, we may have overflowed the integer. We - // have to do - // something!!!! - if (simdjson_unlikely(p > start_exp + 18)) { - // Skip leading zeroes: 1e000000000000000000001 is technically valid and - // doesn't overflow - while (*start_exp == '0') { - start_exp++; - } - // 19 digits could overflow int64_t and is kind of absurd anyway. We - // don't - // support exponents smaller than -999,999,999,999,999,999 and bigger - // than 999,999,999,999,999,999. - // We can truncate. - // Note that 999999999999999999 is assuredly too large. The maximal - // ieee64 value before - // infinity is ~1.8e308. The smallest subnormal is ~5e-324. So, - // actually, we could - // truncate at 324. - // Note that there is no reason to fail per se at this point in time. - // E.g., 0e999999999999999999999 is a fine number. - if (p > start_exp + 18) { - exp_number = 999999999999999999; - } - } - // At this point, we know that exp_number is a sane, positive, signed - // integer. - // It is <= 999,999,999,999,999,999. As long as 'exponent' is in - // [-8223372036854775808, 8223372036854775808], we won't overflow. Because - // 'exponent' - // is bounded in magnitude by the size of the JSON input, we are fine in - // this universe. - // To sum it up: the next line should never overflow. - exponent += (neg_exp ? -exp_number : exp_number); - return SUCCESS; -} - -simdjson_really_inline size_t significant_digits(const uint8_t *start_digits, - size_t digit_count) { - // It is possible that the integer had an overflow. - // We have to handle the case where we have 0.0000somenumber. - const uint8_t *start = start_digits; - while ((*start == '0') || (*start == '.')) { - ++start; - } - // we over-decrement by one when there is a '.' - return digit_count - size_t(start - start_digits); -} - -template -simdjson_really_inline error_code write_float(const uint8_t *const src, - bool negative, - uint64_t i, - const uint8_t *start_digits, - size_t digit_count, - int64_t exponent, - W &writer) { - // If we frequently had to deal with long strings of digits, - // we could extend our code by using a 128-bit integer instead - // of a 64-bit integer. However, this is uncommon in practice. - // - // 9999999999999999999 < 2**64 so we can accommodate 19 digits. - // If we have a decimal separator, then digit_count - 1 is the number of - // digits, but we - // may not have a decimal separator! - if (simdjson_unlikely(digit_count > 19 && - significant_digits(start_digits, digit_count) > 19)) { - // Ok, chances are good that we had an overflow! - // this is almost never going to get called!!! - // we start anew, going slowly!!! - // This will happen in the following examples: - // 10000000000000000000000000000000000000000000e+308 - // 3.1415926535897932384626433832795028841971693993751 - // - // NOTE: This makes a *copy* of the writer and passes it to - // slow_float_parsing. This happens - // because slow_float_parsing is a non-inlined function. If we passed - // our writer reference to - // it, it would force it to be stored in memory, preventing the compiler - // from picking it apart - // and putting into registers. i.e. if we pass it as reference, it gets - // slow. - // This is what forces the skip_double, as well. - error_code error = slow_float_parsing(src, writer); - writer.skip_double(); - return error; - } - // NOTE: it's weird that the simdjson_unlikely() only wraps half the if, but - // it seems to get slower any other - // way we've tried: - // https://github.com/simdjson/simdjson/pull/990#discussion_r448497331 - // To future reader: we'd love if someone found a better way, or at least - // could explain this result! - if (simdjson_unlikely(exponent < simdjson::internal::smallest_power) || - (exponent > simdjson::internal::largest_power)) { - // - // Important: smallest_power is such that it leads to a zero value. - // Observe that 18446744073709551615e-343 == 0, i.e. (2**64 - 1) e -343 - // is zero - // so something x 10^-343 goes to zero, but not so with something x - // 10^-342. - static_assert(simdjson::internal::smallest_power <= -342, - "smallest_power is not small enough"); - // - if ((exponent < simdjson::internal::smallest_power) || (i == 0)) { - WRITE_DOUBLE(0, src, writer); - return SUCCESS; - } else { // (exponent > largest_power) and (i != 0) - // We have, for sure, an infinite value and simdjson refuses to - // parse infinite values. - return INVALID_NUMBER(src); - } - } - double d; - if (!compute_float_64(exponent, i, negative, d)) { - // we are almost never going to get here. - if (!parse_float_fallback(src, &d)) { - return INVALID_NUMBER(src); - } - } - WRITE_DOUBLE(d, src, writer); - return SUCCESS; -} - -// for performance analysis, it is sometimes useful to skip parsing -#ifdef SIMDJSON_SKIPNUMBERPARSING - -template -simdjson_really_inline error_code parse_number(const uint8_t *const, - W &writer) { - writer.append_s64(0); // always write zero - return SUCCESS; // always succeeds -} - -simdjson_unused simdjson_really_inline simdjson_result parse_unsigned( - const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result parse_integer( - const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result parse_double( - const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result -parse_unsigned_in_string(const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result -parse_integer_in_string(const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result -parse_double_in_string(const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline bool is_negative( - const uint8_t *src) noexcept { - return false; -} -simdjson_unused simdjson_really_inline simdjson_result is_integer( - const uint8_t *src) noexcept { - return false; -} -simdjson_unused simdjson_really_inline simdjson_result -get_number_type(const uint8_t *src) noexcept { - return ondemand::number_type::signed_integer; -} -#else - -// parse the number at src -// define JSON_TEST_NUMBERS for unit testing -// -// It is assumed that the number is followed by a structural ({,},],[) character -// or a white space character. If that is not the case (e.g., when the JSON -// document is made of a single number), then it is necessary to copy the -// content and append a space before calling this function. -// -// Our objective is accurate parsing (ULP of 0) at high speed. -template -simdjson_really_inline error_code parse_number(const uint8_t *const src, - W &writer) { - // - // Check for minus sign - // - bool negative = (*src == '-'); - const uint8_t *p = src + negative; - - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - if (digit_count == 0 || ('0' == *start_digits && digit_count > 1)) { - return INVALID_NUMBER(src); - } - - // - // Handle floats if there is a . or e (or both) - // - int64_t exponent = 0; - bool is_float = false; - if ('.' == *p) { - is_float = true; - ++p; - SIMDJSON_TRY(parse_decimal(src, p, i, exponent)); - digit_count = - int(p - start_digits); // used later to guard against overflows - } - if (('e' == *p) || ('E' == *p)) { - is_float = true; - ++p; - SIMDJSON_TRY(parse_exponent(src, p, exponent)); - } - if (is_float) { - const bool dirty_end = - jsoncharutils::is_not_structural_or_whitespace(*p); - SIMDJSON_TRY(write_float( - src, negative, i, start_digits, digit_count, exponent, writer)); - if (dirty_end) { - return INVALID_NUMBER(src); - } - return SUCCESS; - } - - // The longest negative 64-bit number is 19 digits. - // The longest positive 64-bit number is 20 digits. - // We do it this way so we don't trigger this branch unless we must. - size_t longest_digit_count = negative ? 19 : 20; - if (digit_count > longest_digit_count) { - return INVALID_NUMBER(src); - } - if (digit_count == longest_digit_count) { - if (negative) { - // Anything negative above INT64_MAX+1 is invalid - if (i > uint64_t(INT64_MAX) + 1) { - return INVALID_NUMBER(src); - } - WRITE_INTEGER(~i + 1, src, writer); - if (jsoncharutils::is_not_structural_or_whitespace(*p)) { - return INVALID_NUMBER(src); - } - return SUCCESS; - // Positive overflow check: - // - A 20 digit number starting with 2-9 is overflow, because - // 18,446,744,073,709,551,615 is the - // biggest uint64_t. - // - A 20 digit number starting with 1 is overflow if it is less - // than INT64_MAX. - // If we got here, it's a 20 digit number starting with the digit - // "1". - // - If a 20 digit number starting with 1 overflowed (i*10+digit), - // the result will be smaller - // than 1,553,255,926,290,448,384. - // - That is smaller than the smallest possible 20-digit number the - // user could write: - // 10,000,000,000,000,000,000. - // - Therefore, if the number is positive and lower than that, it's - // overflow. - // - The value we are looking at is less than or equal to INT64_MAX. - // - } else if (src[0] != uint8_t('1') || i <= uint64_t(INT64_MAX)) { - return INVALID_NUMBER(src); - } - } - - // Write unsigned if it doesn't fit in a signed integer. - if (i > uint64_t(INT64_MAX)) { - WRITE_UNSIGNED(i, src, writer); - } else { - WRITE_INTEGER(negative ? (~i + 1) : i, src, writer); - } - if (jsoncharutils::is_not_structural_or_whitespace(*p)) { - return INVALID_NUMBER(src); - } - return SUCCESS; -} - -// Inlineable functions -namespace { - -// This table can be used to characterize the final character of an integer -// string. For JSON structural character and allowable white space characters, -// we return SUCCESS. For 'e', '.' and 'E', we return INCORRECT_TYPE. Otherwise -// we return NUMBER_ERROR. -// Optimization note: we could easily reduce the size of the table by half (to -// 128) -// at the cost of an extra branch. -// Optimization note: we want the values to use at most 8 bits (not, e.g., 32 -// bits): -static_assert(error_code(uint8_t(NUMBER_ERROR)) == NUMBER_ERROR, - "bad NUMBER_ERROR cast"); -static_assert(error_code(uint8_t(SUCCESS)) == SUCCESS, "bad NUMBER_ERROR cast"); -static_assert(error_code(uint8_t(INCORRECT_TYPE)) == INCORRECT_TYPE, - "bad NUMBER_ERROR cast"); - -const uint8_t integer_string_finisher[256] = { - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, - SUCCESS, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, - NUMBER_ERROR, INCORRECT_TYPE, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, INCORRECT_TYPE, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, SUCCESS, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, INCORRECT_TYPE, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, - SUCCESS, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR}; - -// Parse any number from 0 to 18,446,744,073,709,551,615 -simdjson_unused simdjson_really_inline simdjson_result parse_unsigned( - const uint8_t *const src) noexcept { - const uint8_t *p = src; - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // The longest positive 64-bit number is 20 digits. - // We do it this way so we don't trigger this branch unless we must. - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > 20)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > 20)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if (integer_string_finisher[*p] != SUCCESS) { - return error_code(integer_string_finisher[*p]); - } - - if (digit_count == 20) { - // Positive overflow check: - // - A 20 digit number starting with 2-9 is overflow, because - // 18,446,744,073,709,551,615 is the - // biggest uint64_t. - // - A 20 digit number starting with 1 is overflow if it is less than - // INT64_MAX. - // If we got here, it's a 20 digit number starting with the digit "1". - // - If a 20 digit number starting with 1 overflowed (i*10+digit), the - // result will be smaller - // than 1,553,255,926,290,448,384. - // - That is smaller than the smallest possible 20-digit number the user - // could write: - // 10,000,000,000,000,000,000. - // - Therefore, if the number is positive and lower than that, it's - // overflow. - // - The value we are looking at is less than or equal to INT64_MAX. - // - if (src[0] != uint8_t('1') || i <= uint64_t(INT64_MAX)) { - return INCORRECT_TYPE; - } - } - - return i; -} - - -// Parse any number from 0 to 18,446,744,073,709,551,615 -// Never read at src_end or beyond -simdjson_unused simdjson_really_inline simdjson_result parse_unsigned( - const uint8_t *const src, const uint8_t *const src_end) noexcept { - const uint8_t *p = src; - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while ((p != src_end) && parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // The longest positive 64-bit number is 20 digits. - // We do it this way so we don't trigger this branch unless we must. - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > 20)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > 20)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if ((p != src_end) && integer_string_finisher[*p] != SUCCESS) { - return error_code(integer_string_finisher[*p]); - } - - if (digit_count == 20) { - // Positive overflow check: - // - A 20 digit number starting with 2-9 is overflow, because - // 18,446,744,073,709,551,615 is the - // biggest uint64_t. - // - A 20 digit number starting with 1 is overflow if it is less than - // INT64_MAX. - // If we got here, it's a 20 digit number starting with the digit "1". - // - If a 20 digit number starting with 1 overflowed (i*10+digit), the - // result will be smaller - // than 1,553,255,926,290,448,384. - // - That is smaller than the smallest possible 20-digit number the user - // could write: - // 10,000,000,000,000,000,000. - // - Therefore, if the number is positive and lower than that, it's - // overflow. - // - The value we are looking at is less than or equal to INT64_MAX. - // - if (src[0] != uint8_t('1') || i <= uint64_t(INT64_MAX)) { - return INCORRECT_TYPE; - } - } - - return i; -} - -// Parse any number from 0 to 18,446,744,073,709,551,615 -simdjson_unused simdjson_really_inline simdjson_result -parse_unsigned_in_string(const uint8_t *const src) noexcept { - const uint8_t *p = src + 1; - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // The longest positive 64-bit number is 20 digits. - // We do it this way so we don't trigger this branch unless we must. - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > 20)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > 20)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if (*p != '"') { - return NUMBER_ERROR; - } - - if (digit_count == 20) { - // Positive overflow check: - // - A 20 digit number starting with 2-9 is overflow, because - // 18,446,744,073,709,551,615 is the - // biggest uint64_t. - // - A 20 digit number starting with 1 is overflow if it is less than - // INT64_MAX. - // If we got here, it's a 20 digit number starting with the digit "1". - // - If a 20 digit number starting with 1 overflowed (i*10+digit), the - // result will be smaller - // than 1,553,255,926,290,448,384. - // - That is smaller than the smallest possible 20-digit number the user - // could write: - // 10,000,000,000,000,000,000. - // - Therefore, if the number is positive and lower than that, it's - // overflow. - // - The value we are looking at is less than or equal to INT64_MAX. - // - // Note: we use src[1] and not src[0] because src[0] is the quote - // character in this - // instance. - if (src[1] != uint8_t('1') || i <= uint64_t(INT64_MAX)) { - return INCORRECT_TYPE; - } - } - - return i; -} - -// Parse any number from -9,223,372,036,854,775,808 to -// 9,223,372,036,854,775,807 -simdjson_unused simdjson_really_inline simdjson_result parse_integer( - const uint8_t *src) noexcept { - // - // Check for minus sign - // - bool negative = (*src == '-'); - const uint8_t *p = src + negative; - - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // We go from - // -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807 - // so we can never represent numbers that have more than 19 digits. - size_t longest_digit_count = 19; - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > longest_digit_count)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > longest_digit_count)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if (integer_string_finisher[*p] != SUCCESS) { - return error_code(integer_string_finisher[*p]); - } - // Negative numbers have can go down to - INT64_MAX - 1 whereas positive - // numbers are limited to INT64_MAX. - // Performance note: This check is only needed when digit_count == - // longest_digit_count but it is - // so cheap that we might as well always make it. - if (i > uint64_t(INT64_MAX) + uint64_t(negative)) { - return INCORRECT_TYPE; - } - return negative ? (~i + 1) : i; -} - -// Parse any number from -9,223,372,036,854,775,808 to -// 9,223,372,036,854,775,807 -// Never read at src_end or beyond -simdjson_unused simdjson_really_inline simdjson_result parse_integer( - const uint8_t *const src, const uint8_t *const src_end) noexcept { - // - // Check for minus sign - // - if (src == src_end) { - return NUMBER_ERROR; - } - bool negative = (*src == '-'); - const uint8_t *p = src + negative; - - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while ((p != src_end) && parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // We go from - // -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807 - // so we can never represent numbers that have more than 19 digits. - size_t longest_digit_count = 19; - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > longest_digit_count)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > longest_digit_count)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if ((p != src_end) && integer_string_finisher[*p] != SUCCESS) { - return error_code(integer_string_finisher[*p]); - } - // Negative numbers have can go down to - INT64_MAX - 1 whereas positive - // numbers are limited to INT64_MAX. - // Performance note: This check is only needed when digit_count == - // longest_digit_count but it is - // so cheap that we might as well always make it. - if (i > uint64_t(INT64_MAX) + uint64_t(negative)) { - return INCORRECT_TYPE; - } - return negative ? (~i + 1) : i; -} - -// Parse any number from -9,223,372,036,854,775,808 to -// 9,223,372,036,854,775,807 -simdjson_unused simdjson_really_inline simdjson_result -parse_integer_in_string(const uint8_t *src) noexcept { - // - // Check for minus sign - // - bool negative = (*(src + 1) == '-'); - const uint8_t *p = src + negative + 1; - - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // We go from - // -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807 - // so we can never represent numbers that have more than 19 digits. - size_t longest_digit_count = 19; - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > longest_digit_count)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > longest_digit_count)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if (*p != '"') { - return NUMBER_ERROR; - } - // Negative numbers have can go down to - INT64_MAX - 1 whereas positive - // numbers are limited to INT64_MAX. - // Performance note: This check is only needed when digit_count == - // longest_digit_count but it is - // so cheap that we might as well always make it. - if (i > uint64_t(INT64_MAX) + uint64_t(negative)) { - return INCORRECT_TYPE; - } - return negative ? (~i + 1) : i; -} - -simdjson_unused simdjson_really_inline simdjson_result parse_double( - const uint8_t *src) noexcept { - // - // Check for minus sign - // - bool negative = (*src == '-'); - src += negative; - - // - // Parse the integer part. - // - uint64_t i = 0; - const uint8_t *p = src; - p += parse_digit(*p, i); - bool leading_zero = (i == 0); - while (parse_digit(*p, i)) { - p++; - } - // no integer digits, or 0123 (zero must be solo) - if (p == src) { - return INCORRECT_TYPE; - } - if ((leading_zero && p != src + 1)) { - return NUMBER_ERROR; - } - - // - // Parse the decimal part. - // - int64_t exponent = 0; - bool overflow; - if (simdjson_likely(*p == '.')) { - p++; - const uint8_t *start_decimal_digits = p; - if (!parse_digit(*p, i)) { - return NUMBER_ERROR; - } // no decimal digits - p++; - while (parse_digit(*p, i)) { - p++; - } - exponent = -(p - start_decimal_digits); - - // Overflow check. More than 19 digits (minus the decimal) may be - // overflow. - overflow = p - src - 1 > 19; - if (simdjson_unlikely(overflow && leading_zero)) { - // Skip leading 0.00000 and see if it still overflows - const uint8_t *start_digits = src + 2; - while (*start_digits == '0') { - start_digits++; - } - overflow = start_digits - src > 19; - } - } else { - overflow = p - src > 19; - } - - // - // Parse the exponent - // - if (*p == 'e' || *p == 'E') { - p++; - bool exp_neg = *p == '-'; - p += exp_neg || *p == '+'; - - uint64_t exp = 0; - const uint8_t *start_exp_digits = p; - while (parse_digit(*p, exp)) { - p++; - } - // no exp digits, or 20+ exp digits - if (p - start_exp_digits == 0 || p - start_exp_digits > 19) { - return NUMBER_ERROR; - } - - exponent += exp_neg ? 0 - exp : exp; - } - - if (jsoncharutils::is_not_structural_or_whitespace(*p)) { - return NUMBER_ERROR; - } - - overflow = overflow || exponent < simdjson::internal::smallest_power || - exponent > simdjson::internal::largest_power; - - // - // Assemble (or slow-parse) the float - // - double d; - if (simdjson_likely(!overflow)) { - if (compute_float_64(exponent, i, negative, d)) { - return d; - } - } - if (!parse_float_fallback(src - negative, &d)) { - return NUMBER_ERROR; - } - return d; -} - -simdjson_unused simdjson_really_inline bool is_negative( - const uint8_t *src) noexcept { - return (*src == '-'); -} - -simdjson_unused simdjson_really_inline simdjson_result is_integer( - const uint8_t *src) noexcept { - bool negative = (*src == '-'); - src += negative; - const uint8_t *p = src; - while (static_cast(*p - '0') <= 9) { - p++; - } - if (p == src) { - return NUMBER_ERROR; - } - if (jsoncharutils::is_structural_or_whitespace(*p)) { - return true; - } - return false; -} - -simdjson_unused simdjson_really_inline simdjson_result -get_number_type(const uint8_t *src) noexcept { - bool negative = (*src == '-'); - src += negative; - const uint8_t *p = src; - while (static_cast(*p - '0') <= 9) { - p++; - } - if (p == src) { - return NUMBER_ERROR; - } - if (jsoncharutils::is_structural_or_whitespace(*p)) { - int digit_count = int(p - src); - if (digit_count >= 19) { - const uint8_t *smaller_big_integer = - reinterpret_cast("9223372036854775808"); - if ((digit_count >= 20) || - (memcmp(src, smaller_big_integer, 19) >= 0)) { - return ondemand::number_type::unsigned_integer; - } - } - return ondemand::number_type::signed_integer; - } - return ondemand::number_type::floating_point_number; -} - -// Never read at src_end or beyond -simdjson_unused simdjson_really_inline simdjson_result parse_double( - const uint8_t *src, const uint8_t *const src_end) noexcept { - if (src == src_end) { - return NUMBER_ERROR; - } - // - // Check for minus sign - // - bool negative = (*src == '-'); - src += negative; - - // - // Parse the integer part. - // - uint64_t i = 0; - const uint8_t *p = src; - if (p == src_end) { - return NUMBER_ERROR; - } - p += parse_digit(*p, i); - bool leading_zero = (i == 0); - while ((p != src_end) && parse_digit(*p, i)) { - p++; - } - // no integer digits, or 0123 (zero must be solo) - if (p == src) { - return INCORRECT_TYPE; - } - if ((leading_zero && p != src + 1)) { - return NUMBER_ERROR; - } - - // - // Parse the decimal part. - // - int64_t exponent = 0; - bool overflow; - if (simdjson_likely((p != src_end) && (*p == '.'))) { - p++; - const uint8_t *start_decimal_digits = p; - if ((p == src_end) || !parse_digit(*p, i)) { - return NUMBER_ERROR; - } // no decimal digits - p++; - while ((p != src_end) && parse_digit(*p, i)) { - p++; - } - exponent = -(p - start_decimal_digits); - - // Overflow check. More than 19 digits (minus the decimal) may be - // overflow. - overflow = p - src - 1 > 19; - if (simdjson_unlikely(overflow && leading_zero)) { - // Skip leading 0.00000 and see if it still overflows - const uint8_t *start_digits = src + 2; - while (*start_digits == '0') { - start_digits++; - } - overflow = start_digits - src > 19; - } - } else { - overflow = p - src > 19; - } - - // - // Parse the exponent - // - if ((p != src_end) && (*p == 'e' || *p == 'E')) { - p++; - if (p == src_end) { - return NUMBER_ERROR; - } - bool exp_neg = *p == '-'; - p += exp_neg || *p == '+'; - - uint64_t exp = 0; - const uint8_t *start_exp_digits = p; - while ((p != src_end) && parse_digit(*p, exp)) { - p++; - } - // no exp digits, or 20+ exp digits - if (p - start_exp_digits == 0 || p - start_exp_digits > 19) { - return NUMBER_ERROR; - } - - exponent += exp_neg ? 0 - exp : exp; - } - - if ((p != src_end) && jsoncharutils::is_not_structural_or_whitespace(*p)) { - return NUMBER_ERROR; - } - - overflow = overflow || exponent < simdjson::internal::smallest_power || - exponent > simdjson::internal::largest_power; - - // - // Assemble (or slow-parse) the float - // - double d; - if (simdjson_likely(!overflow)) { - if (compute_float_64(exponent, i, negative, d)) { - return d; - } - } - if (!parse_float_fallback(src - negative, src_end, &d)) { - return NUMBER_ERROR; - } - return d; -} - -simdjson_unused simdjson_really_inline simdjson_result -parse_double_in_string(const uint8_t *src) noexcept { - // - // Check for minus sign - // - bool negative = (*(src + 1) == '-'); - src += negative + 1; - - // - // Parse the integer part. - // - uint64_t i = 0; - const uint8_t *p = src; - p += parse_digit(*p, i); - bool leading_zero = (i == 0); - while (parse_digit(*p, i)) { - p++; - } - // no integer digits, or 0123 (zero must be solo) - if (p == src) { - return INCORRECT_TYPE; - } - if ((leading_zero && p != src + 1)) { - return NUMBER_ERROR; - } - - // - // Parse the decimal part. - // - int64_t exponent = 0; - bool overflow; - if (simdjson_likely(*p == '.')) { - p++; - const uint8_t *start_decimal_digits = p; - if (!parse_digit(*p, i)) { - return NUMBER_ERROR; - } // no decimal digits - p++; - while (parse_digit(*p, i)) { - p++; - } - exponent = -(p - start_decimal_digits); - - // Overflow check. More than 19 digits (minus the decimal) may be - // overflow. - overflow = p - src - 1 > 19; - if (simdjson_unlikely(overflow && leading_zero)) { - // Skip leading 0.00000 and see if it still overflows - const uint8_t *start_digits = src + 2; - while (*start_digits == '0') { - start_digits++; - } - overflow = start_digits - src > 19; - } - } else { - overflow = p - src > 19; - } - - // - // Parse the exponent - // - if (*p == 'e' || *p == 'E') { - p++; - bool exp_neg = *p == '-'; - p += exp_neg || *p == '+'; - - uint64_t exp = 0; - const uint8_t *start_exp_digits = p; - while (parse_digit(*p, exp)) { - p++; - } - // no exp digits, or 20+ exp digits - if (p - start_exp_digits == 0 || p - start_exp_digits > 19) { - return NUMBER_ERROR; - } - - exponent += exp_neg ? 0 - exp : exp; - } - - if (*p != '"') { - return NUMBER_ERROR; - } - - overflow = overflow || exponent < simdjson::internal::smallest_power || - exponent > simdjson::internal::largest_power; - - // - // Assemble (or slow-parse) the float - // - double d; - if (simdjson_likely(!overflow)) { - if (compute_float_64(exponent, i, negative, d)) { - return d; - } - } - if (!parse_float_fallback(src - negative, &d)) { - return NUMBER_ERROR; - } - return d; -} -} // namespace {} -#endif // SIMDJSON_SKIPNUMBERPARSING - -} // namespace numberparsing -} // unnamed namespace -} // namespace fallback -} // namespace simdjson -/* end file include/simdjson/generic/numberparsing.h */ - -#endif // SIMDJSON_FALLBACK_NUMBERPARSING_H -/* end file include/simdjson/fallback/numberparsing.h */ -/* begin file include/simdjson/fallback/end.h */ -/* end file include/simdjson/fallback/end.h */ - -#endif // SIMDJSON_IMPLEMENTATION_FALLBACK -#endif // SIMDJSON_FALLBACK_H -/* end file include/simdjson/fallback.h */ -/* begin file include/simdjson/haswell.h */ -#ifndef SIMDJSON_HASWELL_H -#define SIMDJSON_HASWELL_H - - -#if SIMDJSON_IMPLEMENTATION_HASWELL - -#if SIMDJSON_CAN_ALWAYS_RUN_HASWELL -#define SIMDJSON_TARGET_HASWELL -#define SIMDJSON_UNTARGET_HASWELL -#else -#define SIMDJSON_TARGET_HASWELL SIMDJSON_TARGET_REGION("avx2,bmi,pclmul,lzcnt") -#define SIMDJSON_UNTARGET_HASWELL SIMDJSON_UNTARGET_REGION -#endif - -namespace simdjson { -/** - * Implementation for Haswell (Intel AVX2). - */ -namespace haswell {} // namespace haswell -} // namespace simdjson - -// -// These two need to be included outside SIMDJSON_TARGET_HASWELL -// -/* begin file include/simdjson/haswell/implementation.h */ -#ifndef SIMDJSON_HASWELL_IMPLEMENTATION_H -#define SIMDJSON_HASWELL_IMPLEMENTATION_H - - -// The constructor may be executed on any host, so we take care not to use -// SIMDJSON_TARGET_HASWELL -namespace simdjson { -namespace haswell { - -using namespace simdjson; - -class implementation final : public simdjson::implementation { - public: - simdjson_really_inline implementation() - : simdjson::implementation("haswell", - "Intel/AMD AVX2", - internal::instruction_set::AVX2 | - internal::instruction_set::PCLMULQDQ | - internal::instruction_set::BMI1 | - internal::instruction_set::BMI2) {} - simdjson_warn_unused error_code create_dom_parser_implementation( - size_t capacity, - size_t max_length, - std::unique_ptr &dst) const - noexcept final; - simdjson_warn_unused error_code - minify(const uint8_t *buf, size_t len, uint8_t *dst, size_t &dst_len) const - noexcept final; - simdjson_warn_unused bool validate_utf8(const char *buf, size_t len) const - noexcept final; -}; - -} // namespace haswell -} // namespace simdjson - -#endif // SIMDJSON_HASWELL_IMPLEMENTATION_H -/* end file include/simdjson/haswell/implementation.h */ -/* begin file include/simdjson/haswell/intrinsics.h */ -#ifndef SIMDJSON_HASWELL_INTRINSICS_H -#define SIMDJSON_HASWELL_INTRINSICS_H - - -#ifdef SIMDJSON_VISUAL_STUDIO -// under clang within visual studio, this will include -#include // visual studio or clang -#else -#include // elsewhere -#endif // SIMDJSON_VISUAL_STUDIO - -#ifdef SIMDJSON_CLANG_VISUAL_STUDIO -/** - * You are not supposed, normally, to include these - * headers directly. Instead you should either include intrin.h - * or x86intrin.h. However, when compiling with clang - * under Windows (i.e., when _MSC_VER is set), these headers - * only get included *if* the corresponding features are detected - * from macros: - * e.g., if __AVX2__ is set... in turn, we normally set these - * macros by compiling against the corresponding architecture - * (e.g., arch:AVX2, -mavx2, etc.) which compiles the whole - * software with these advanced instructions. In simdjson, we - * want to compile the whole program for a generic target, - * and only target our specific kernels. As a workaround, - * we directly include the needed headers. These headers would - * normally guard against such usage, but we carefully included - * (or ) before, so the headers - * are fooled. - */ -#include -#include -#include // for _blsr_u64 -#include // for most things (AVX2, AVX512, _popcnt64) -#include // for __lzcnt64 -#include -#include -#include // for _mm_clmulepi64_si128 -// unfortunately, we may not get _blsr_u64, but, thankfully, clang -// has it as a macro. -#ifndef _blsr_u64 -// we roll our own -SIMDJSON_TARGET_HASWELL -static simdjson_really_inline uint64_t _blsr_u64(uint64_t n) { - return (n - 1) & n; -} -SIMDJSON_UNTARGET_HASWELL -#endif // _blsr_u64 -#endif // SIMDJSON_CLANG_VISUAL_STUDIO - -#endif // SIMDJSON_HASWELL_INTRINSICS_H -/* end file include/simdjson/haswell/intrinsics.h */ - -// -// The rest need to be inside the region -// -/* begin file include/simdjson/haswell/begin.h */ -// redefining SIMDJSON_IMPLEMENTATION to "haswell" -// #define SIMDJSON_IMPLEMENTATION haswell -SIMDJSON_TARGET_HASWELL -/* end file include/simdjson/haswell/begin.h */ - -// Declarations -/* begin file include/simdjson/generic/dom_parser_implementation.h */ - -namespace simdjson { -namespace haswell { - -// expectation: sizeof(open_container) = 64/8. -struct open_container { - uint32_t tape_index; // where, on the tape, does the scope ([,{) begins - uint32_t count; // how many elements in the scope -}; // struct open_container - -static_assert(sizeof(open_container) == 64 / 8, - "Open container must be 64 bits"); - -class dom_parser_implementation final - : public internal::dom_parser_implementation { - public: - /** Tape location of each open { or [ */ - std::unique_ptr open_containers{}; - /** Whether each open container is a [ or { */ - std::unique_ptr is_array{}; - /** Buffer passed to stage 1 */ - const uint8_t *buf{}; - /** Length passed to stage 1 */ - size_t len{0}; - /** Document passed to stage 2 */ - dom::document *doc{}; - - inline dom_parser_implementation() noexcept; - inline dom_parser_implementation( - dom_parser_implementation &&other) noexcept; - inline dom_parser_implementation &operator=( - dom_parser_implementation &&other) noexcept; - dom_parser_implementation(const dom_parser_implementation &) = delete; - dom_parser_implementation &operator=(const dom_parser_implementation &) = - delete; - - simdjson_warn_unused error_code parse(const uint8_t *buf, - size_t len, - dom::document &doc) noexcept final; - simdjson_warn_unused error_code stage1(const uint8_t *buf, - size_t len, - stage1_mode partial) noexcept final; - simdjson_warn_unused error_code stage2(dom::document &doc) noexcept final; - simdjson_warn_unused error_code - stage2_next(dom::document &doc) noexcept final; - inline simdjson_warn_unused error_code - set_capacity(size_t capacity) noexcept final; - inline simdjson_warn_unused error_code - set_max_depth(size_t max_depth) noexcept final; - - private: - simdjson_really_inline simdjson_warn_unused error_code - set_capacity_stage1(size_t capacity); -}; - -} // namespace haswell -} // namespace simdjson - -namespace simdjson { -namespace haswell { - -inline dom_parser_implementation::dom_parser_implementation() noexcept = - default; -inline dom_parser_implementation::dom_parser_implementation( - dom_parser_implementation &&other) noexcept = default; -inline dom_parser_implementation &dom_parser_implementation::operator=( - dom_parser_implementation &&other) noexcept = default; - -// Leaving these here so they can be inlined if so desired -inline simdjson_warn_unused error_code -dom_parser_implementation::set_capacity(size_t capacity) noexcept { - if (capacity > SIMDJSON_MAXSIZE_BYTES) { - return CAPACITY; - } - // Stage 1 index output - size_t max_structures = SIMDJSON_ROUNDUP_N(capacity, 64) + 2 + 7; - structural_indexes.reset(new (std::nothrow) uint32_t[max_structures]); - if (!structural_indexes) { - _capacity = 0; - return MEMALLOC; - } - structural_indexes[0] = 0; - n_structural_indexes = 0; - - _capacity = capacity; - return SUCCESS; -} - -inline simdjson_warn_unused error_code -dom_parser_implementation::set_max_depth(size_t max_depth) noexcept { - // Stage 2 stacks - open_containers.reset(new (std::nothrow) open_container[max_depth]); - is_array.reset(new (std::nothrow) bool[max_depth]); - if (!is_array || !open_containers) { - _max_depth = 0; - return MEMALLOC; - } - - _max_depth = max_depth; - return SUCCESS; -} - -} // namespace haswell -} // namespace simdjson -/* end file include/simdjson/generic/dom_parser_implementation.h */ -/* begin file include/simdjson/haswell/bitmanipulation.h */ -#ifndef SIMDJSON_HASWELL_BITMANIPULATION_H -#define SIMDJSON_HASWELL_BITMANIPULATION_H - -namespace simdjson { -namespace haswell { -namespace { - -// We sometimes call trailing_zero on inputs that are zero, -// but the algorithms do not end up using the returned value. -// Sadly, sanitizers are not smart enough to figure it out. -SIMDJSON_NO_SANITIZE_UNDEFINED -simdjson_really_inline int trailing_zeroes(uint64_t input_num) { -#ifdef SIMDJSON_REGULAR_VISUAL_STUDIO - return (int)_tzcnt_u64(input_num); -#else // SIMDJSON_REGULAR_VISUAL_STUDIO - //////// - // You might expect the next line to be equivalent to - // return (int)_tzcnt_u64(input_num); - // but the generated code differs and might be less efficient? - //////// - return __builtin_ctzll(input_num); -#endif // SIMDJSON_REGULAR_VISUAL_STUDIO -} - -/* result might be undefined when input_num is zero */ -simdjson_really_inline uint64_t clear_lowest_bit(uint64_t input_num) { - return _blsr_u64(input_num); -} - -/* result might be undefined when input_num is zero */ -simdjson_really_inline int leading_zeroes(uint64_t input_num) { - return int(_lzcnt_u64(input_num)); -} - -#ifdef SIMDJSON_REGULAR_VISUAL_STUDIO -simdjson_really_inline unsigned __int64 count_ones(uint64_t input_num) { - // note: we do not support legacy 32-bit Windows - return __popcnt64(input_num); // Visual Studio wants two underscores -} -#else -simdjson_really_inline long long int count_ones(uint64_t input_num) { - return _popcnt64(input_num); -} -#endif - -simdjson_really_inline bool add_overflow(uint64_t value1, - uint64_t value2, - uint64_t *result) { -#ifdef SIMDJSON_REGULAR_VISUAL_STUDIO - return _addcarry_u64( - 0, value1, value2, reinterpret_cast(result)); -#else - return __builtin_uaddll_overflow( - value1, value2, reinterpret_cast(result)); -#endif -} - -} // unnamed namespace -} // namespace haswell -} // namespace simdjson - -#endif // SIMDJSON_HASWELL_BITMANIPULATION_H -/* end file include/simdjson/haswell/bitmanipulation.h */ -/* begin file include/simdjson/haswell/bitmask.h */ -#ifndef SIMDJSON_HASWELL_BITMASK_H -#define SIMDJSON_HASWELL_BITMASK_H - -namespace simdjson { -namespace haswell { -namespace { - -// -// Perform a "cumulative bitwise xor," flipping bits each time a 1 is -// encountered. -// -// For example, prefix_xor(00100100) == 00011100 -// -simdjson_really_inline uint64_t prefix_xor(const uint64_t bitmask) { - // There should be no such thing with a processor supporting avx2 - // but not clmul. - __m128i all_ones = _mm_set1_epi8('\xFF'); - __m128i result = - _mm_clmulepi64_si128(_mm_set_epi64x(0ULL, bitmask), all_ones, 0); - return _mm_cvtsi128_si64(result); -} - -} // unnamed namespace -} // namespace haswell -} // namespace simdjson - -#endif // SIMDJSON_HASWELL_BITMASK_H -/* end file include/simdjson/haswell/bitmask.h */ -/* begin file include/simdjson/haswell/simd.h */ -#ifndef SIMDJSON_HASWELL_SIMD_H -#define SIMDJSON_HASWELL_SIMD_H - - -namespace simdjson { -namespace haswell { -namespace { -namespace simd { - -// Forward-declared so they can be used by splat and friends. -template -struct base { - __m256i value; - - // Zero constructor - simdjson_really_inline base() : value{__m256i()} {} - - // Conversion from SIMD register - simdjson_really_inline base(const __m256i _value) : value(_value) {} - - // Conversion to SIMD register - simdjson_really_inline operator const __m256i &() const { - return this->value; - } - simdjson_really_inline operator __m256i &() { return this->value; } - - // Bit operations - simdjson_really_inline Child operator|(const Child other) const { - return _mm256_or_si256(*this, other); - } - simdjson_really_inline Child operator&(const Child other) const { - return _mm256_and_si256(*this, other); - } - simdjson_really_inline Child operator^(const Child other) const { - return _mm256_xor_si256(*this, other); - } - simdjson_really_inline Child bit_andnot(const Child other) const { - return _mm256_andnot_si256(other, *this); - } - simdjson_really_inline Child &operator|=(const Child other) { - auto this_cast = static_cast(this); - *this_cast = *this_cast | other; - return *this_cast; - } - simdjson_really_inline Child &operator&=(const Child other) { - auto this_cast = static_cast(this); - *this_cast = *this_cast & other; - return *this_cast; - } - simdjson_really_inline Child &operator^=(const Child other) { - auto this_cast = static_cast(this); - *this_cast = *this_cast ^ other; - return *this_cast; - } -}; - -// Forward-declared so they can be used by splat and friends. -template -struct simd8; - -template > -struct base8 : base> { - typedef uint32_t bitmask_t; - typedef uint64_t bitmask2_t; - - simdjson_really_inline base8() : base>() {} - simdjson_really_inline base8(const __m256i _value) - : base>(_value) {} - - friend simdjson_really_inline Mask operator==(const simd8 lhs, - const simd8 rhs) { - return _mm256_cmpeq_epi8(lhs, rhs); - } - - static const int SIZE = sizeof(base::value); - - template - simdjson_really_inline simd8 prev(const simd8 prev_chunk) const { - return _mm256_alignr_epi8( - *this, _mm256_permute2x128_si256(prev_chunk, *this, 0x21), 16 - N); - } -}; - -// SIMD byte mask type (returned by things like eq and gt) -template <> -struct simd8 : base8 { - static simdjson_really_inline simd8 splat(bool _value) { - return _mm256_set1_epi8(uint8_t(-(!!_value))); - } - - simdjson_really_inline simd8() : base8() {} - simdjson_really_inline simd8(const __m256i _value) - : base8(_value) {} - // Splat constructor - simdjson_really_inline simd8(bool _value) - : base8(splat(_value)) {} - - simdjson_really_inline int to_bitmask() const { - return _mm256_movemask_epi8(*this); - } - simdjson_really_inline bool any() const { - return !_mm256_testz_si256(*this, *this); - } - simdjson_really_inline simd8 operator~() const { - return *this ^ true; - } -}; - -template -struct base8_numeric : base8 { - static simdjson_really_inline simd8 splat(T _value) { - return _mm256_set1_epi8(_value); - } - static simdjson_really_inline simd8 zero() { - return _mm256_setzero_si256(); - } - static simdjson_really_inline simd8 load(const T values[32]) { - return _mm256_loadu_si256(reinterpret_cast(values)); - } - // Repeat 16 values as many times as necessary (usually for lookup tables) - static simdjson_really_inline simd8 repeat_16(T v0, - T v1, - T v2, - T v3, - T v4, - T v5, - T v6, - T v7, - T v8, - T v9, - T v10, - T v11, - T v12, - T v13, - T v14, - T v15) { - return simd8(v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15, - v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15); - } - - simdjson_really_inline base8_numeric() : base8() {} - simdjson_really_inline base8_numeric(const __m256i _value) - : base8(_value) {} - - // Store to array - simdjson_really_inline void store(T dst[32]) const { - return _mm256_storeu_si256(reinterpret_cast<__m256i *>(dst), *this); - } - - // Addition/subtraction are the same for signed and unsigned - simdjson_really_inline simd8 operator+(const simd8 other) const { - return _mm256_add_epi8(*this, other); - } - simdjson_really_inline simd8 operator-(const simd8 other) const { - return _mm256_sub_epi8(*this, other); - } - simdjson_really_inline simd8 &operator+=(const simd8 other) { - *this = *this + other; - return *static_cast *>(this); - } - simdjson_really_inline simd8 &operator-=(const simd8 other) { - *this = *this - other; - return *static_cast *>(this); - } - - // Override to distinguish from bool version - simdjson_really_inline simd8 operator~() const { return *this ^ 0xFFu; } - - // Perform a lookup assuming the value is between 0 and 16 (undefined - // behavior for out of range values) - template - simdjson_really_inline simd8 lookup_16(simd8 lookup_table) const { - return _mm256_shuffle_epi8(lookup_table, *this); - } - - // Copies to 'output" all bytes corresponding to a 0 in the mask - // (interpreted as a bitset). - // Passing a 0 value for mask would be equivalent to writing out every byte - // to output. - // Only the first 32 - count_ones(mask) bytes of the result are significant - // but 32 bytes - // get written. - // Design consideration: it seems like a function with the - // signature simd8 compress(uint32_t mask) would be - // sensible, but the AVX ISA makes this kind of approach difficult. - template - simdjson_really_inline void compress(uint32_t mask, L *output) const { - using internal::thintable_epi8; - using internal::BitsSetTable256mul2; - using internal::pshufb_combine_table; - // this particular implementation was inspired by work done by - // @animetosho - // we do it in four steps, first 8 bytes and then second 8 bytes... - uint8_t mask1 = uint8_t(mask); // least significant 8 bits - uint8_t mask2 = uint8_t(mask >> 8); // second least significant 8 bits - uint8_t mask3 = uint8_t(mask >> 16); // ... - uint8_t mask4 = uint8_t(mask >> 24); // ... - // next line just loads the 64-bit values thintable_epi8[mask1] and - // thintable_epi8[mask2] into a 128-bit register, using only - // two instructions on most compilers. - __m256i shufmask = _mm256_set_epi64x(thintable_epi8[mask4], - thintable_epi8[mask3], - thintable_epi8[mask2], - thintable_epi8[mask1]); - // we increment by 0x08 the second half of the mask and so forth - shufmask = _mm256_add_epi8(shufmask, - _mm256_set_epi32(0x18181818, - 0x18181818, - 0x10101010, - 0x10101010, - 0x08080808, - 0x08080808, - 0, - 0)); - // this is the version "nearly pruned" - __m256i pruned = _mm256_shuffle_epi8(*this, shufmask); - // we still need to put the pieces back together. - // we compute the popcount of the first words: - int pop1 = BitsSetTable256mul2[mask1]; - int pop3 = BitsSetTable256mul2[mask3]; - - // then load the corresponding mask - // could be done with _mm256_loadu2_m128i but many standard libraries - // omit this intrinsic. - __m256i v256 = _mm256_castsi128_si256( - _mm_loadu_si128(reinterpret_cast( - pshufb_combine_table + pop1 * 8))); - __m256i compactmask = _mm256_insertf128_si256( - v256, - _mm_loadu_si128(reinterpret_cast( - pshufb_combine_table + pop3 * 8)), - 1); - __m256i almostthere = _mm256_shuffle_epi8(pruned, compactmask); - // We just need to write out the result. - // This is the tricky bit that is hard to do - // if we want to return a SIMD register, since there - // is no single-instruction approach to recombine - // the two 128-bit lanes with an offset. - __m128i v128; - v128 = _mm256_castsi256_si128(almostthere); - _mm_storeu_si128(reinterpret_cast<__m128i *>(output), v128); - v128 = _mm256_extractf128_si256(almostthere, 1); - _mm_storeu_si128(reinterpret_cast<__m128i *>(output + 16 - - count_ones(mask & 0xFFFF)), - v128); - } - - template - simdjson_really_inline simd8 lookup_16(L replace0, - L replace1, - L replace2, - L replace3, - L replace4, - L replace5, - L replace6, - L replace7, - L replace8, - L replace9, - L replace10, - L replace11, - L replace12, - L replace13, - L replace14, - L replace15) const { - return lookup_16(simd8::repeat_16(replace0, - replace1, - replace2, - replace3, - replace4, - replace5, - replace6, - replace7, - replace8, - replace9, - replace10, - replace11, - replace12, - replace13, - replace14, - replace15)); - } -}; - -// Signed bytes -template <> -struct simd8 : base8_numeric { - simdjson_really_inline simd8() : base8_numeric() {} - simdjson_really_inline simd8(const __m256i _value) - : base8_numeric(_value) {} - // Splat constructor - simdjson_really_inline simd8(int8_t _value) : simd8(splat(_value)) {} - // Array constructor - simdjson_really_inline simd8(const int8_t values[32]) - : simd8(load(values)) {} - // Member-by-member initialization - simdjson_really_inline simd8(int8_t v0, - int8_t v1, - int8_t v2, - int8_t v3, - int8_t v4, - int8_t v5, - int8_t v6, - int8_t v7, - int8_t v8, - int8_t v9, - int8_t v10, - int8_t v11, - int8_t v12, - int8_t v13, - int8_t v14, - int8_t v15, - int8_t v16, - int8_t v17, - int8_t v18, - int8_t v19, - int8_t v20, - int8_t v21, - int8_t v22, - int8_t v23, - int8_t v24, - int8_t v25, - int8_t v26, - int8_t v27, - int8_t v28, - int8_t v29, - int8_t v30, - int8_t v31) - : simd8(_mm256_setr_epi8(v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15, - v16, - v17, - v18, - v19, - v20, - v21, - v22, - v23, - v24, - v25, - v26, - v27, - v28, - v29, - v30, - v31)) {} - // Repeat 16 values as many times as necessary (usually for lookup tables) - simdjson_really_inline static simd8 repeat_16(int8_t v0, - int8_t v1, - int8_t v2, - int8_t v3, - int8_t v4, - int8_t v5, - int8_t v6, - int8_t v7, - int8_t v8, - int8_t v9, - int8_t v10, - int8_t v11, - int8_t v12, - int8_t v13, - int8_t v14, - int8_t v15) { - return simd8(v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15, - v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15); - } - - // Order-sensitive comparisons - simdjson_really_inline simd8 max_val( - const simd8 other) const { - return _mm256_max_epi8(*this, other); - } - simdjson_really_inline simd8 min_val( - const simd8 other) const { - return _mm256_min_epi8(*this, other); - } - simdjson_really_inline simd8 operator>( - const simd8 other) const { - return _mm256_cmpgt_epi8(*this, other); - } - simdjson_really_inline simd8 operator<( - const simd8 other) const { - return _mm256_cmpgt_epi8(other, *this); - } -}; - -// Unsigned bytes -template <> -struct simd8 : base8_numeric { - simdjson_really_inline simd8() : base8_numeric() {} - simdjson_really_inline simd8(const __m256i _value) - : base8_numeric(_value) {} - // Splat constructor - simdjson_really_inline simd8(uint8_t _value) : simd8(splat(_value)) {} - // Array constructor - simdjson_really_inline simd8(const uint8_t values[32]) - : simd8(load(values)) {} - // Member-by-member initialization - simdjson_really_inline simd8(uint8_t v0, - uint8_t v1, - uint8_t v2, - uint8_t v3, - uint8_t v4, - uint8_t v5, - uint8_t v6, - uint8_t v7, - uint8_t v8, - uint8_t v9, - uint8_t v10, - uint8_t v11, - uint8_t v12, - uint8_t v13, - uint8_t v14, - uint8_t v15, - uint8_t v16, - uint8_t v17, - uint8_t v18, - uint8_t v19, - uint8_t v20, - uint8_t v21, - uint8_t v22, - uint8_t v23, - uint8_t v24, - uint8_t v25, - uint8_t v26, - uint8_t v27, - uint8_t v28, - uint8_t v29, - uint8_t v30, - uint8_t v31) - : simd8(_mm256_setr_epi8(v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15, - v16, - v17, - v18, - v19, - v20, - v21, - v22, - v23, - v24, - v25, - v26, - v27, - v28, - v29, - v30, - v31)) {} - // Repeat 16 values as many times as necessary (usually for lookup tables) - simdjson_really_inline static simd8 repeat_16(uint8_t v0, - uint8_t v1, - uint8_t v2, - uint8_t v3, - uint8_t v4, - uint8_t v5, - uint8_t v6, - uint8_t v7, - uint8_t v8, - uint8_t v9, - uint8_t v10, - uint8_t v11, - uint8_t v12, - uint8_t v13, - uint8_t v14, - uint8_t v15) { - return simd8(v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15, - v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15); - } - - // Saturated math - simdjson_really_inline simd8 saturating_add( - const simd8 other) const { - return _mm256_adds_epu8(*this, other); - } - simdjson_really_inline simd8 saturating_sub( - const simd8 other) const { - return _mm256_subs_epu8(*this, other); - } - - // Order-specific operations - simdjson_really_inline simd8 max_val( - const simd8 other) const { - return _mm256_max_epu8(*this, other); - } - simdjson_really_inline simd8 min_val( - const simd8 other) const { - return _mm256_min_epu8(other, *this); - } - // Same as >, but only guarantees true is nonzero (< guarantees true = -1) - simdjson_really_inline simd8 gt_bits( - const simd8 other) const { - return this->saturating_sub(other); - } - // Same as <, but only guarantees true is nonzero (< guarantees true = -1) - simdjson_really_inline simd8 lt_bits( - const simd8 other) const { - return other.saturating_sub(*this); - } - simdjson_really_inline simd8 operator<=( - const simd8 other) const { - return other.max_val(*this) == other; - } - simdjson_really_inline simd8 operator>=( - const simd8 other) const { - return other.min_val(*this) == other; - } - simdjson_really_inline simd8 operator>( - const simd8 other) const { - return this->gt_bits(other).any_bits_set(); - } - simdjson_really_inline simd8 operator<( - const simd8 other) const { - return this->lt_bits(other).any_bits_set(); - } - - // Bit-specific operations - simdjson_really_inline simd8 bits_not_set() const { - return *this == uint8_t(0); - } - simdjson_really_inline simd8 bits_not_set(simd8 bits) const { - return (*this & bits).bits_not_set(); - } - simdjson_really_inline simd8 any_bits_set() const { - return ~this->bits_not_set(); - } - simdjson_really_inline simd8 any_bits_set(simd8 bits) const { - return ~this->bits_not_set(bits); - } - simdjson_really_inline bool is_ascii() const { - return _mm256_movemask_epi8(*this) == 0; - } - simdjson_really_inline bool bits_not_set_anywhere() const { - return _mm256_testz_si256(*this, *this); - } - simdjson_really_inline bool any_bits_set_anywhere() const { - return !bits_not_set_anywhere(); - } - simdjson_really_inline bool bits_not_set_anywhere( - simd8 bits) const { - return _mm256_testz_si256(*this, bits); - } - simdjson_really_inline bool any_bits_set_anywhere( - simd8 bits) const { - return !bits_not_set_anywhere(bits); - } - template - simdjson_really_inline simd8 shr() const { - return simd8(_mm256_srli_epi16(*this, N)) & - uint8_t(0xFFu >> N); - } - template - simdjson_really_inline simd8 shl() const { - return simd8(_mm256_slli_epi16(*this, N)) & - uint8_t(0xFFu << N); - } - // Get one of the bits and make a bitmask out of it. - // e.g. value.get_bit<7>() gets the high bit - template - simdjson_really_inline int get_bit() const { - return _mm256_movemask_epi8(_mm256_slli_epi16(*this, 7 - N)); - } -}; - -template -struct simd8x64 { - static constexpr int NUM_CHUNKS = 64 / sizeof(simd8); - static_assert(NUM_CHUNKS == 2, - "Haswell kernel should use two registers per 64-byte block."); - const simd8 chunks[NUM_CHUNKS]; - - simd8x64(const simd8x64 &o) = delete; // no copy allowed - simd8x64 &operator=(const simd8 &other) = - delete; // no assignment allowed - simd8x64() = delete; // no default constructor allowed - - simdjson_really_inline simd8x64(const simd8 chunk0, - const simd8 chunk1) - : chunks{chunk0, chunk1} {} - simdjson_really_inline simd8x64(const T ptr[64]) - : chunks{simd8::load(ptr), simd8::load(ptr + 32)} {} - - simdjson_really_inline uint64_t compress(uint64_t mask, T *output) const { - uint32_t mask1 = uint32_t(mask); - uint32_t mask2 = uint32_t(mask >> 32); - this->chunks[0].compress(mask1, output); - this->chunks[1].compress(mask2, output + 32 - count_ones(mask1)); - return 64 - count_ones(mask); - } - - simdjson_really_inline void store(T ptr[64]) const { - this->chunks[0].store(ptr + sizeof(simd8) * 0); - this->chunks[1].store(ptr + sizeof(simd8) * 1); - } - - simdjson_really_inline uint64_t to_bitmask() const { - uint64_t r_lo = uint32_t(this->chunks[0].to_bitmask()); - uint64_t r_hi = this->chunks[1].to_bitmask(); - return r_lo | (r_hi << 32); - } - - simdjson_really_inline simd8 reduce_or() const { - return this->chunks[0] | this->chunks[1]; - } - - simdjson_really_inline simd8x64 bit_or(const T m) const { - const simd8 mask = simd8::splat(m); - return simd8x64(this->chunks[0] | mask, this->chunks[1] | mask); - } - - simdjson_really_inline uint64_t eq(const T m) const { - const simd8 mask = simd8::splat(m); - return simd8x64(this->chunks[0] == mask, this->chunks[1] == mask) - .to_bitmask(); - } - - simdjson_really_inline uint64_t eq(const simd8x64 &other) const { - return simd8x64(this->chunks[0] == other.chunks[0], - this->chunks[1] == other.chunks[1]) - .to_bitmask(); - } - - simdjson_really_inline uint64_t lteq(const T m) const { - const simd8 mask = simd8::splat(m); - return simd8x64(this->chunks[0] <= mask, this->chunks[1] <= mask) - .to_bitmask(); - } -}; // struct simd8x64 - -} // namespace simd - -} // unnamed namespace -} // namespace haswell -} // namespace simdjson - -#endif // SIMDJSON_HASWELL_SIMD_H -/* end file include/simdjson/haswell/simd.h */ -/* begin file include/simdjson/generic/jsoncharutils.h */ - -namespace simdjson { -namespace haswell { -namespace { -namespace jsoncharutils { - -// return non-zero if not a structural or whitespace char -// zero otherwise -simdjson_really_inline uint32_t is_not_structural_or_whitespace(uint8_t c) { - return internal::structural_or_whitespace_negated[c]; -} - -simdjson_really_inline uint32_t is_structural_or_whitespace(uint8_t c) { - return internal::structural_or_whitespace[c]; -} - -// returns a value with the high 16 bits set if not valid -// otherwise returns the conversion of the 4 hex digits at src into the bottom -// 16 bits of the 32-bit return register -// -// see -// https://lemire.me/blog/2019/04/17/parsing-short-hexadecimal-strings-efficiently/ -static inline uint32_t hex_to_u32_nocheck( - const uint8_t *src) { // strictly speaking, static inline is a C-ism - uint32_t v1 = internal::digit_to_val32[630 + src[0]]; - uint32_t v2 = internal::digit_to_val32[420 + src[1]]; - uint32_t v3 = internal::digit_to_val32[210 + src[2]]; - uint32_t v4 = internal::digit_to_val32[0 + src[3]]; - return v1 | v2 | v3 | v4; -} - -// given a code point cp, writes to c -// the utf-8 code, outputting the length in -// bytes, if the length is zero, the code point -// is invalid -// -// This can possibly be made faster using pdep -// and clz and table lookups, but JSON documents -// have few escaped code points, and the following -// function looks cheap. -// -// Note: we assume that surrogates are treated separately -// -simdjson_really_inline size_t codepoint_to_utf8(uint32_t cp, uint8_t *c) { - if (cp <= 0x7F) { - c[0] = uint8_t(cp); - return 1; // ascii - } - if (cp <= 0x7FF) { - c[0] = uint8_t((cp >> 6) + 192); - c[1] = uint8_t((cp & 63) + 128); - return 2; // universal plane - // Surrogates are treated elsewhere... - //} //else if (0xd800 <= cp && cp <= 0xdfff) { - // return 0; // surrogates // could put assert here - } else if (cp <= 0xFFFF) { - c[0] = uint8_t((cp >> 12) + 224); - c[1] = uint8_t(((cp >> 6) & 63) + 128); - c[2] = uint8_t((cp & 63) + 128); - return 3; - } else if (cp <= - 0x10FFFF) { // if you know you have a valid code point, this - // is not needed - c[0] = uint8_t((cp >> 18) + 240); - c[1] = uint8_t(((cp >> 12) & 63) + 128); - c[2] = uint8_t(((cp >> 6) & 63) + 128); - c[3] = uint8_t((cp & 63) + 128); - return 4; - } - // will return 0 when the code point was too large. - return 0; // bad r -} - -#ifdef SIMDJSON_IS_32BITS // _umul128 for x86, arm -// this is a slow emulation routine for 32-bit -// -static simdjson_really_inline uint64_t __emulu(uint32_t x, uint32_t y) { - return x * (uint64_t)y; -} -static simdjson_really_inline uint64_t _umul128(uint64_t ab, - uint64_t cd, - uint64_t *hi) { - uint64_t ad = __emulu((uint32_t)(ab >> 32), (uint32_t)cd); - uint64_t bd = __emulu((uint32_t)ab, (uint32_t)cd); - uint64_t adbc = ad + __emulu((uint32_t)ab, (uint32_t)(cd >> 32)); - uint64_t adbc_carry = !!(adbc < ad); - uint64_t lo = bd + (adbc << 32); - *hi = __emulu((uint32_t)(ab >> 32), (uint32_t)(cd >> 32)) + (adbc >> 32) + - (adbc_carry << 32) + !!(lo < bd); - return lo; -} -#endif - -using internal::value128; - -simdjson_really_inline value128 full_multiplication(uint64_t value1, - uint64_t value2) { - value128 answer; -#if defined(SIMDJSON_REGULAR_VISUAL_STUDIO) || defined(SIMDJSON_IS_32BITS) -#ifdef _M_ARM64 - // ARM64 has native support for 64-bit multiplications, no need to emultate - answer.high = __umulh(value1, value2); - answer.low = value1 * value2; -#else - answer.low = _umul128( - value1, value2, &answer.high); // _umul128 not available on ARM64 -#endif // _M_ARM64 -#else // defined(SIMDJSON_REGULAR_VISUAL_STUDIO) || defined(SIMDJSON_IS_32BITS) - __uint128_t r = (static_cast<__uint128_t>(value1)) * value2; - answer.low = uint64_t(r); - answer.high = uint64_t(r >> 64); -#endif - return answer; -} - -} // namespace jsoncharutils -} // unnamed namespace -} // namespace haswell -} // namespace simdjson -/* end file include/simdjson/generic/jsoncharutils.h */ -/* begin file include/simdjson/generic/atomparsing.h */ -namespace simdjson { -namespace haswell { -namespace { -/// @private -namespace atomparsing { - -// The string_to_uint32 is exclusively used to map literal strings to 32-bit -// values. -// We use memcpy instead of a pointer cast to avoid undefined behaviors since we -// cannot -// be certain that the character pointer will be properly aligned. -// You might think that using memcpy makes this function expensive, but you'd be -// wrong. -// All decent optimizing compilers (GCC, clang, Visual Studio) will compile -// string_to_uint32("false"); -// to the compile-time constant 1936482662. -simdjson_really_inline uint32_t string_to_uint32(const char *str) { - uint32_t val; - std::memcpy(&val, str, sizeof(uint32_t)); - return val; -} - - -// Again in str4ncmp we use a memcpy to avoid undefined behavior. The memcpy may -// appear expensive. -// Yet all decent optimizing compilers will compile memcpy to a single -// instruction, just about. -simdjson_warn_unused simdjson_really_inline uint32_t -str4ncmp(const uint8_t *src, const char *atom) { - uint32_t - srcval; // we want to avoid unaligned 32-bit loads (undefined in C/C++) - static_assert(sizeof(uint32_t) <= SIMDJSON_PADDING, - "SIMDJSON_PADDING must be larger than 4 bytes"); - std::memcpy(&srcval, src, sizeof(uint32_t)); - return srcval ^ string_to_uint32(atom); -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_true_atom( - const uint8_t *src) { - return (str4ncmp(src, "true") | - jsoncharutils::is_not_structural_or_whitespace(src[4])) == 0; -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_true_atom( - const uint8_t *src, size_t len) { - if (len > 4) { - return is_valid_true_atom(src); - } else if (len == 4) { - return !str4ncmp(src, "true"); - } else { - return false; - } -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_false_atom( - const uint8_t *src) { - return (str4ncmp(src + 1, "alse") | - jsoncharutils::is_not_structural_or_whitespace(src[5])) == 0; -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_false_atom( - const uint8_t *src, size_t len) { - if (len > 5) { - return is_valid_false_atom(src); - } else if (len == 5) { - return !str4ncmp(src + 1, "alse"); - } else { - return false; - } -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_null_atom( - const uint8_t *src) { - return (str4ncmp(src, "null") | - jsoncharutils::is_not_structural_or_whitespace(src[4])) == 0; -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_null_atom( - const uint8_t *src, size_t len) { - if (len > 4) { - return is_valid_null_atom(src); - } else if (len == 4) { - return !str4ncmp(src, "null"); - } else { - return false; - } -} - -} // namespace atomparsing -} // unnamed namespace -} // namespace haswell -} // namespace simdjson -/* end file include/simdjson/generic/atomparsing.h */ -/* begin file include/simdjson/haswell/stringparsing.h */ -#ifndef SIMDJSON_HASWELL_STRINGPARSING_H -#define SIMDJSON_HASWELL_STRINGPARSING_H - - -namespace simdjson { -namespace haswell { -namespace { - -using namespace simd; - -// Holds backslashes and quotes locations. -struct backslash_and_quote { - public: - static constexpr uint32_t BYTES_PROCESSED = 32; - simdjson_really_inline static backslash_and_quote copy_and_find( - const uint8_t *src, uint8_t *dst); - - simdjson_really_inline bool has_quote_first() { - return ((bs_bits - 1) & quote_bits) != 0; - } - simdjson_really_inline bool has_backslash() { - return ((quote_bits - 1) & bs_bits) != 0; - } - simdjson_really_inline int quote_index() { - return trailing_zeroes(quote_bits); - } - simdjson_really_inline int backslash_index() { - return trailing_zeroes(bs_bits); - } - - uint32_t bs_bits; - uint32_t quote_bits; -}; // struct backslash_and_quote - -simdjson_really_inline backslash_and_quote -backslash_and_quote::copy_and_find(const uint8_t *src, uint8_t *dst) { - // this can read up to 15 bytes beyond the buffer size, but we require - // SIMDJSON_PADDING of padding - static_assert(SIMDJSON_PADDING >= (BYTES_PROCESSED - 1), - "backslash and quote finder must process fewer than " - "SIMDJSON_PADDING bytes"); - simd8 v(src); - // store to dest unconditionally - we can overwrite the bits we don't like - // later - v.store(dst); - return { - static_cast((v == '\\').to_bitmask()), // bs_bits - static_cast((v == '"').to_bitmask()), // quote_bits - }; -} - -} // unnamed namespace -} // namespace haswell -} // namespace simdjson - -/* begin file include/simdjson/generic/stringparsing.h */ -// This file contains the common code every implementation uses -// It is intended to be included multiple times and compiled multiple times - -namespace simdjson { -namespace haswell { -namespace { -/// @private -namespace stringparsing { - -// begin copypasta -// These chars yield themselves: " \ / -// b -> backspace, f -> formfeed, n -> newline, r -> cr, t -> horizontal tab -// u not handled in this table as it's complex -static const uint8_t escape_map[256] = { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x0. - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0x22, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x2f, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x4. - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x5c, 0, 0, 0, // 0x5. - 0, 0, 0x08, 0, 0, 0, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0x0a, 0, // 0x6. - 0, 0, 0x0d, 0, 0x09, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x7. - - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -}; - -// handle a unicode codepoint -// write appropriate values into dest -// src will advance 6 bytes or 12 bytes -// dest will advance a variable amount (return via pointer) -// return true if the unicode codepoint was valid -// We work in little-endian then swap at write time -simdjson_warn_unused simdjson_really_inline bool handle_unicode_codepoint( - const uint8_t **src_ptr, uint8_t **dst_ptr) { - // jsoncharutils::hex_to_u32_nocheck fills high 16 bits of the return value - // with 1s if the - // conversion isn't valid; we defer the check for this to inside the - // multilingual plane check - uint32_t code_point = jsoncharutils::hex_to_u32_nocheck(*src_ptr + 2); - *src_ptr += 6; - // check for low surrogate for characters outside the Basic - // Multilingual Plane. - if (code_point >= 0xd800 && code_point < 0xdc00) { - if (((*src_ptr)[0] != '\\') || (*src_ptr)[1] != 'u') { - return false; - } - uint32_t code_point_2 = jsoncharutils::hex_to_u32_nocheck(*src_ptr + 2); - - // if the first code point is invalid we will get here, as we will go - // past - // the check for being outside the Basic Multilingual plane. If we don't - // find a \u immediately afterwards we fail out anyhow, but if we do, - // this check catches both the case of the first code point being - // invalid - // or the second code point being invalid. - if ((code_point | code_point_2) >> 16) { - return false; - } - - code_point = - (((code_point - 0xd800) << 10) | (code_point_2 - 0xdc00)) + 0x10000; - *src_ptr += 6; - } - size_t offset = jsoncharutils::codepoint_to_utf8(code_point, *dst_ptr); - *dst_ptr += offset; - return offset > 0; -} - -/** - * Unescape a string from src to dst, stopping at a final unescaped quote. E.g., - * if src points at 'joe"', then - * dst needs to have four free bytes. - */ -simdjson_warn_unused simdjson_really_inline uint8_t *parse_string( - const uint8_t *src, uint8_t *dst) { - while (1) { - // Copy the next n bytes, and find the backslash and quote in them. - auto bs_quote = backslash_and_quote::copy_and_find(src, dst); - // If the next thing is the end quote, copy and return - if (bs_quote.has_quote_first()) { - // we encountered quotes first. Move dst to point to quotes and exit - return dst + bs_quote.quote_index(); - } - if (bs_quote.has_backslash()) { - /* find out where the backspace is */ - auto bs_dist = bs_quote.backslash_index(); - uint8_t escape_char = src[bs_dist + 1]; - /* we encountered backslash first. Handle backslash */ - if (escape_char == 'u') { - /* move src/dst up to the start; they will be further adjusted - within the unicode codepoint handling code. */ - src += bs_dist; - dst += bs_dist; - if (!handle_unicode_codepoint(&src, &dst)) { - return nullptr; - } - } else { - /* simple 1:1 conversion. Will eat bs_dist+2 characters in input - * and - * write bs_dist+1 characters to output - * note this may reach beyond the part of the buffer we've - * actually - * seen. I think this is ok */ - uint8_t escape_result = escape_map[escape_char]; - if (escape_result == 0u) { - return nullptr; /* bogus escape value is an error */ - } - dst[bs_dist] = escape_result; - src += bs_dist + 2; - dst += bs_dist + 1; - } - } else { - /* they are the same. Since they can't co-occur, it means we - * encountered neither. */ - src += backslash_and_quote::BYTES_PROCESSED; - dst += backslash_and_quote::BYTES_PROCESSED; - } - } - /* can't be reached */ - return nullptr; -} - -simdjson_unused simdjson_warn_unused simdjson_really_inline error_code -parse_string_to_buffer(const uint8_t *src, - uint8_t *¤t_string_buf_loc, - std::string_view &s) { - if (*(src++) != '"') { - return STRING_ERROR; - } - auto end = stringparsing::parse_string(src, current_string_buf_loc); - if (!end) { - return STRING_ERROR; - } - s = std::string_view(reinterpret_cast(current_string_buf_loc), - end - current_string_buf_loc); - current_string_buf_loc = end; - return SUCCESS; -} - -} // namespace stringparsing -} // unnamed namespace -} // namespace haswell -} // namespace simdjson -/* end file include/simdjson/generic/stringparsing.h */ - -#endif // SIMDJSON_HASWELL_STRINGPARSING_H -/* end file include/simdjson/haswell/stringparsing.h */ -/* begin file include/simdjson/haswell/numberparsing.h */ -#ifndef SIMDJSON_HASWELL_NUMBERPARSING_H -#define SIMDJSON_HASWELL_NUMBERPARSING_H - -namespace simdjson { -namespace haswell { -namespace { - -static simdjson_really_inline uint32_t -parse_eight_digits_unrolled(const uint8_t *chars) { - // this actually computes *16* values so we are being wasteful. - const __m128i ascii0 = _mm_set1_epi8('0'); - const __m128i mul_1_10 = - _mm_setr_epi8(10, 1, 10, 1, 10, 1, 10, 1, 10, 1, 10, 1, 10, 1, 10, 1); - const __m128i mul_1_100 = _mm_setr_epi16(100, 1, 100, 1, 100, 1, 100, 1); - const __m128i mul_1_10000 = - _mm_setr_epi16(10000, 1, 10000, 1, 10000, 1, 10000, 1); - const __m128i input = _mm_sub_epi8( - _mm_loadu_si128(reinterpret_cast(chars)), ascii0); - const __m128i t1 = _mm_maddubs_epi16(input, mul_1_10); - const __m128i t2 = _mm_madd_epi16(t1, mul_1_100); - const __m128i t3 = _mm_packus_epi32(t2, t2); - const __m128i t4 = _mm_madd_epi16(t3, mul_1_10000); - return _mm_cvtsi128_si32( - t4); // only captures the sum of the first 8 digits, drop the rest -} - -} // unnamed namespace -} // namespace haswell -} // namespace simdjson - -#define SIMDJSON_SWAR_NUMBER_PARSING 1 - -/* begin file include/simdjson/generic/numberparsing.h */ -#include - -namespace simdjson { -namespace haswell { - -namespace ondemand { -/** - * The type of a JSON number - */ -enum class number_type { - floating_point_number = 1, /// a binary64 number - signed_integer, /// a signed integer that fits in a 64-bit word using two's - /// complement - unsigned_integer /// a positive integer larger or equal to 1<<63 -}; -} - -namespace { -/// @private -namespace numberparsing { - - -#ifdef JSON_TEST_NUMBERS -#define INVALID_NUMBER(SRC) (found_invalid_number((SRC)), NUMBER_ERROR) -#define WRITE_INTEGER(VALUE, SRC, WRITER) \ - (found_integer((VALUE), (SRC)), (WRITER).append_s64((VALUE))) -#define WRITE_UNSIGNED(VALUE, SRC, WRITER) \ - (found_unsigned_integer((VALUE), (SRC)), (WRITER).append_u64((VALUE))) -#define WRITE_DOUBLE(VALUE, SRC, WRITER) \ - (found_float((VALUE), (SRC)), (WRITER).append_double((VALUE))) -#else -#define INVALID_NUMBER(SRC) (NUMBER_ERROR) -#define WRITE_INTEGER(VALUE, SRC, WRITER) (WRITER).append_s64((VALUE)) -#define WRITE_UNSIGNED(VALUE, SRC, WRITER) (WRITER).append_u64((VALUE)) -#define WRITE_DOUBLE(VALUE, SRC, WRITER) (WRITER).append_double((VALUE)) -#endif - -namespace { -// Convert a mantissa, an exponent and a sign bit into an ieee64 double. -// The real_exponent needs to be in [0, 2046] (technically real_exponent = 2047 -// would be acceptable). -// The mantissa should be in [0,1<<53). The bit at index (1ULL << 52) while be -// zeroed. -simdjson_really_inline double to_double(uint64_t mantissa, - uint64_t real_exponent, - bool negative) { - double d; - mantissa &= ~(1ULL << 52); - mantissa |= real_exponent << 52; - mantissa |= ((static_cast(negative)) << 63); - std::memcpy(&d, &mantissa, sizeof(d)); - return d; -} -} -// Attempts to compute i * 10^(power) exactly; and if "negative" is -// true, negate the result. -// This function will only work in some cases, when it does not work, success is -// set to false. This should work *most of the time* (like 99% of the time). -// We assume that power is in the [smallest_power, -// largest_power] interval: the caller is responsible for this check. -simdjson_really_inline bool compute_float_64(int64_t power, - uint64_t i, - bool negative, - double &d) { -// we start with a fast path -// It was described in -// Clinger WD. How to read floating point numbers accurately. -// ACM SIGPLAN Notices. 1990 -#ifndef FLT_EVAL_METHOD -#error "FLT_EVAL_METHOD should be defined, please include cfloat." -#endif -#if (FLT_EVAL_METHOD != 1) && (FLT_EVAL_METHOD != 0) - // We cannot be certain that x/y is rounded to nearest. - if (0 <= power && power <= 22 && i <= 9007199254740991) { -#else - if (-22 <= power && power <= 22 && i <= 9007199254740991) { -#endif - // convert the integer into a double. This is lossless since - // 0 <= i <= 2^53 - 1. - d = double(i); - // - // The general idea is as follows. - // If 0 <= s < 2^53 and if 10^0 <= p <= 10^22 then - // 1) Both s and p can be represented exactly as 64-bit floating-point - // values - // (binary64). - // 2) Because s and p can be represented exactly as floating-point - // values, - // then s * p - // and s / p will produce correctly rounded values. - // - if (power < 0) { - d = d / simdjson::internal::power_of_ten[-power]; - } else { - d = d * simdjson::internal::power_of_ten[power]; - } - if (negative) { - d = -d; - } - return true; - } - // When 22 < power && power < 22 + 16, we could - // hope for another, secondary fast path. It was - // described by David M. Gay in "Correctly rounded - // binary-decimal and decimal-binary conversions." (1990) - // If you need to compute i * 10^(22 + x) for x < 16, - // first compute i * 10^x, if you know that result is exact - // (e.g., when i * 10^x < 2^53), - // then you can still proceed and do (i * 10^x) * 10^22. - // Is this worth your time? - // You need 22 < power *and* power < 22 + 16 *and* (i * 10^(x-22) < 2^53) - // for this second fast path to work. - // If you you have 22 < power *and* power < 22 + 16, and then you - // optimistically compute "i * 10^(x-22)", there is still a chance that you - // have wasted your time if i * 10^(x-22) >= 2^53. It makes the use cases of - // this optimization maybe less common than we would like. Source: - // http://www.exploringbinary.com/fast-path-decimal-to-floating-point-conversion/ - // also used in RapidJSON: https://rapidjson.org/strtod_8h_source.html - - // The fast path has now failed, so we are failing back on the slower path. - - // In the slow path, we need to adjust i so that it is > 1<<63 which is - // always - // possible, except if i == 0, so we handle i == 0 separately. - if (i == 0) { - d = 0.0; - return true; - } - - - // The exponent is 1024 + 63 + power - // + floor(log(5**power)/log(2)). - // The 1024 comes from the ieee64 standard. - // The 63 comes from the fact that we use a 64-bit word. - // - // Computing floor(log(5**power)/log(2)) could be - // slow. Instead we use a fast function. - // - // For power in (-400,350), we have that - // (((152170 + 65536) * power ) >> 16); - // is equal to - // floor(log(5**power)/log(2)) + power when power >= 0 - // and it is equal to - // ceil(log(5**-power)/log(2)) + power when power < 0 - // - // The 65536 is (1<<16) and corresponds to - // (65536 * power) >> 16 ---> power - // - // ((152170 * power ) >> 16) is equal to - // floor(log(5**power)/log(2)) - // - // Note that this is not magic: 152170/(1<<16) is - // approximatively equal to log(5)/log(2). - // The 1<<16 value is a power of two; we could use a - // larger power of 2 if we wanted to. - // - int64_t exponent = (((152170 + 65536) * power) >> 16) + 1024 + 63; - - - // We want the most significant bit of i to be 1. Shift if needed. - int lz = leading_zeroes(i); - i <<= lz; - - - // We are going to need to do some 64-bit arithmetic to get a precise - // product. - // We use a table lookup approach. - // It is safe because - // power >= smallest_power - // and power <= largest_power - // We recover the mantissa of the power, it has a leading 1. It is always - // rounded down. - // - // We want the most significant 64 bits of the product. We know - // this will be non-zero because the most significant bit of i is - // 1. - const uint32_t index = - 2 * uint32_t(power - simdjson::internal::smallest_power); - // Optimization: It may be that materializing the index as a variable might - // confuse some compilers and prevent effective complex-addressing loads. - // (Done for code clarity.) - // - // The full_multiplication function computes the 128-bit product of two - // 64-bit words - // with a returned value of type value128 with a "low component" - // corresponding to the - // 64-bit least significant bits of the product and with a "high component" - // corresponding - // to the 64-bit most significant bits of the product. - simdjson::internal::value128 firstproduct = - jsoncharutils::full_multiplication( - i, simdjson::internal::power_of_five_128[index]); - // Both i and power_of_five_128[index] have their most significant bit set - // to 1 which - // implies that the either the most or the second most significant bit of - // the product - // is 1. We pack values in this manner for efficiency reasons: it maximizes - // the use - // we make of the product. It also makes it easy to reason about the - // product: there - // is 0 or 1 leading zero in the product. - - // Unless the least significant 9 bits of the high (64-bit) part of the full - // product are all 1s, then we know that the most significant 55 bits are - // exact and no further work is needed. Having 55 bits is necessary because - // we need 53 bits for the mantissa but we have to have one rounding bit and - // we can waste a bit if the most significant bit of the product is zero. - if ((firstproduct.high & 0x1FF) == 0x1FF) { - // We want to compute i * 5^q, but only care about the top 55 bits at - // most. - // Consider the scenario where q>=0. Then 5^q may not fit in 64-bits. - // Doing - // the full computation is wasteful. So we do what is called a - // "truncated - // multiplication". - // We take the most significant 64-bits, and we put them in - // power_of_five_128[index]. Usually, that's good enough to approximate - // i * 5^q - // to the desired approximation using one multiplication. Sometimes it - // does not suffice. - // Then we store the next most significant 64 bits in - // power_of_five_128[index + 1], and - // then we get a better approximation to i * 5^q. In very rare cases, - // even that - // will not suffice, though it is seemingly very hard to find such a - // scenario. - // - // That's for when q>=0. The logic for q<0 is somewhat similar but it is - // somewhat - // more complicated. - // - // There is an extra layer of complexity in that we need more than 55 - // bits of - // accuracy in the round-to-even scenario. - // - // The full_multiplication function computes the 128-bit product of two - // 64-bit words - // with a returned value of type value128 with a "low component" - // corresponding to the - // 64-bit least significant bits of the product and with a "high - // component" corresponding - // to the 64-bit most significant bits of the product. - simdjson::internal::value128 secondproduct = - jsoncharutils::full_multiplication( - i, simdjson::internal::power_of_five_128[index + 1]); - firstproduct.low += secondproduct.high; - if (secondproduct.high > firstproduct.low) { - firstproduct.high++; - } - // At this point, we might need to add at most one to firstproduct, but - // this - // can only change the value of firstproduct.high if firstproduct.low is - // maximal. - if (simdjson_unlikely(firstproduct.low == 0xFFFFFFFFFFFFFFFF)) { - // This is very unlikely, but if so, we need to do much more work! - return false; - } - } - uint64_t lower = firstproduct.low; - uint64_t upper = firstproduct.high; - // The final mantissa should be 53 bits with a leading 1. - // We shift it so that it occupies 54 bits with a leading 1. - /////// - uint64_t upperbit = upper >> 63; - uint64_t mantissa = upper >> (upperbit + 9); - lz += int(1 ^ upperbit); - - // Here we have mantissa < (1<<54). - int64_t real_exponent = exponent - lz; - if (simdjson_unlikely(real_exponent <= 0)) { // we have a subnormal? - // Here have that real_exponent <= 0 so -real_exponent >= 0 - if (-real_exponent + 1 >= 64) { // if we have more than 64 bits below - // the minimum exponent, you have a - // zero for sure. - d = 0.0; - return true; - } - // next line is safe because -real_exponent + 1 < 0 - mantissa >>= -real_exponent + 1; - // Thankfully, we can't have both "round-to-even" and subnormals because - // "round-to-even" only occurs for powers close to 0. - mantissa += (mantissa & 1); // round up - mantissa >>= 1; - // There is a weird scenario where we don't have a subnormal but just. - // Suppose we start with 2.2250738585072013e-308, we end up - // with 0x3fffffffffffff x 2^-1023-53 which is technically subnormal - // whereas 0x40000000000000 x 2^-1023-53 is normal. Now, we need to - // round - // up 0x3fffffffffffff x 2^-1023-53 and once we do, we are no longer - // subnormal, but we can only know this after rounding. - // So we only declare a subnormal if we are smaller than the threshold. - real_exponent = (mantissa < (uint64_t(1) << 52)) ? 0 : 1; - d = to_double(mantissa, real_exponent, negative); - return true; - } - // We have to round to even. The "to even" part - // is only a problem when we are right in between two floats - // which we guard against. - // If we have lots of trailing zeros, we may fall right between two - // floating-point values. - // - // The round-to-even cases take the form of a number 2m+1 which is in - // (2^53,2^54] - // times a power of two. That is, it is right between a number with binary - // significand - // m and another number with binary significand m+1; and it must be the case - // that it cannot be represented by a float itself. - // - // We must have that w * 10 ^q == (2m+1) * 2^p for some power of two 2^p. - // Recall that 10^q = 5^q * 2^q. - // When q >= 0, we must have that (2m+1) is divible by 5^q, so 5^q <= 2^54. - // We have that - // 5^23 <= 2^54 and it is the last power of five to qualify, so q <= 23. - // When q<0, we have w >= (2m+1) x 5^{-q}. We must have that w<2^{64} so - // (2m+1) x 5^{-q} < 2^{64}. We have that 2m+1>2^{53}. Hence, we must have - // 2^{53} x 5^{-q} < 2^{64}. - // Hence we have 5^{-q} < 2^{11}$ or q>= -4. - // - // We require lower <= 1 and not lower == 0 because we could not prove that - // that lower == 0 is implied; but we could prove that lower <= 1 is a - // necessary and sufficient test. - if (simdjson_unlikely((lower <= 1) && (power >= -4) && (power <= 23) && - ((mantissa & 3) == 1))) { - if ((mantissa << (upperbit + 64 - 53 - 2)) == upper) { - mantissa &= ~1; // flip it so that we do not round up - } - } - - mantissa += mantissa & 1; - mantissa >>= 1; - - // Here we have mantissa < (1<<53), unless there was an overflow - if (mantissa >= (1ULL << 53)) { - ////////// - // This will happen when parsing values such as 7.2057594037927933e+16 - //////// - mantissa = (1ULL << 52); - real_exponent++; - } - mantissa &= ~(1ULL << 52); - // we have to check that real_exponent is in range, otherwise we bail out - if (simdjson_unlikely(real_exponent > 2046)) { - // We have an infinite value!!! We could actually throw an error here if - // we could. - return false; - } - d = to_double(mantissa, real_exponent, negative); - return true; -} - -// We call a fallback floating-point parser that might be slow. Note -// it will accept JSON numbers, but the JSON spec. is more restrictive so -// before you call parse_float_fallback, you need to have validated the input -// string with the JSON grammar. -// It will return an error (false) if the parsed number is infinite. -// The string parsing itself always succeeds. We know that there is at least -// one digit. -static bool parse_float_fallback(const uint8_t *ptr, double *outDouble) { - *outDouble = - simdjson::internal::from_chars(reinterpret_cast(ptr)); - // We do not accept infinite values. - - // Detecting finite values in a portable manner is ridiculously hard, - // ideally - // we would want to do: - // return !std::isfinite(*outDouble); - // but that mysteriously fails under legacy/old libc++ libraries, see - // https://github.com/simdjson/simdjson/issues/1286 - // - // Therefore, fall back to this solution (the extra parens are there - // to handle that max may be a macro on windows). - return !(*outDouble > (std::numeric_limits::max)() || - *outDouble < std::numeric_limits::lowest()); -} -static bool parse_float_fallback(const uint8_t *ptr, - const uint8_t *end_ptr, - double *outDouble) { - *outDouble = - simdjson::internal::from_chars(reinterpret_cast(ptr), - reinterpret_cast(end_ptr)); - // We do not accept infinite values. - - // Detecting finite values in a portable manner is ridiculously hard, - // ideally - // we would want to do: - // return !std::isfinite(*outDouble); - // but that mysteriously fails under legacy/old libc++ libraries, see - // https://github.com/simdjson/simdjson/issues/1286 - // - // Therefore, fall back to this solution (the extra parens are there - // to handle that max may be a macro on windows). - return !(*outDouble > (std::numeric_limits::max)() || - *outDouble < std::numeric_limits::lowest()); -} - -// check quickly whether the next 8 chars are made of digits -// at a glance, it looks better than Mula's -// http://0x80.pl/articles/swar-digits-validate.html -simdjson_really_inline bool is_made_of_eight_digits_fast(const uint8_t *chars) { - uint64_t val; - // this can read up to 7 bytes beyond the buffer size, but we require - // SIMDJSON_PADDING of padding - static_assert(7 <= SIMDJSON_PADDING, - "SIMDJSON_PADDING must be bigger than 7"); - std::memcpy(&val, chars, 8); - // a branchy method might be faster: - // return (( val & 0xF0F0F0F0F0F0F0F0 ) == 0x3030303030303030) - // && (( (val + 0x0606060606060606) & 0xF0F0F0F0F0F0F0F0 ) == - // 0x3030303030303030); - return (((val & 0xF0F0F0F0F0F0F0F0) | - (((val + 0x0606060606060606) & 0xF0F0F0F0F0F0F0F0) >> 4)) == - 0x3333333333333333); -} - -template -error_code slow_float_parsing(simdjson_unused const uint8_t *src, W writer) { - double d; - if (parse_float_fallback(src, &d)) { - writer.append_double(d); - return SUCCESS; - } - return INVALID_NUMBER(src); -} - -template -SIMDJSON_NO_SANITIZE_UNDEFINED // We deliberately allow overflow here and check - // later - simdjson_really_inline bool - parse_digit(const uint8_t c, I &i) { - const uint8_t digit = static_cast(c - '0'); - if (digit > 9) { - return false; - } - // PERF NOTE: multiplication by 10 is cheaper than arbitrary integer - // multiplication - i = 10 * i + digit; // might overflow, we will handle the overflow later - return true; -} - -simdjson_really_inline error_code -parse_decimal(simdjson_unused const uint8_t *const src, - const uint8_t *&p, - uint64_t &i, - int64_t &exponent) { - // we continue with the fiction that we have an integer. If the - // floating point number is representable as x * 10^z for some integer - // z that fits in 53 bits, then we will be able to convert back the - // the integer into a float in a lossless manner. - const uint8_t *const first_after_period = p; - -#ifdef SIMDJSON_SWAR_NUMBER_PARSING -#if SIMDJSON_SWAR_NUMBER_PARSING - // this helps if we have lots of decimals! - // this turns out to be frequent enough. - if (is_made_of_eight_digits_fast(p)) { - i = i * 100000000 + parse_eight_digits_unrolled(p); - p += 8; - } -#endif // SIMDJSON_SWAR_NUMBER_PARSING -#endif // #ifdef SIMDJSON_SWAR_NUMBER_PARSING - // Unrolling the first digit makes a small difference on some - // implementations (e.g. westmere) - if (parse_digit(*p, i)) { - ++p; - } - while (parse_digit(*p, i)) { - p++; - } - exponent = first_after_period - p; - // Decimal without digits (123.) is illegal - if (exponent == 0) { - return INVALID_NUMBER(src); - } - return SUCCESS; -} - -simdjson_really_inline error_code -parse_exponent(simdjson_unused const uint8_t *const src, - const uint8_t *&p, - int64_t &exponent) { - // Exp Sign: -123.456e[-]78 - bool neg_exp = ('-' == *p); - if (neg_exp || '+' == *p) { - p++; - } // Skip + as well - - // Exponent: -123.456e-[78] - auto start_exp = p; - int64_t exp_number = 0; - while (parse_digit(*p, exp_number)) { - ++p; - } - // It is possible for parse_digit to overflow. - // In particular, it could overflow to INT64_MIN, and we cannot do - - // INT64_MIN. - // Thus we *must* check for possible overflow before we negate exp_number. - - // Performance notes: it may seem like combining the two "simdjson_unlikely - // checks" below into - // a single simdjson_unlikely path would be faster. The reasoning is sound, - // but the compiler may - // not oblige and may, in fact, generate two distinct paths in any case. It - // might be - // possible to do uint64_t(p - start_exp - 1) >= 18 but it could end up - // trading off - // instructions for a simdjson_likely branch, an unconclusive gain. - - // If there were no digits, it's an error. - if (simdjson_unlikely(p == start_exp)) { - return INVALID_NUMBER(src); - } - // We have a valid positive exponent in exp_number at this point, except - // that - // it may have overflowed. - - // If there were more than 18 digits, we may have overflowed the integer. We - // have to do - // something!!!! - if (simdjson_unlikely(p > start_exp + 18)) { - // Skip leading zeroes: 1e000000000000000000001 is technically valid and - // doesn't overflow - while (*start_exp == '0') { - start_exp++; - } - // 19 digits could overflow int64_t and is kind of absurd anyway. We - // don't - // support exponents smaller than -999,999,999,999,999,999 and bigger - // than 999,999,999,999,999,999. - // We can truncate. - // Note that 999999999999999999 is assuredly too large. The maximal - // ieee64 value before - // infinity is ~1.8e308. The smallest subnormal is ~5e-324. So, - // actually, we could - // truncate at 324. - // Note that there is no reason to fail per se at this point in time. - // E.g., 0e999999999999999999999 is a fine number. - if (p > start_exp + 18) { - exp_number = 999999999999999999; - } - } - // At this point, we know that exp_number is a sane, positive, signed - // integer. - // It is <= 999,999,999,999,999,999. As long as 'exponent' is in - // [-8223372036854775808, 8223372036854775808], we won't overflow. Because - // 'exponent' - // is bounded in magnitude by the size of the JSON input, we are fine in - // this universe. - // To sum it up: the next line should never overflow. - exponent += (neg_exp ? -exp_number : exp_number); - return SUCCESS; -} - -simdjson_really_inline size_t significant_digits(const uint8_t *start_digits, - size_t digit_count) { - // It is possible that the integer had an overflow. - // We have to handle the case where we have 0.0000somenumber. - const uint8_t *start = start_digits; - while ((*start == '0') || (*start == '.')) { - ++start; - } - // we over-decrement by one when there is a '.' - return digit_count - size_t(start - start_digits); -} - -template -simdjson_really_inline error_code write_float(const uint8_t *const src, - bool negative, - uint64_t i, - const uint8_t *start_digits, - size_t digit_count, - int64_t exponent, - W &writer) { - // If we frequently had to deal with long strings of digits, - // we could extend our code by using a 128-bit integer instead - // of a 64-bit integer. However, this is uncommon in practice. - // - // 9999999999999999999 < 2**64 so we can accommodate 19 digits. - // If we have a decimal separator, then digit_count - 1 is the number of - // digits, but we - // may not have a decimal separator! - if (simdjson_unlikely(digit_count > 19 && - significant_digits(start_digits, digit_count) > 19)) { - // Ok, chances are good that we had an overflow! - // this is almost never going to get called!!! - // we start anew, going slowly!!! - // This will happen in the following examples: - // 10000000000000000000000000000000000000000000e+308 - // 3.1415926535897932384626433832795028841971693993751 - // - // NOTE: This makes a *copy* of the writer and passes it to - // slow_float_parsing. This happens - // because slow_float_parsing is a non-inlined function. If we passed - // our writer reference to - // it, it would force it to be stored in memory, preventing the compiler - // from picking it apart - // and putting into registers. i.e. if we pass it as reference, it gets - // slow. - // This is what forces the skip_double, as well. - error_code error = slow_float_parsing(src, writer); - writer.skip_double(); - return error; - } - // NOTE: it's weird that the simdjson_unlikely() only wraps half the if, but - // it seems to get slower any other - // way we've tried: - // https://github.com/simdjson/simdjson/pull/990#discussion_r448497331 - // To future reader: we'd love if someone found a better way, or at least - // could explain this result! - if (simdjson_unlikely(exponent < simdjson::internal::smallest_power) || - (exponent > simdjson::internal::largest_power)) { - // - // Important: smallest_power is such that it leads to a zero value. - // Observe that 18446744073709551615e-343 == 0, i.e. (2**64 - 1) e -343 - // is zero - // so something x 10^-343 goes to zero, but not so with something x - // 10^-342. - static_assert(simdjson::internal::smallest_power <= -342, - "smallest_power is not small enough"); - // - if ((exponent < simdjson::internal::smallest_power) || (i == 0)) { - WRITE_DOUBLE(0, src, writer); - return SUCCESS; - } else { // (exponent > largest_power) and (i != 0) - // We have, for sure, an infinite value and simdjson refuses to - // parse infinite values. - return INVALID_NUMBER(src); - } - } - double d; - if (!compute_float_64(exponent, i, negative, d)) { - // we are almost never going to get here. - if (!parse_float_fallback(src, &d)) { - return INVALID_NUMBER(src); - } - } - WRITE_DOUBLE(d, src, writer); - return SUCCESS; -} - -// for performance analysis, it is sometimes useful to skip parsing -#ifdef SIMDJSON_SKIPNUMBERPARSING - -template -simdjson_really_inline error_code parse_number(const uint8_t *const, - W &writer) { - writer.append_s64(0); // always write zero - return SUCCESS; // always succeeds -} - -simdjson_unused simdjson_really_inline simdjson_result parse_unsigned( - const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result parse_integer( - const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result parse_double( - const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result -parse_unsigned_in_string(const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result -parse_integer_in_string(const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result -parse_double_in_string(const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline bool is_negative( - const uint8_t *src) noexcept { - return false; -} -simdjson_unused simdjson_really_inline simdjson_result is_integer( - const uint8_t *src) noexcept { - return false; -} -simdjson_unused simdjson_really_inline simdjson_result -get_number_type(const uint8_t *src) noexcept { - return ondemand::number_type::signed_integer; -} -#else - -// parse the number at src -// define JSON_TEST_NUMBERS for unit testing -// -// It is assumed that the number is followed by a structural ({,},],[) character -// or a white space character. If that is not the case (e.g., when the JSON -// document is made of a single number), then it is necessary to copy the -// content and append a space before calling this function. -// -// Our objective is accurate parsing (ULP of 0) at high speed. -template -simdjson_really_inline error_code parse_number(const uint8_t *const src, - W &writer) { - // - // Check for minus sign - // - bool negative = (*src == '-'); - const uint8_t *p = src + negative; - - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - if (digit_count == 0 || ('0' == *start_digits && digit_count > 1)) { - return INVALID_NUMBER(src); - } - - // - // Handle floats if there is a . or e (or both) - // - int64_t exponent = 0; - bool is_float = false; - if ('.' == *p) { - is_float = true; - ++p; - SIMDJSON_TRY(parse_decimal(src, p, i, exponent)); - digit_count = - int(p - start_digits); // used later to guard against overflows - } - if (('e' == *p) || ('E' == *p)) { - is_float = true; - ++p; - SIMDJSON_TRY(parse_exponent(src, p, exponent)); - } - if (is_float) { - const bool dirty_end = - jsoncharutils::is_not_structural_or_whitespace(*p); - SIMDJSON_TRY(write_float( - src, negative, i, start_digits, digit_count, exponent, writer)); - if (dirty_end) { - return INVALID_NUMBER(src); - } - return SUCCESS; - } - - // The longest negative 64-bit number is 19 digits. - // The longest positive 64-bit number is 20 digits. - // We do it this way so we don't trigger this branch unless we must. - size_t longest_digit_count = negative ? 19 : 20; - if (digit_count > longest_digit_count) { - return INVALID_NUMBER(src); - } - if (digit_count == longest_digit_count) { - if (negative) { - // Anything negative above INT64_MAX+1 is invalid - if (i > uint64_t(INT64_MAX) + 1) { - return INVALID_NUMBER(src); - } - WRITE_INTEGER(~i + 1, src, writer); - if (jsoncharutils::is_not_structural_or_whitespace(*p)) { - return INVALID_NUMBER(src); - } - return SUCCESS; - // Positive overflow check: - // - A 20 digit number starting with 2-9 is overflow, because - // 18,446,744,073,709,551,615 is the - // biggest uint64_t. - // - A 20 digit number starting with 1 is overflow if it is less - // than INT64_MAX. - // If we got here, it's a 20 digit number starting with the digit - // "1". - // - If a 20 digit number starting with 1 overflowed (i*10+digit), - // the result will be smaller - // than 1,553,255,926,290,448,384. - // - That is smaller than the smallest possible 20-digit number the - // user could write: - // 10,000,000,000,000,000,000. - // - Therefore, if the number is positive and lower than that, it's - // overflow. - // - The value we are looking at is less than or equal to INT64_MAX. - // - } else if (src[0] != uint8_t('1') || i <= uint64_t(INT64_MAX)) { - return INVALID_NUMBER(src); - } - } - - // Write unsigned if it doesn't fit in a signed integer. - if (i > uint64_t(INT64_MAX)) { - WRITE_UNSIGNED(i, src, writer); - } else { - WRITE_INTEGER(negative ? (~i + 1) : i, src, writer); - } - if (jsoncharutils::is_not_structural_or_whitespace(*p)) { - return INVALID_NUMBER(src); - } - return SUCCESS; -} - -// Inlineable functions -namespace { - -// This table can be used to characterize the final character of an integer -// string. For JSON structural character and allowable white space characters, -// we return SUCCESS. For 'e', '.' and 'E', we return INCORRECT_TYPE. Otherwise -// we return NUMBER_ERROR. -// Optimization note: we could easily reduce the size of the table by half (to -// 128) -// at the cost of an extra branch. -// Optimization note: we want the values to use at most 8 bits (not, e.g., 32 -// bits): -static_assert(error_code(uint8_t(NUMBER_ERROR)) == NUMBER_ERROR, - "bad NUMBER_ERROR cast"); -static_assert(error_code(uint8_t(SUCCESS)) == SUCCESS, "bad NUMBER_ERROR cast"); -static_assert(error_code(uint8_t(INCORRECT_TYPE)) == INCORRECT_TYPE, - "bad NUMBER_ERROR cast"); - -const uint8_t integer_string_finisher[256] = { - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, - SUCCESS, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, - NUMBER_ERROR, INCORRECT_TYPE, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, INCORRECT_TYPE, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, SUCCESS, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, INCORRECT_TYPE, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, - SUCCESS, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR}; - -// Parse any number from 0 to 18,446,744,073,709,551,615 -simdjson_unused simdjson_really_inline simdjson_result parse_unsigned( - const uint8_t *const src) noexcept { - const uint8_t *p = src; - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // The longest positive 64-bit number is 20 digits. - // We do it this way so we don't trigger this branch unless we must. - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > 20)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > 20)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if (integer_string_finisher[*p] != SUCCESS) { - return error_code(integer_string_finisher[*p]); - } - - if (digit_count == 20) { - // Positive overflow check: - // - A 20 digit number starting with 2-9 is overflow, because - // 18,446,744,073,709,551,615 is the - // biggest uint64_t. - // - A 20 digit number starting with 1 is overflow if it is less than - // INT64_MAX. - // If we got here, it's a 20 digit number starting with the digit "1". - // - If a 20 digit number starting with 1 overflowed (i*10+digit), the - // result will be smaller - // than 1,553,255,926,290,448,384. - // - That is smaller than the smallest possible 20-digit number the user - // could write: - // 10,000,000,000,000,000,000. - // - Therefore, if the number is positive and lower than that, it's - // overflow. - // - The value we are looking at is less than or equal to INT64_MAX. - // - if (src[0] != uint8_t('1') || i <= uint64_t(INT64_MAX)) { - return INCORRECT_TYPE; - } - } - - return i; -} - - -// Parse any number from 0 to 18,446,744,073,709,551,615 -// Never read at src_end or beyond -simdjson_unused simdjson_really_inline simdjson_result parse_unsigned( - const uint8_t *const src, const uint8_t *const src_end) noexcept { - const uint8_t *p = src; - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while ((p != src_end) && parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // The longest positive 64-bit number is 20 digits. - // We do it this way so we don't trigger this branch unless we must. - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > 20)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > 20)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if ((p != src_end) && integer_string_finisher[*p] != SUCCESS) { - return error_code(integer_string_finisher[*p]); - } - - if (digit_count == 20) { - // Positive overflow check: - // - A 20 digit number starting with 2-9 is overflow, because - // 18,446,744,073,709,551,615 is the - // biggest uint64_t. - // - A 20 digit number starting with 1 is overflow if it is less than - // INT64_MAX. - // If we got here, it's a 20 digit number starting with the digit "1". - // - If a 20 digit number starting with 1 overflowed (i*10+digit), the - // result will be smaller - // than 1,553,255,926,290,448,384. - // - That is smaller than the smallest possible 20-digit number the user - // could write: - // 10,000,000,000,000,000,000. - // - Therefore, if the number is positive and lower than that, it's - // overflow. - // - The value we are looking at is less than or equal to INT64_MAX. - // - if (src[0] != uint8_t('1') || i <= uint64_t(INT64_MAX)) { - return INCORRECT_TYPE; - } - } - - return i; -} - -// Parse any number from 0 to 18,446,744,073,709,551,615 -simdjson_unused simdjson_really_inline simdjson_result -parse_unsigned_in_string(const uint8_t *const src) noexcept { - const uint8_t *p = src + 1; - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // The longest positive 64-bit number is 20 digits. - // We do it this way so we don't trigger this branch unless we must. - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > 20)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > 20)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if (*p != '"') { - return NUMBER_ERROR; - } - - if (digit_count == 20) { - // Positive overflow check: - // - A 20 digit number starting with 2-9 is overflow, because - // 18,446,744,073,709,551,615 is the - // biggest uint64_t. - // - A 20 digit number starting with 1 is overflow if it is less than - // INT64_MAX. - // If we got here, it's a 20 digit number starting with the digit "1". - // - If a 20 digit number starting with 1 overflowed (i*10+digit), the - // result will be smaller - // than 1,553,255,926,290,448,384. - // - That is smaller than the smallest possible 20-digit number the user - // could write: - // 10,000,000,000,000,000,000. - // - Therefore, if the number is positive and lower than that, it's - // overflow. - // - The value we are looking at is less than or equal to INT64_MAX. - // - // Note: we use src[1] and not src[0] because src[0] is the quote - // character in this - // instance. - if (src[1] != uint8_t('1') || i <= uint64_t(INT64_MAX)) { - return INCORRECT_TYPE; - } - } - - return i; -} - -// Parse any number from -9,223,372,036,854,775,808 to -// 9,223,372,036,854,775,807 -simdjson_unused simdjson_really_inline simdjson_result parse_integer( - const uint8_t *src) noexcept { - // - // Check for minus sign - // - bool negative = (*src == '-'); - const uint8_t *p = src + negative; - - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // We go from - // -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807 - // so we can never represent numbers that have more than 19 digits. - size_t longest_digit_count = 19; - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > longest_digit_count)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > longest_digit_count)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if (integer_string_finisher[*p] != SUCCESS) { - return error_code(integer_string_finisher[*p]); - } - // Negative numbers have can go down to - INT64_MAX - 1 whereas positive - // numbers are limited to INT64_MAX. - // Performance note: This check is only needed when digit_count == - // longest_digit_count but it is - // so cheap that we might as well always make it. - if (i > uint64_t(INT64_MAX) + uint64_t(negative)) { - return INCORRECT_TYPE; - } - return negative ? (~i + 1) : i; -} - -// Parse any number from -9,223,372,036,854,775,808 to -// 9,223,372,036,854,775,807 -// Never read at src_end or beyond -simdjson_unused simdjson_really_inline simdjson_result parse_integer( - const uint8_t *const src, const uint8_t *const src_end) noexcept { - // - // Check for minus sign - // - if (src == src_end) { - return NUMBER_ERROR; - } - bool negative = (*src == '-'); - const uint8_t *p = src + negative; - - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while ((p != src_end) && parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // We go from - // -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807 - // so we can never represent numbers that have more than 19 digits. - size_t longest_digit_count = 19; - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > longest_digit_count)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > longest_digit_count)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if ((p != src_end) && integer_string_finisher[*p] != SUCCESS) { - return error_code(integer_string_finisher[*p]); - } - // Negative numbers have can go down to - INT64_MAX - 1 whereas positive - // numbers are limited to INT64_MAX. - // Performance note: This check is only needed when digit_count == - // longest_digit_count but it is - // so cheap that we might as well always make it. - if (i > uint64_t(INT64_MAX) + uint64_t(negative)) { - return INCORRECT_TYPE; - } - return negative ? (~i + 1) : i; -} - -// Parse any number from -9,223,372,036,854,775,808 to -// 9,223,372,036,854,775,807 -simdjson_unused simdjson_really_inline simdjson_result -parse_integer_in_string(const uint8_t *src) noexcept { - // - // Check for minus sign - // - bool negative = (*(src + 1) == '-'); - const uint8_t *p = src + negative + 1; - - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // We go from - // -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807 - // so we can never represent numbers that have more than 19 digits. - size_t longest_digit_count = 19; - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > longest_digit_count)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > longest_digit_count)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if (*p != '"') { - return NUMBER_ERROR; - } - // Negative numbers have can go down to - INT64_MAX - 1 whereas positive - // numbers are limited to INT64_MAX. - // Performance note: This check is only needed when digit_count == - // longest_digit_count but it is - // so cheap that we might as well always make it. - if (i > uint64_t(INT64_MAX) + uint64_t(negative)) { - return INCORRECT_TYPE; - } - return negative ? (~i + 1) : i; -} - -simdjson_unused simdjson_really_inline simdjson_result parse_double( - const uint8_t *src) noexcept { - // - // Check for minus sign - // - bool negative = (*src == '-'); - src += negative; - - // - // Parse the integer part. - // - uint64_t i = 0; - const uint8_t *p = src; - p += parse_digit(*p, i); - bool leading_zero = (i == 0); - while (parse_digit(*p, i)) { - p++; - } - // no integer digits, or 0123 (zero must be solo) - if (p == src) { - return INCORRECT_TYPE; - } - if ((leading_zero && p != src + 1)) { - return NUMBER_ERROR; - } - - // - // Parse the decimal part. - // - int64_t exponent = 0; - bool overflow; - if (simdjson_likely(*p == '.')) { - p++; - const uint8_t *start_decimal_digits = p; - if (!parse_digit(*p, i)) { - return NUMBER_ERROR; - } // no decimal digits - p++; - while (parse_digit(*p, i)) { - p++; - } - exponent = -(p - start_decimal_digits); - - // Overflow check. More than 19 digits (minus the decimal) may be - // overflow. - overflow = p - src - 1 > 19; - if (simdjson_unlikely(overflow && leading_zero)) { - // Skip leading 0.00000 and see if it still overflows - const uint8_t *start_digits = src + 2; - while (*start_digits == '0') { - start_digits++; - } - overflow = start_digits - src > 19; - } - } else { - overflow = p - src > 19; - } - - // - // Parse the exponent - // - if (*p == 'e' || *p == 'E') { - p++; - bool exp_neg = *p == '-'; - p += exp_neg || *p == '+'; - - uint64_t exp = 0; - const uint8_t *start_exp_digits = p; - while (parse_digit(*p, exp)) { - p++; - } - // no exp digits, or 20+ exp digits - if (p - start_exp_digits == 0 || p - start_exp_digits > 19) { - return NUMBER_ERROR; - } - - exponent += exp_neg ? 0 - exp : exp; - } - - if (jsoncharutils::is_not_structural_or_whitespace(*p)) { - return NUMBER_ERROR; - } - - overflow = overflow || exponent < simdjson::internal::smallest_power || - exponent > simdjson::internal::largest_power; - - // - // Assemble (or slow-parse) the float - // - double d; - if (simdjson_likely(!overflow)) { - if (compute_float_64(exponent, i, negative, d)) { - return d; - } - } - if (!parse_float_fallback(src - negative, &d)) { - return NUMBER_ERROR; - } - return d; -} - -simdjson_unused simdjson_really_inline bool is_negative( - const uint8_t *src) noexcept { - return (*src == '-'); -} - -simdjson_unused simdjson_really_inline simdjson_result is_integer( - const uint8_t *src) noexcept { - bool negative = (*src == '-'); - src += negative; - const uint8_t *p = src; - while (static_cast(*p - '0') <= 9) { - p++; - } - if (p == src) { - return NUMBER_ERROR; - } - if (jsoncharutils::is_structural_or_whitespace(*p)) { - return true; - } - return false; -} - -simdjson_unused simdjson_really_inline simdjson_result -get_number_type(const uint8_t *src) noexcept { - bool negative = (*src == '-'); - src += negative; - const uint8_t *p = src; - while (static_cast(*p - '0') <= 9) { - p++; - } - if (p == src) { - return NUMBER_ERROR; - } - if (jsoncharutils::is_structural_or_whitespace(*p)) { - int digit_count = int(p - src); - if (digit_count >= 19) { - const uint8_t *smaller_big_integer = - reinterpret_cast("9223372036854775808"); - if ((digit_count >= 20) || - (memcmp(src, smaller_big_integer, 19) >= 0)) { - return ondemand::number_type::unsigned_integer; - } - } - return ondemand::number_type::signed_integer; - } - return ondemand::number_type::floating_point_number; -} - -// Never read at src_end or beyond -simdjson_unused simdjson_really_inline simdjson_result parse_double( - const uint8_t *src, const uint8_t *const src_end) noexcept { - if (src == src_end) { - return NUMBER_ERROR; - } - // - // Check for minus sign - // - bool negative = (*src == '-'); - src += negative; - - // - // Parse the integer part. - // - uint64_t i = 0; - const uint8_t *p = src; - if (p == src_end) { - return NUMBER_ERROR; - } - p += parse_digit(*p, i); - bool leading_zero = (i == 0); - while ((p != src_end) && parse_digit(*p, i)) { - p++; - } - // no integer digits, or 0123 (zero must be solo) - if (p == src) { - return INCORRECT_TYPE; - } - if ((leading_zero && p != src + 1)) { - return NUMBER_ERROR; - } - - // - // Parse the decimal part. - // - int64_t exponent = 0; - bool overflow; - if (simdjson_likely((p != src_end) && (*p == '.'))) { - p++; - const uint8_t *start_decimal_digits = p; - if ((p == src_end) || !parse_digit(*p, i)) { - return NUMBER_ERROR; - } // no decimal digits - p++; - while ((p != src_end) && parse_digit(*p, i)) { - p++; - } - exponent = -(p - start_decimal_digits); - - // Overflow check. More than 19 digits (minus the decimal) may be - // overflow. - overflow = p - src - 1 > 19; - if (simdjson_unlikely(overflow && leading_zero)) { - // Skip leading 0.00000 and see if it still overflows - const uint8_t *start_digits = src + 2; - while (*start_digits == '0') { - start_digits++; - } - overflow = start_digits - src > 19; - } - } else { - overflow = p - src > 19; - } - - // - // Parse the exponent - // - if ((p != src_end) && (*p == 'e' || *p == 'E')) { - p++; - if (p == src_end) { - return NUMBER_ERROR; - } - bool exp_neg = *p == '-'; - p += exp_neg || *p == '+'; - - uint64_t exp = 0; - const uint8_t *start_exp_digits = p; - while ((p != src_end) && parse_digit(*p, exp)) { - p++; - } - // no exp digits, or 20+ exp digits - if (p - start_exp_digits == 0 || p - start_exp_digits > 19) { - return NUMBER_ERROR; - } - - exponent += exp_neg ? 0 - exp : exp; - } - - if ((p != src_end) && jsoncharutils::is_not_structural_or_whitespace(*p)) { - return NUMBER_ERROR; - } - - overflow = overflow || exponent < simdjson::internal::smallest_power || - exponent > simdjson::internal::largest_power; - - // - // Assemble (or slow-parse) the float - // - double d; - if (simdjson_likely(!overflow)) { - if (compute_float_64(exponent, i, negative, d)) { - return d; - } - } - if (!parse_float_fallback(src - negative, src_end, &d)) { - return NUMBER_ERROR; - } - return d; -} - -simdjson_unused simdjson_really_inline simdjson_result -parse_double_in_string(const uint8_t *src) noexcept { - // - // Check for minus sign - // - bool negative = (*(src + 1) == '-'); - src += negative + 1; - - // - // Parse the integer part. - // - uint64_t i = 0; - const uint8_t *p = src; - p += parse_digit(*p, i); - bool leading_zero = (i == 0); - while (parse_digit(*p, i)) { - p++; - } - // no integer digits, or 0123 (zero must be solo) - if (p == src) { - return INCORRECT_TYPE; - } - if ((leading_zero && p != src + 1)) { - return NUMBER_ERROR; - } - - // - // Parse the decimal part. - // - int64_t exponent = 0; - bool overflow; - if (simdjson_likely(*p == '.')) { - p++; - const uint8_t *start_decimal_digits = p; - if (!parse_digit(*p, i)) { - return NUMBER_ERROR; - } // no decimal digits - p++; - while (parse_digit(*p, i)) { - p++; - } - exponent = -(p - start_decimal_digits); - - // Overflow check. More than 19 digits (minus the decimal) may be - // overflow. - overflow = p - src - 1 > 19; - if (simdjson_unlikely(overflow && leading_zero)) { - // Skip leading 0.00000 and see if it still overflows - const uint8_t *start_digits = src + 2; - while (*start_digits == '0') { - start_digits++; - } - overflow = start_digits - src > 19; - } - } else { - overflow = p - src > 19; - } - - // - // Parse the exponent - // - if (*p == 'e' || *p == 'E') { - p++; - bool exp_neg = *p == '-'; - p += exp_neg || *p == '+'; - - uint64_t exp = 0; - const uint8_t *start_exp_digits = p; - while (parse_digit(*p, exp)) { - p++; - } - // no exp digits, or 20+ exp digits - if (p - start_exp_digits == 0 || p - start_exp_digits > 19) { - return NUMBER_ERROR; - } - - exponent += exp_neg ? 0 - exp : exp; - } - - if (*p != '"') { - return NUMBER_ERROR; - } - - overflow = overflow || exponent < simdjson::internal::smallest_power || - exponent > simdjson::internal::largest_power; - - // - // Assemble (or slow-parse) the float - // - double d; - if (simdjson_likely(!overflow)) { - if (compute_float_64(exponent, i, negative, d)) { - return d; - } - } - if (!parse_float_fallback(src - negative, &d)) { - return NUMBER_ERROR; - } - return d; -} -} // namespace {} -#endif // SIMDJSON_SKIPNUMBERPARSING - -} // namespace numberparsing -} // unnamed namespace -} // namespace haswell -} // namespace simdjson -/* end file include/simdjson/generic/numberparsing.h */ - -#endif // SIMDJSON_HASWELL_NUMBERPARSING_H -/* end file include/simdjson/haswell/numberparsing.h */ -/* begin file include/simdjson/haswell/end.h */ -SIMDJSON_UNTARGET_HASWELL -/* end file include/simdjson/haswell/end.h */ - -#endif // SIMDJSON_IMPLEMENTATION_HASWELL -#endif // SIMDJSON_HASWELL_COMMON_H -/* end file include/simdjson/haswell.h */ -/* begin file include/simdjson/ppc64.h */ -#ifndef SIMDJSON_PPC64_H -#define SIMDJSON_PPC64_H - - -#if SIMDJSON_IMPLEMENTATION_PPC64 - -namespace simdjson { -/** - * Implementation for ALTIVEC (PPC64). - */ -namespace ppc64 {} // namespace ppc64 -} // namespace simdjson - -/* begin file include/simdjson/ppc64/implementation.h */ -#ifndef SIMDJSON_PPC64_IMPLEMENTATION_H -#define SIMDJSON_PPC64_IMPLEMENTATION_H - - -namespace simdjson { -namespace ppc64 { - -namespace { -using namespace simdjson; -using namespace simdjson::dom; -} // namespace - -class implementation final : public simdjson::implementation { - public: - simdjson_really_inline implementation() - : simdjson::implementation( - "ppc64", "PPC64 ALTIVEC", internal::instruction_set::ALTIVEC) {} - simdjson_warn_unused error_code create_dom_parser_implementation( - size_t capacity, - size_t max_length, - std::unique_ptr &dst) const - noexcept final; - simdjson_warn_unused error_code - minify(const uint8_t *buf, size_t len, uint8_t *dst, size_t &dst_len) const - noexcept final; - simdjson_warn_unused bool validate_utf8(const char *buf, size_t len) const - noexcept final; -}; - -} // namespace ppc64 -} // namespace simdjson - -#endif // SIMDJSON_PPC64_IMPLEMENTATION_H -/* end file include/simdjson/ppc64/implementation.h */ - -/* begin file include/simdjson/ppc64/begin.h */ -// redefining SIMDJSON_IMPLEMENTATION to "ppc64" -// #define SIMDJSON_IMPLEMENTATION ppc64 -/* end file include/simdjson/ppc64/begin.h */ - -// Declarations -/* begin file include/simdjson/generic/dom_parser_implementation.h */ - -namespace simdjson { -namespace ppc64 { - -// expectation: sizeof(open_container) = 64/8. -struct open_container { - uint32_t tape_index; // where, on the tape, does the scope ([,{) begins - uint32_t count; // how many elements in the scope -}; // struct open_container - -static_assert(sizeof(open_container) == 64 / 8, - "Open container must be 64 bits"); - -class dom_parser_implementation final - : public internal::dom_parser_implementation { - public: - /** Tape location of each open { or [ */ - std::unique_ptr open_containers{}; - /** Whether each open container is a [ or { */ - std::unique_ptr is_array{}; - /** Buffer passed to stage 1 */ - const uint8_t *buf{}; - /** Length passed to stage 1 */ - size_t len{0}; - /** Document passed to stage 2 */ - dom::document *doc{}; - - inline dom_parser_implementation() noexcept; - inline dom_parser_implementation( - dom_parser_implementation &&other) noexcept; - inline dom_parser_implementation &operator=( - dom_parser_implementation &&other) noexcept; - dom_parser_implementation(const dom_parser_implementation &) = delete; - dom_parser_implementation &operator=(const dom_parser_implementation &) = - delete; - - simdjson_warn_unused error_code parse(const uint8_t *buf, - size_t len, - dom::document &doc) noexcept final; - simdjson_warn_unused error_code stage1(const uint8_t *buf, - size_t len, - stage1_mode partial) noexcept final; - simdjson_warn_unused error_code stage2(dom::document &doc) noexcept final; - simdjson_warn_unused error_code - stage2_next(dom::document &doc) noexcept final; - inline simdjson_warn_unused error_code - set_capacity(size_t capacity) noexcept final; - inline simdjson_warn_unused error_code - set_max_depth(size_t max_depth) noexcept final; - - private: - simdjson_really_inline simdjson_warn_unused error_code - set_capacity_stage1(size_t capacity); -}; - -} // namespace ppc64 -} // namespace simdjson - -namespace simdjson { -namespace ppc64 { - -inline dom_parser_implementation::dom_parser_implementation() noexcept = - default; -inline dom_parser_implementation::dom_parser_implementation( - dom_parser_implementation &&other) noexcept = default; -inline dom_parser_implementation &dom_parser_implementation::operator=( - dom_parser_implementation &&other) noexcept = default; - -// Leaving these here so they can be inlined if so desired -inline simdjson_warn_unused error_code -dom_parser_implementation::set_capacity(size_t capacity) noexcept { - if (capacity > SIMDJSON_MAXSIZE_BYTES) { - return CAPACITY; - } - // Stage 1 index output - size_t max_structures = SIMDJSON_ROUNDUP_N(capacity, 64) + 2 + 7; - structural_indexes.reset(new (std::nothrow) uint32_t[max_structures]); - if (!structural_indexes) { - _capacity = 0; - return MEMALLOC; - } - structural_indexes[0] = 0; - n_structural_indexes = 0; - - _capacity = capacity; - return SUCCESS; -} - -inline simdjson_warn_unused error_code -dom_parser_implementation::set_max_depth(size_t max_depth) noexcept { - // Stage 2 stacks - open_containers.reset(new (std::nothrow) open_container[max_depth]); - is_array.reset(new (std::nothrow) bool[max_depth]); - if (!is_array || !open_containers) { - _max_depth = 0; - return MEMALLOC; - } - - _max_depth = max_depth; - return SUCCESS; -} - -} // namespace ppc64 -} // namespace simdjson -/* end file include/simdjson/generic/dom_parser_implementation.h */ -/* begin file include/simdjson/ppc64/intrinsics.h */ -#ifndef SIMDJSON_PPC64_INTRINSICS_H -#define SIMDJSON_PPC64_INTRINSICS_H - - -// This should be the correct header whether -// you use visual studio or other compilers. -#include - -// These are defined by altivec.h in GCC toolchain, it is safe to undef them. -#ifdef bool -#undef bool -#endif - -#ifdef vector -#undef vector -#endif - -#endif // SIMDJSON_PPC64_INTRINSICS_H -/* end file include/simdjson/ppc64/intrinsics.h */ -/* begin file include/simdjson/ppc64/bitmanipulation.h */ -#ifndef SIMDJSON_PPC64_BITMANIPULATION_H -#define SIMDJSON_PPC64_BITMANIPULATION_H - -namespace simdjson { -namespace ppc64 { -namespace { - -// We sometimes call trailing_zero on inputs that are zero, -// but the algorithms do not end up using the returned value. -// Sadly, sanitizers are not smart enough to figure it out. -SIMDJSON_NO_SANITIZE_UNDEFINED -simdjson_really_inline int trailing_zeroes(uint64_t input_num) { -#ifdef SIMDJSON_REGULAR_VISUAL_STUDIO - unsigned long ret; - // Search the mask data from least significant bit (LSB) - // to the most significant bit (MSB) for a set bit (1). - _BitScanForward64(&ret, input_num); - return (int)ret; -#else // SIMDJSON_REGULAR_VISUAL_STUDIO - return __builtin_ctzll(input_num); -#endif // SIMDJSON_REGULAR_VISUAL_STUDIO -} - -/* result might be undefined when input_num is zero */ -simdjson_really_inline uint64_t clear_lowest_bit(uint64_t input_num) { - return input_num & (input_num - 1); -} - -/* result might be undefined when input_num is zero */ -simdjson_really_inline int leading_zeroes(uint64_t input_num) { -#ifdef SIMDJSON_REGULAR_VISUAL_STUDIO - unsigned long leading_zero = 0; - // Search the mask data from most significant bit (MSB) - // to least significant bit (LSB) for a set bit (1). - if (_BitScanReverse64(&leading_zero, input_num)) - return (int)(63 - leading_zero); - else - return 64; -#else - return __builtin_clzll(input_num); -#endif // SIMDJSON_REGULAR_VISUAL_STUDIO -} - -#ifdef SIMDJSON_REGULAR_VISUAL_STUDIO -simdjson_really_inline int count_ones(uint64_t input_num) { - // note: we do not support legacy 32-bit Windows - return __popcnt64(input_num); // Visual Studio wants two underscores -} -#else -simdjson_really_inline int count_ones(uint64_t input_num) { - return __builtin_popcountll(input_num); -} -#endif - -simdjson_really_inline bool add_overflow(uint64_t value1, - uint64_t value2, - uint64_t *result) { -#ifdef SIMDJSON_REGULAR_VISUAL_STUDIO - *result = value1 + value2; - return *result < value1; -#else - return __builtin_uaddll_overflow( - value1, value2, reinterpret_cast(result)); -#endif -} - -} // unnamed namespace -} // namespace ppc64 -} // namespace simdjson - -#endif // SIMDJSON_PPC64_BITMANIPULATION_H -/* end file include/simdjson/ppc64/bitmanipulation.h */ -/* begin file include/simdjson/ppc64/bitmask.h */ -#ifndef SIMDJSON_PPC64_BITMASK_H -#define SIMDJSON_PPC64_BITMASK_H - -namespace simdjson { -namespace ppc64 { -namespace { - -// -// Perform a "cumulative bitwise xor," flipping bits each time a 1 is -// encountered. -// -// For example, prefix_xor(00100100) == 00011100 -// -simdjson_really_inline uint64_t prefix_xor(uint64_t bitmask) { - // You can use the version below, however gcc sometimes miscompiles - // vec_pmsum_be, it happens somewhere around between 8 and 9th version. - // The performance boost was not noticeable, falling back to a usual - // implementation. - // __vector unsigned long long all_ones = {~0ull, ~0ull}; - // __vector unsigned long long mask = {bitmask, 0}; - // // Clang and GCC return different values for pmsum for ull so cast it - // to one. - // // Generally it is not specified by ALTIVEC ISA what is returned by - // // vec_pmsum_be. - // #if defined(__LITTLE_ENDIAN__) - // return (uint64_t)(((__vector unsigned long long)vec_pmsum_be(all_ones, - // mask))[0]); - // #else - // return (uint64_t)(((__vector unsigned long long)vec_pmsum_be(all_ones, - // mask))[1]); - // #endif - bitmask ^= bitmask << 1; - bitmask ^= bitmask << 2; - bitmask ^= bitmask << 4; - bitmask ^= bitmask << 8; - bitmask ^= bitmask << 16; - bitmask ^= bitmask << 32; - return bitmask; -} - -} // unnamed namespace -} // namespace ppc64 -} // namespace simdjson - -#endif -/* end file include/simdjson/ppc64/bitmask.h */ -/* begin file include/simdjson/ppc64/simd.h */ -#ifndef SIMDJSON_PPC64_SIMD_H -#define SIMDJSON_PPC64_SIMD_H - -#include - -namespace simdjson { -namespace ppc64 { -namespace { -namespace simd { - -using __m128i = __vector unsigned char; - -template -struct base { - __m128i value; - - // Zero constructor - simdjson_really_inline base() : value{__m128i()} {} - - // Conversion from SIMD register - simdjson_really_inline base(const __m128i _value) : value(_value) {} - - // Conversion to SIMD register - simdjson_really_inline operator const __m128i &() const { - return this->value; - } - simdjson_really_inline operator __m128i &() { return this->value; } - - // Bit operations - simdjson_really_inline Child operator|(const Child other) const { - return vec_or(this->value, (__m128i)other); - } - simdjson_really_inline Child operator&(const Child other) const { - return vec_and(this->value, (__m128i)other); - } - simdjson_really_inline Child operator^(const Child other) const { - return vec_xor(this->value, (__m128i)other); - } - simdjson_really_inline Child bit_andnot(const Child other) const { - return vec_andc(this->value, (__m128i)other); - } - simdjson_really_inline Child &operator|=(const Child other) { - auto this_cast = static_cast(this); - *this_cast = *this_cast | other; - return *this_cast; - } - simdjson_really_inline Child &operator&=(const Child other) { - auto this_cast = static_cast(this); - *this_cast = *this_cast & other; - return *this_cast; - } - simdjson_really_inline Child &operator^=(const Child other) { - auto this_cast = static_cast(this); - *this_cast = *this_cast ^ other; - return *this_cast; - } -}; - -// Forward-declared so they can be used by splat and friends. -template -struct simd8; - -template > -struct base8 : base> { - typedef uint16_t bitmask_t; - typedef uint32_t bitmask2_t; - - simdjson_really_inline base8() : base>() {} - simdjson_really_inline base8(const __m128i _value) - : base>(_value) {} - - friend simdjson_really_inline Mask operator==(const simd8 lhs, - const simd8 rhs) { - return (__m128i)vec_cmpeq(lhs.value, (__m128i)rhs); - } - - static const int SIZE = sizeof(base>::value); - - template - simdjson_really_inline simd8 prev(simd8 prev_chunk) const { - __m128i chunk = this->value; -#ifdef __LITTLE_ENDIAN__ - chunk = (__m128i)vec_reve(this->value); - prev_chunk = (__m128i)vec_reve((__m128i)prev_chunk); -#endif - chunk = (__m128i)vec_sld((__m128i)prev_chunk, (__m128i)chunk, 16 - N); -#ifdef __LITTLE_ENDIAN__ - chunk = (__m128i)vec_reve((__m128i)chunk); -#endif - return chunk; - } -}; - -// SIMD byte mask type (returned by things like eq and gt) -template <> -struct simd8 : base8 { - static simdjson_really_inline simd8 splat(bool _value) { - return (__m128i)vec_splats((unsigned char)(-(!!_value))); - } - - simdjson_really_inline simd8() : base8() {} - simdjson_really_inline simd8(const __m128i _value) - : base8(_value) {} - // Splat constructor - simdjson_really_inline simd8(bool _value) - : base8(splat(_value)) {} - - simdjson_really_inline int to_bitmask() const { - __vector unsigned long long result; - const __m128i perm_mask = {0x78, - 0x70, - 0x68, - 0x60, - 0x58, - 0x50, - 0x48, - 0x40, - 0x38, - 0x30, - 0x28, - 0x20, - 0x18, - 0x10, - 0x08, - 0x00}; - - result = ((__vector unsigned long long)vec_vbpermq( - (__m128i) this->value, (__m128i)perm_mask)); -#ifdef __LITTLE_ENDIAN__ - return static_cast(result[1]); -#else - return static_cast(result[0]); -#endif - } - simdjson_really_inline bool any() const { - return !vec_all_eq(this->value, (__m128i)vec_splats(0)); - } - simdjson_really_inline simd8 operator~() const { - return this->value ^ (__m128i)splat(true); - } -}; - -template -struct base8_numeric : base8 { - static simdjson_really_inline simd8 splat(T value) { - (void)value; - return (__m128i)vec_splats(value); - } - static simdjson_really_inline simd8 zero() { return splat(0); } - static simdjson_really_inline simd8 load(const T values[16]) { - return (__m128i)( - vec_vsx_ld(0, reinterpret_cast(values))); - } - // Repeat 16 values as many times as necessary (usually for lookup tables) - static simdjson_really_inline simd8 repeat_16(T v0, - T v1, - T v2, - T v3, - T v4, - T v5, - T v6, - T v7, - T v8, - T v9, - T v10, - T v11, - T v12, - T v13, - T v14, - T v15) { - return simd8(v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15); - } - - simdjson_really_inline base8_numeric() : base8() {} - simdjson_really_inline base8_numeric(const __m128i _value) - : base8(_value) {} - - // Store to array - simdjson_really_inline void store(T dst[16]) const { - vec_vsx_st(this->value, 0, reinterpret_cast<__m128i *>(dst)); - } - - // Override to distinguish from bool version - simdjson_really_inline simd8 operator~() const { return *this ^ 0xFFu; } - - // Addition/subtraction are the same for signed and unsigned - simdjson_really_inline simd8 operator+(const simd8 other) const { - return (__m128i)((__m128i) this->value + (__m128i)other); - } - simdjson_really_inline simd8 operator-(const simd8 other) const { - return (__m128i)((__m128i) this->value - (__m128i)other); - } - simdjson_really_inline simd8 &operator+=(const simd8 other) { - *this = *this + other; - return *static_cast *>(this); - } - simdjson_really_inline simd8 &operator-=(const simd8 other) { - *this = *this - other; - return *static_cast *>(this); - } - - // Perform a lookup assuming the value is between 0 and 16 (undefined - // behavior - // for out of range values) - template - simdjson_really_inline simd8 lookup_16(simd8 lookup_table) const { - return (__m128i)vec_perm( - (__m128i)lookup_table, (__m128i)lookup_table, this->value); - } - - // Copies to 'output" all bytes corresponding to a 0 in the mask - // (interpreted - // as a bitset). Passing a 0 value for mask would be equivalent to writing - // out - // every byte to output. Only the first 16 - count_ones(mask) bytes of the - // result are significant but 16 bytes get written. Design consideration: it - // seems like a function with the signature simd8 compress(uint32_t mask) - // would be sensible, but the AVX ISA makes this kind of approach difficult. - template - simdjson_really_inline void compress(uint16_t mask, L *output) const { - using internal::BitsSetTable256mul2; - using internal::pshufb_combine_table; - using internal::thintable_epi8; - // this particular implementation was inspired by work done by - // @animetosho - // we do it in two steps, first 8 bytes and then second 8 bytes - uint8_t mask1 = uint8_t(mask); // least significant 8 bits - uint8_t mask2 = uint8_t(mask >> 8); // most significant 8 bits -// next line just loads the 64-bit values thintable_epi8[mask1] and -// thintable_epi8[mask2] into a 128-bit register, using only -// two instructions on most compilers. -#ifdef __LITTLE_ENDIAN__ - __m128i shufmask = (__m128i)(__vector unsigned long long){ - thintable_epi8[mask1], thintable_epi8[mask2]}; -#else - __m128i shufmask = (__m128i)(__vector unsigned long long){ - thintable_epi8[mask2], thintable_epi8[mask1]}; - shufmask = (__m128i)vec_reve((__m128i)shufmask); -#endif - // we increment by 0x08 the second half of the mask - shufmask = ((__m128i)shufmask) + - ((__m128i)(__vector int){0, 0, 0x08080808, 0x08080808}); - - // this is the version "nearly pruned" - __m128i pruned = vec_perm(this->value, this->value, shufmask); - // we still need to put the two halves together. - // we compute the popcount of the first half: - int pop1 = BitsSetTable256mul2[mask1]; - // then load the corresponding mask, what it does is to write - // only the first pop1 bytes from the first 8 bytes, and then - // it fills in with the bytes from the second 8 bytes + some filling - // at the end. - __m128i compactmask = vec_vsx_ld( - 0, - reinterpret_cast(pshufb_combine_table + pop1 * 8)); - __m128i answer = vec_perm(pruned, (__m128i)vec_splats(0), compactmask); - vec_vsx_st(answer, 0, reinterpret_cast<__m128i *>(output)); - } - - template - simdjson_really_inline simd8 lookup_16(L replace0, - L replace1, - L replace2, - L replace3, - L replace4, - L replace5, - L replace6, - L replace7, - L replace8, - L replace9, - L replace10, - L replace11, - L replace12, - L replace13, - L replace14, - L replace15) const { - return lookup_16(simd8::repeat_16(replace0, - replace1, - replace2, - replace3, - replace4, - replace5, - replace6, - replace7, - replace8, - replace9, - replace10, - replace11, - replace12, - replace13, - replace14, - replace15)); - } -}; - -// Signed bytes -template <> -struct simd8 : base8_numeric { - simdjson_really_inline simd8() : base8_numeric() {} - simdjson_really_inline simd8(const __m128i _value) - : base8_numeric(_value) {} - // Splat constructor - simdjson_really_inline simd8(int8_t _value) : simd8(splat(_value)) {} - // Array constructor - simdjson_really_inline simd8(const int8_t *values) : simd8(load(values)) {} - // Member-by-member initialization - simdjson_really_inline simd8(int8_t v0, - int8_t v1, - int8_t v2, - int8_t v3, - int8_t v4, - int8_t v5, - int8_t v6, - int8_t v7, - int8_t v8, - int8_t v9, - int8_t v10, - int8_t v11, - int8_t v12, - int8_t v13, - int8_t v14, - int8_t v15) - : simd8((__m128i)(__vector signed char){v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15}) {} - // Repeat 16 values as many times as necessary (usually for lookup tables) - simdjson_really_inline static simd8 repeat_16(int8_t v0, - int8_t v1, - int8_t v2, - int8_t v3, - int8_t v4, - int8_t v5, - int8_t v6, - int8_t v7, - int8_t v8, - int8_t v9, - int8_t v10, - int8_t v11, - int8_t v12, - int8_t v13, - int8_t v14, - int8_t v15) { - return simd8(v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15); - } - - // Order-sensitive comparisons - simdjson_really_inline simd8 max_val( - const simd8 other) const { - return (__m128i)vec_max((__vector signed char)this->value, - (__vector signed char)(__m128i)other); - } - simdjson_really_inline simd8 min_val( - const simd8 other) const { - return (__m128i)vec_min((__vector signed char)this->value, - (__vector signed char)(__m128i)other); - } - simdjson_really_inline simd8 operator>( - const simd8 other) const { - return (__m128i)vec_cmpgt((__vector signed char)this->value, - (__vector signed char)(__m128i)other); - } - simdjson_really_inline simd8 operator<( - const simd8 other) const { - return (__m128i)vec_cmplt((__vector signed char)this->value, - (__vector signed char)(__m128i)other); - } -}; - -// Unsigned bytes -template <> -struct simd8 : base8_numeric { - simdjson_really_inline simd8() : base8_numeric() {} - simdjson_really_inline simd8(const __m128i _value) - : base8_numeric(_value) {} - // Splat constructor - simdjson_really_inline simd8(uint8_t _value) : simd8(splat(_value)) {} - // Array constructor - simdjson_really_inline simd8(const uint8_t *values) : simd8(load(values)) {} - // Member-by-member initialization - simdjson_really_inline simd8(uint8_t v0, - uint8_t v1, - uint8_t v2, - uint8_t v3, - uint8_t v4, - uint8_t v5, - uint8_t v6, - uint8_t v7, - uint8_t v8, - uint8_t v9, - uint8_t v10, - uint8_t v11, - uint8_t v12, - uint8_t v13, - uint8_t v14, - uint8_t v15) - : simd8((__m128i){v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15}) {} - // Repeat 16 values as many times as necessary (usually for lookup tables) - simdjson_really_inline static simd8 repeat_16(uint8_t v0, - uint8_t v1, - uint8_t v2, - uint8_t v3, - uint8_t v4, - uint8_t v5, - uint8_t v6, - uint8_t v7, - uint8_t v8, - uint8_t v9, - uint8_t v10, - uint8_t v11, - uint8_t v12, - uint8_t v13, - uint8_t v14, - uint8_t v15) { - return simd8(v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15); - } - - // Saturated math - simdjson_really_inline simd8 saturating_add( - const simd8 other) const { - return (__m128i)vec_adds(this->value, (__m128i)other); - } - simdjson_really_inline simd8 saturating_sub( - const simd8 other) const { - return (__m128i)vec_subs(this->value, (__m128i)other); - } - - // Order-specific operations - simdjson_really_inline simd8 max_val( - const simd8 other) const { - return (__m128i)vec_max(this->value, (__m128i)other); - } - simdjson_really_inline simd8 min_val( - const simd8 other) const { - return (__m128i)vec_min(this->value, (__m128i)other); - } - // Same as >, but only guarantees true is nonzero (< guarantees true = -1) - simdjson_really_inline simd8 gt_bits( - const simd8 other) const { - return this->saturating_sub(other); - } - // Same as <, but only guarantees true is nonzero (< guarantees true = -1) - simdjson_really_inline simd8 lt_bits( - const simd8 other) const { - return other.saturating_sub(*this); - } - simdjson_really_inline simd8 operator<=( - const simd8 other) const { - return other.max_val(*this) == other; - } - simdjson_really_inline simd8 operator>=( - const simd8 other) const { - return other.min_val(*this) == other; - } - simdjson_really_inline simd8 operator>( - const simd8 other) const { - return this->gt_bits(other).any_bits_set(); - } - simdjson_really_inline simd8 operator<( - const simd8 other) const { - return this->gt_bits(other).any_bits_set(); - } - - // Bit-specific operations - simdjson_really_inline simd8 bits_not_set() const { - return (__m128i)vec_cmpeq(this->value, (__m128i)vec_splats(uint8_t(0))); - } - simdjson_really_inline simd8 bits_not_set(simd8 bits) const { - return (*this & bits).bits_not_set(); - } - simdjson_really_inline simd8 any_bits_set() const { - return ~this->bits_not_set(); - } - simdjson_really_inline simd8 any_bits_set(simd8 bits) const { - return ~this->bits_not_set(bits); - } - simdjson_really_inline bool bits_not_set_anywhere() const { - return vec_all_eq(this->value, (__m128i)vec_splats(0)); - } - simdjson_really_inline bool any_bits_set_anywhere() const { - return !bits_not_set_anywhere(); - } - simdjson_really_inline bool bits_not_set_anywhere( - simd8 bits) const { - return vec_all_eq(vec_and(this->value, (__m128i)bits), - (__m128i)vec_splats(0)); - } - simdjson_really_inline bool any_bits_set_anywhere( - simd8 bits) const { - return !bits_not_set_anywhere(bits); - } - template - simdjson_really_inline simd8 shr() const { - return simd8( - (__m128i)vec_sr(this->value, (__m128i)vec_splat_u8(N))); - } - template - simdjson_really_inline simd8 shl() const { - return simd8( - (__m128i)vec_sl(this->value, (__m128i)vec_splat_u8(N))); - } -}; - -template -struct simd8x64 { - static constexpr int NUM_CHUNKS = 64 / sizeof(simd8); - static_assert(NUM_CHUNKS == 4, - "PPC64 kernel should use four registers per 64-byte block."); - const simd8 chunks[NUM_CHUNKS]; - - simd8x64(const simd8x64 &o) = delete; // no copy allowed - simd8x64 &operator=(const simd8 &other) = - delete; // no assignment allowed - simd8x64() = delete; // no default constructor allowed - - simdjson_really_inline simd8x64(const simd8 chunk0, - const simd8 chunk1, - const simd8 chunk2, - const simd8 chunk3) - : chunks{chunk0, chunk1, chunk2, chunk3} {} - simdjson_really_inline simd8x64(const T ptr[64]) - : chunks{simd8::load(ptr), - simd8::load(ptr + 16), - simd8::load(ptr + 32), - simd8::load(ptr + 48)} {} - - simdjson_really_inline void store(T ptr[64]) const { - this->chunks[0].store(ptr + sizeof(simd8) * 0); - this->chunks[1].store(ptr + sizeof(simd8) * 1); - this->chunks[2].store(ptr + sizeof(simd8) * 2); - this->chunks[3].store(ptr + sizeof(simd8) * 3); - } - - simdjson_really_inline simd8 reduce_or() const { - return (this->chunks[0] | this->chunks[1]) | - (this->chunks[2] | this->chunks[3]); - } - - simdjson_really_inline uint64_t compress(uint64_t mask, T *output) const { - this->chunks[0].compress(uint16_t(mask), output); - this->chunks[1].compress(uint16_t(mask >> 16), - output + 16 - count_ones(mask & 0xFFFF)); - this->chunks[2].compress(uint16_t(mask >> 32), - output + 32 - count_ones(mask & 0xFFFFFFFF)); - this->chunks[3].compress( - uint16_t(mask >> 48), - output + 48 - count_ones(mask & 0xFFFFFFFFFFFF)); - return 64 - count_ones(mask); - } - - simdjson_really_inline uint64_t to_bitmask() const { - uint64_t r0 = uint32_t(this->chunks[0].to_bitmask()); - uint64_t r1 = this->chunks[1].to_bitmask(); - uint64_t r2 = this->chunks[2].to_bitmask(); - uint64_t r3 = this->chunks[3].to_bitmask(); - return r0 | (r1 << 16) | (r2 << 32) | (r3 << 48); - } - - simdjson_really_inline uint64_t eq(const T m) const { - const simd8 mask = simd8::splat(m); - return simd8x64(this->chunks[0] == mask, - this->chunks[1] == mask, - this->chunks[2] == mask, - this->chunks[3] == mask) - .to_bitmask(); - } - - simdjson_really_inline uint64_t eq(const simd8x64 &other) const { - return simd8x64(this->chunks[0] == other.chunks[0], - this->chunks[1] == other.chunks[1], - this->chunks[2] == other.chunks[2], - this->chunks[3] == other.chunks[3]) - .to_bitmask(); - } - - simdjson_really_inline uint64_t lteq(const T m) const { - const simd8 mask = simd8::splat(m); - return simd8x64(this->chunks[0] <= mask, - this->chunks[1] <= mask, - this->chunks[2] <= mask, - this->chunks[3] <= mask) - .to_bitmask(); - } -}; // struct simd8x64 - -} // namespace simd -} // unnamed namespace -} // namespace ppc64 -} // namespace simdjson - -#endif // SIMDJSON_PPC64_SIMD_INPUT_H -/* end file include/simdjson/ppc64/simd.h */ -/* begin file include/simdjson/generic/jsoncharutils.h */ - -namespace simdjson { -namespace ppc64 { -namespace { -namespace jsoncharutils { - -// return non-zero if not a structural or whitespace char -// zero otherwise -simdjson_really_inline uint32_t is_not_structural_or_whitespace(uint8_t c) { - return internal::structural_or_whitespace_negated[c]; -} - -simdjson_really_inline uint32_t is_structural_or_whitespace(uint8_t c) { - return internal::structural_or_whitespace[c]; -} - -// returns a value with the high 16 bits set if not valid -// otherwise returns the conversion of the 4 hex digits at src into the bottom -// 16 bits of the 32-bit return register -// -// see -// https://lemire.me/blog/2019/04/17/parsing-short-hexadecimal-strings-efficiently/ -static inline uint32_t hex_to_u32_nocheck( - const uint8_t *src) { // strictly speaking, static inline is a C-ism - uint32_t v1 = internal::digit_to_val32[630 + src[0]]; - uint32_t v2 = internal::digit_to_val32[420 + src[1]]; - uint32_t v3 = internal::digit_to_val32[210 + src[2]]; - uint32_t v4 = internal::digit_to_val32[0 + src[3]]; - return v1 | v2 | v3 | v4; -} - -// given a code point cp, writes to c -// the utf-8 code, outputting the length in -// bytes, if the length is zero, the code point -// is invalid -// -// This can possibly be made faster using pdep -// and clz and table lookups, but JSON documents -// have few escaped code points, and the following -// function looks cheap. -// -// Note: we assume that surrogates are treated separately -// -simdjson_really_inline size_t codepoint_to_utf8(uint32_t cp, uint8_t *c) { - if (cp <= 0x7F) { - c[0] = uint8_t(cp); - return 1; // ascii - } - if (cp <= 0x7FF) { - c[0] = uint8_t((cp >> 6) + 192); - c[1] = uint8_t((cp & 63) + 128); - return 2; // universal plane - // Surrogates are treated elsewhere... - //} //else if (0xd800 <= cp && cp <= 0xdfff) { - // return 0; // surrogates // could put assert here - } else if (cp <= 0xFFFF) { - c[0] = uint8_t((cp >> 12) + 224); - c[1] = uint8_t(((cp >> 6) & 63) + 128); - c[2] = uint8_t((cp & 63) + 128); - return 3; - } else if (cp <= - 0x10FFFF) { // if you know you have a valid code point, this - // is not needed - c[0] = uint8_t((cp >> 18) + 240); - c[1] = uint8_t(((cp >> 12) & 63) + 128); - c[2] = uint8_t(((cp >> 6) & 63) + 128); - c[3] = uint8_t((cp & 63) + 128); - return 4; - } - // will return 0 when the code point was too large. - return 0; // bad r -} - -#ifdef SIMDJSON_IS_32BITS // _umul128 for x86, arm -// this is a slow emulation routine for 32-bit -// -static simdjson_really_inline uint64_t __emulu(uint32_t x, uint32_t y) { - return x * (uint64_t)y; -} -static simdjson_really_inline uint64_t _umul128(uint64_t ab, - uint64_t cd, - uint64_t *hi) { - uint64_t ad = __emulu((uint32_t)(ab >> 32), (uint32_t)cd); - uint64_t bd = __emulu((uint32_t)ab, (uint32_t)cd); - uint64_t adbc = ad + __emulu((uint32_t)ab, (uint32_t)(cd >> 32)); - uint64_t adbc_carry = !!(adbc < ad); - uint64_t lo = bd + (adbc << 32); - *hi = __emulu((uint32_t)(ab >> 32), (uint32_t)(cd >> 32)) + (adbc >> 32) + - (adbc_carry << 32) + !!(lo < bd); - return lo; -} -#endif - -using internal::value128; - -simdjson_really_inline value128 full_multiplication(uint64_t value1, - uint64_t value2) { - value128 answer; -#if defined(SIMDJSON_REGULAR_VISUAL_STUDIO) || defined(SIMDJSON_IS_32BITS) -#ifdef _M_ARM64 - // ARM64 has native support for 64-bit multiplications, no need to emultate - answer.high = __umulh(value1, value2); - answer.low = value1 * value2; -#else - answer.low = _umul128( - value1, value2, &answer.high); // _umul128 not available on ARM64 -#endif // _M_ARM64 -#else // defined(SIMDJSON_REGULAR_VISUAL_STUDIO) || defined(SIMDJSON_IS_32BITS) - __uint128_t r = (static_cast<__uint128_t>(value1)) * value2; - answer.low = uint64_t(r); - answer.high = uint64_t(r >> 64); -#endif - return answer; -} - -} // namespace jsoncharutils -} // unnamed namespace -} // namespace ppc64 -} // namespace simdjson -/* end file include/simdjson/generic/jsoncharutils.h */ -/* begin file include/simdjson/generic/atomparsing.h */ -namespace simdjson { -namespace ppc64 { -namespace { -/// @private -namespace atomparsing { - -// The string_to_uint32 is exclusively used to map literal strings to 32-bit -// values. -// We use memcpy instead of a pointer cast to avoid undefined behaviors since we -// cannot -// be certain that the character pointer will be properly aligned. -// You might think that using memcpy makes this function expensive, but you'd be -// wrong. -// All decent optimizing compilers (GCC, clang, Visual Studio) will compile -// string_to_uint32("false"); -// to the compile-time constant 1936482662. -simdjson_really_inline uint32_t string_to_uint32(const char *str) { - uint32_t val; - std::memcpy(&val, str, sizeof(uint32_t)); - return val; -} - - -// Again in str4ncmp we use a memcpy to avoid undefined behavior. The memcpy may -// appear expensive. -// Yet all decent optimizing compilers will compile memcpy to a single -// instruction, just about. -simdjson_warn_unused simdjson_really_inline uint32_t -str4ncmp(const uint8_t *src, const char *atom) { - uint32_t - srcval; // we want to avoid unaligned 32-bit loads (undefined in C/C++) - static_assert(sizeof(uint32_t) <= SIMDJSON_PADDING, - "SIMDJSON_PADDING must be larger than 4 bytes"); - std::memcpy(&srcval, src, sizeof(uint32_t)); - return srcval ^ string_to_uint32(atom); -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_true_atom( - const uint8_t *src) { - return (str4ncmp(src, "true") | - jsoncharutils::is_not_structural_or_whitespace(src[4])) == 0; -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_true_atom( - const uint8_t *src, size_t len) { - if (len > 4) { - return is_valid_true_atom(src); - } else if (len == 4) { - return !str4ncmp(src, "true"); - } else { - return false; - } -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_false_atom( - const uint8_t *src) { - return (str4ncmp(src + 1, "alse") | - jsoncharutils::is_not_structural_or_whitespace(src[5])) == 0; -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_false_atom( - const uint8_t *src, size_t len) { - if (len > 5) { - return is_valid_false_atom(src); - } else if (len == 5) { - return !str4ncmp(src + 1, "alse"); - } else { - return false; - } -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_null_atom( - const uint8_t *src) { - return (str4ncmp(src, "null") | - jsoncharutils::is_not_structural_or_whitespace(src[4])) == 0; -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_null_atom( - const uint8_t *src, size_t len) { - if (len > 4) { - return is_valid_null_atom(src); - } else if (len == 4) { - return !str4ncmp(src, "null"); - } else { - return false; - } -} - -} // namespace atomparsing -} // unnamed namespace -} // namespace ppc64 -} // namespace simdjson -/* end file include/simdjson/generic/atomparsing.h */ -/* begin file include/simdjson/ppc64/stringparsing.h */ -#ifndef SIMDJSON_PPC64_STRINGPARSING_H -#define SIMDJSON_PPC64_STRINGPARSING_H - - -namespace simdjson { -namespace ppc64 { -namespace { - -using namespace simd; - -// Holds backslashes and quotes locations. -struct backslash_and_quote { - public: - static constexpr uint32_t BYTES_PROCESSED = 32; - simdjson_really_inline static backslash_and_quote copy_and_find( - const uint8_t *src, uint8_t *dst); - - simdjson_really_inline bool has_quote_first() { - return ((bs_bits - 1) & quote_bits) != 0; - } - simdjson_really_inline bool has_backslash() { return bs_bits != 0; } - simdjson_really_inline int quote_index() { - return trailing_zeroes(quote_bits); - } - simdjson_really_inline int backslash_index() { - return trailing_zeroes(bs_bits); - } - - uint32_t bs_bits; - uint32_t quote_bits; -}; // struct backslash_and_quote - -simdjson_really_inline backslash_and_quote -backslash_and_quote::copy_and_find(const uint8_t *src, uint8_t *dst) { - // this can read up to 31 bytes beyond the buffer size, but we require - // SIMDJSON_PADDING of padding - static_assert(SIMDJSON_PADDING >= (BYTES_PROCESSED - 1), - "backslash and quote finder must process fewer than " - "SIMDJSON_PADDING bytes"); - simd8 v0(src); - simd8 v1(src + sizeof(v0)); - v0.store(dst); - v1.store(dst + sizeof(v0)); - - // Getting a 64-bit bitmask is much cheaper than multiple 16-bit bitmasks on - // PPC; therefore, we smash them together into a 64-byte mask and get the - // bitmask from there. - uint64_t bs_and_quote = - simd8x64(v0 == '\\', v1 == '\\', v0 == '"', v1 == '"') - .to_bitmask(); - return { - uint32_t(bs_and_quote), // bs_bits - uint32_t(bs_and_quote >> 32) // quote_bits - }; -} - -} // unnamed namespace -} // namespace ppc64 -} // namespace simdjson - -/* begin file include/simdjson/generic/stringparsing.h */ -// This file contains the common code every implementation uses -// It is intended to be included multiple times and compiled multiple times - -namespace simdjson { -namespace ppc64 { -namespace { -/// @private -namespace stringparsing { - -// begin copypasta -// These chars yield themselves: " \ / -// b -> backspace, f -> formfeed, n -> newline, r -> cr, t -> horizontal tab -// u not handled in this table as it's complex -static const uint8_t escape_map[256] = { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x0. - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0x22, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x2f, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x4. - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x5c, 0, 0, 0, // 0x5. - 0, 0, 0x08, 0, 0, 0, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0x0a, 0, // 0x6. - 0, 0, 0x0d, 0, 0x09, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x7. - - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -}; - -// handle a unicode codepoint -// write appropriate values into dest -// src will advance 6 bytes or 12 bytes -// dest will advance a variable amount (return via pointer) -// return true if the unicode codepoint was valid -// We work in little-endian then swap at write time -simdjson_warn_unused simdjson_really_inline bool handle_unicode_codepoint( - const uint8_t **src_ptr, uint8_t **dst_ptr) { - // jsoncharutils::hex_to_u32_nocheck fills high 16 bits of the return value - // with 1s if the - // conversion isn't valid; we defer the check for this to inside the - // multilingual plane check - uint32_t code_point = jsoncharutils::hex_to_u32_nocheck(*src_ptr + 2); - *src_ptr += 6; - // check for low surrogate for characters outside the Basic - // Multilingual Plane. - if (code_point >= 0xd800 && code_point < 0xdc00) { - if (((*src_ptr)[0] != '\\') || (*src_ptr)[1] != 'u') { - return false; - } - uint32_t code_point_2 = jsoncharutils::hex_to_u32_nocheck(*src_ptr + 2); - - // if the first code point is invalid we will get here, as we will go - // past - // the check for being outside the Basic Multilingual plane. If we don't - // find a \u immediately afterwards we fail out anyhow, but if we do, - // this check catches both the case of the first code point being - // invalid - // or the second code point being invalid. - if ((code_point | code_point_2) >> 16) { - return false; - } - - code_point = - (((code_point - 0xd800) << 10) | (code_point_2 - 0xdc00)) + 0x10000; - *src_ptr += 6; - } - size_t offset = jsoncharutils::codepoint_to_utf8(code_point, *dst_ptr); - *dst_ptr += offset; - return offset > 0; -} - -/** - * Unescape a string from src to dst, stopping at a final unescaped quote. E.g., - * if src points at 'joe"', then - * dst needs to have four free bytes. - */ -simdjson_warn_unused simdjson_really_inline uint8_t *parse_string( - const uint8_t *src, uint8_t *dst) { - while (1) { - // Copy the next n bytes, and find the backslash and quote in them. - auto bs_quote = backslash_and_quote::copy_and_find(src, dst); - // If the next thing is the end quote, copy and return - if (bs_quote.has_quote_first()) { - // we encountered quotes first. Move dst to point to quotes and exit - return dst + bs_quote.quote_index(); - } - if (bs_quote.has_backslash()) { - /* find out where the backspace is */ - auto bs_dist = bs_quote.backslash_index(); - uint8_t escape_char = src[bs_dist + 1]; - /* we encountered backslash first. Handle backslash */ - if (escape_char == 'u') { - /* move src/dst up to the start; they will be further adjusted - within the unicode codepoint handling code. */ - src += bs_dist; - dst += bs_dist; - if (!handle_unicode_codepoint(&src, &dst)) { - return nullptr; - } - } else { - /* simple 1:1 conversion. Will eat bs_dist+2 characters in input - * and - * write bs_dist+1 characters to output - * note this may reach beyond the part of the buffer we've - * actually - * seen. I think this is ok */ - uint8_t escape_result = escape_map[escape_char]; - if (escape_result == 0u) { - return nullptr; /* bogus escape value is an error */ - } - dst[bs_dist] = escape_result; - src += bs_dist + 2; - dst += bs_dist + 1; - } - } else { - /* they are the same. Since they can't co-occur, it means we - * encountered neither. */ - src += backslash_and_quote::BYTES_PROCESSED; - dst += backslash_and_quote::BYTES_PROCESSED; - } - } - /* can't be reached */ - return nullptr; -} - -simdjson_unused simdjson_warn_unused simdjson_really_inline error_code -parse_string_to_buffer(const uint8_t *src, - uint8_t *¤t_string_buf_loc, - std::string_view &s) { - if (*(src++) != '"') { - return STRING_ERROR; - } - auto end = stringparsing::parse_string(src, current_string_buf_loc); - if (!end) { - return STRING_ERROR; - } - s = std::string_view(reinterpret_cast(current_string_buf_loc), - end - current_string_buf_loc); - current_string_buf_loc = end; - return SUCCESS; -} - -} // namespace stringparsing -} // unnamed namespace -} // namespace ppc64 -} // namespace simdjson -/* end file include/simdjson/generic/stringparsing.h */ - -#endif // SIMDJSON_PPC64_STRINGPARSING_H -/* end file include/simdjson/ppc64/stringparsing.h */ -/* begin file include/simdjson/ppc64/numberparsing.h */ -#ifndef SIMDJSON_PPC64_NUMBERPARSING_H -#define SIMDJSON_PPC64_NUMBERPARSING_H - -#if defined(__linux__) -#include -#elif defined(__FreeBSD__) -#include -#endif - -namespace simdjson { -namespace ppc64 { -namespace { - -// we don't have appropriate instructions, so let us use a scalar function -// credit: https://johnnylee-sde.github.io/Fast-numeric-string-to-int/ -static simdjson_really_inline uint32_t -parse_eight_digits_unrolled(const uint8_t *chars) { - uint64_t val; - std::memcpy(&val, chars, sizeof(uint64_t)); -#ifdef __BIG_ENDIAN__ -#if defined(__linux__) - val = bswap_64(val); -#elif defined(__FreeBSD__) - val = bswap64(val); -#endif -#endif - val = (val & 0x0F0F0F0F0F0F0F0F) * 2561 >> 8; - val = (val & 0x00FF00FF00FF00FF) * 6553601 >> 16; - return uint32_t((val & 0x0000FFFF0000FFFF) * 42949672960001 >> 32); -} - -} // unnamed namespace -} // namespace ppc64 -} // namespace simdjson - -#define SIMDJSON_SWAR_NUMBER_PARSING 1 - -/* begin file include/simdjson/generic/numberparsing.h */ -#include - -namespace simdjson { -namespace ppc64 { - -namespace ondemand { -/** - * The type of a JSON number - */ -enum class number_type { - floating_point_number = 1, /// a binary64 number - signed_integer, /// a signed integer that fits in a 64-bit word using two's - /// complement - unsigned_integer /// a positive integer larger or equal to 1<<63 -}; -} - -namespace { -/// @private -namespace numberparsing { - - -#ifdef JSON_TEST_NUMBERS -#define INVALID_NUMBER(SRC) (found_invalid_number((SRC)), NUMBER_ERROR) -#define WRITE_INTEGER(VALUE, SRC, WRITER) \ - (found_integer((VALUE), (SRC)), (WRITER).append_s64((VALUE))) -#define WRITE_UNSIGNED(VALUE, SRC, WRITER) \ - (found_unsigned_integer((VALUE), (SRC)), (WRITER).append_u64((VALUE))) -#define WRITE_DOUBLE(VALUE, SRC, WRITER) \ - (found_float((VALUE), (SRC)), (WRITER).append_double((VALUE))) -#else -#define INVALID_NUMBER(SRC) (NUMBER_ERROR) -#define WRITE_INTEGER(VALUE, SRC, WRITER) (WRITER).append_s64((VALUE)) -#define WRITE_UNSIGNED(VALUE, SRC, WRITER) (WRITER).append_u64((VALUE)) -#define WRITE_DOUBLE(VALUE, SRC, WRITER) (WRITER).append_double((VALUE)) -#endif - -namespace { -// Convert a mantissa, an exponent and a sign bit into an ieee64 double. -// The real_exponent needs to be in [0, 2046] (technically real_exponent = 2047 -// would be acceptable). -// The mantissa should be in [0,1<<53). The bit at index (1ULL << 52) while be -// zeroed. -simdjson_really_inline double to_double(uint64_t mantissa, - uint64_t real_exponent, - bool negative) { - double d; - mantissa &= ~(1ULL << 52); - mantissa |= real_exponent << 52; - mantissa |= ((static_cast(negative)) << 63); - std::memcpy(&d, &mantissa, sizeof(d)); - return d; -} -} -// Attempts to compute i * 10^(power) exactly; and if "negative" is -// true, negate the result. -// This function will only work in some cases, when it does not work, success is -// set to false. This should work *most of the time* (like 99% of the time). -// We assume that power is in the [smallest_power, -// largest_power] interval: the caller is responsible for this check. -simdjson_really_inline bool compute_float_64(int64_t power, - uint64_t i, - bool negative, - double &d) { -// we start with a fast path -// It was described in -// Clinger WD. How to read floating point numbers accurately. -// ACM SIGPLAN Notices. 1990 -#ifndef FLT_EVAL_METHOD -#error "FLT_EVAL_METHOD should be defined, please include cfloat." -#endif -#if (FLT_EVAL_METHOD != 1) && (FLT_EVAL_METHOD != 0) - // We cannot be certain that x/y is rounded to nearest. - if (0 <= power && power <= 22 && i <= 9007199254740991) { -#else - if (-22 <= power && power <= 22 && i <= 9007199254740991) { -#endif - // convert the integer into a double. This is lossless since - // 0 <= i <= 2^53 - 1. - d = double(i); - // - // The general idea is as follows. - // If 0 <= s < 2^53 and if 10^0 <= p <= 10^22 then - // 1) Both s and p can be represented exactly as 64-bit floating-point - // values - // (binary64). - // 2) Because s and p can be represented exactly as floating-point - // values, - // then s * p - // and s / p will produce correctly rounded values. - // - if (power < 0) { - d = d / simdjson::internal::power_of_ten[-power]; - } else { - d = d * simdjson::internal::power_of_ten[power]; - } - if (negative) { - d = -d; - } - return true; - } - // When 22 < power && power < 22 + 16, we could - // hope for another, secondary fast path. It was - // described by David M. Gay in "Correctly rounded - // binary-decimal and decimal-binary conversions." (1990) - // If you need to compute i * 10^(22 + x) for x < 16, - // first compute i * 10^x, if you know that result is exact - // (e.g., when i * 10^x < 2^53), - // then you can still proceed and do (i * 10^x) * 10^22. - // Is this worth your time? - // You need 22 < power *and* power < 22 + 16 *and* (i * 10^(x-22) < 2^53) - // for this second fast path to work. - // If you you have 22 < power *and* power < 22 + 16, and then you - // optimistically compute "i * 10^(x-22)", there is still a chance that you - // have wasted your time if i * 10^(x-22) >= 2^53. It makes the use cases of - // this optimization maybe less common than we would like. Source: - // http://www.exploringbinary.com/fast-path-decimal-to-floating-point-conversion/ - // also used in RapidJSON: https://rapidjson.org/strtod_8h_source.html - - // The fast path has now failed, so we are failing back on the slower path. - - // In the slow path, we need to adjust i so that it is > 1<<63 which is - // always - // possible, except if i == 0, so we handle i == 0 separately. - if (i == 0) { - d = 0.0; - return true; - } - - - // The exponent is 1024 + 63 + power - // + floor(log(5**power)/log(2)). - // The 1024 comes from the ieee64 standard. - // The 63 comes from the fact that we use a 64-bit word. - // - // Computing floor(log(5**power)/log(2)) could be - // slow. Instead we use a fast function. - // - // For power in (-400,350), we have that - // (((152170 + 65536) * power ) >> 16); - // is equal to - // floor(log(5**power)/log(2)) + power when power >= 0 - // and it is equal to - // ceil(log(5**-power)/log(2)) + power when power < 0 - // - // The 65536 is (1<<16) and corresponds to - // (65536 * power) >> 16 ---> power - // - // ((152170 * power ) >> 16) is equal to - // floor(log(5**power)/log(2)) - // - // Note that this is not magic: 152170/(1<<16) is - // approximatively equal to log(5)/log(2). - // The 1<<16 value is a power of two; we could use a - // larger power of 2 if we wanted to. - // - int64_t exponent = (((152170 + 65536) * power) >> 16) + 1024 + 63; - - - // We want the most significant bit of i to be 1. Shift if needed. - int lz = leading_zeroes(i); - i <<= lz; - - - // We are going to need to do some 64-bit arithmetic to get a precise - // product. - // We use a table lookup approach. - // It is safe because - // power >= smallest_power - // and power <= largest_power - // We recover the mantissa of the power, it has a leading 1. It is always - // rounded down. - // - // We want the most significant 64 bits of the product. We know - // this will be non-zero because the most significant bit of i is - // 1. - const uint32_t index = - 2 * uint32_t(power - simdjson::internal::smallest_power); - // Optimization: It may be that materializing the index as a variable might - // confuse some compilers and prevent effective complex-addressing loads. - // (Done for code clarity.) - // - // The full_multiplication function computes the 128-bit product of two - // 64-bit words - // with a returned value of type value128 with a "low component" - // corresponding to the - // 64-bit least significant bits of the product and with a "high component" - // corresponding - // to the 64-bit most significant bits of the product. - simdjson::internal::value128 firstproduct = - jsoncharutils::full_multiplication( - i, simdjson::internal::power_of_five_128[index]); - // Both i and power_of_five_128[index] have their most significant bit set - // to 1 which - // implies that the either the most or the second most significant bit of - // the product - // is 1. We pack values in this manner for efficiency reasons: it maximizes - // the use - // we make of the product. It also makes it easy to reason about the - // product: there - // is 0 or 1 leading zero in the product. - - // Unless the least significant 9 bits of the high (64-bit) part of the full - // product are all 1s, then we know that the most significant 55 bits are - // exact and no further work is needed. Having 55 bits is necessary because - // we need 53 bits for the mantissa but we have to have one rounding bit and - // we can waste a bit if the most significant bit of the product is zero. - if ((firstproduct.high & 0x1FF) == 0x1FF) { - // We want to compute i * 5^q, but only care about the top 55 bits at - // most. - // Consider the scenario where q>=0. Then 5^q may not fit in 64-bits. - // Doing - // the full computation is wasteful. So we do what is called a - // "truncated - // multiplication". - // We take the most significant 64-bits, and we put them in - // power_of_five_128[index]. Usually, that's good enough to approximate - // i * 5^q - // to the desired approximation using one multiplication. Sometimes it - // does not suffice. - // Then we store the next most significant 64 bits in - // power_of_five_128[index + 1], and - // then we get a better approximation to i * 5^q. In very rare cases, - // even that - // will not suffice, though it is seemingly very hard to find such a - // scenario. - // - // That's for when q>=0. The logic for q<0 is somewhat similar but it is - // somewhat - // more complicated. - // - // There is an extra layer of complexity in that we need more than 55 - // bits of - // accuracy in the round-to-even scenario. - // - // The full_multiplication function computes the 128-bit product of two - // 64-bit words - // with a returned value of type value128 with a "low component" - // corresponding to the - // 64-bit least significant bits of the product and with a "high - // component" corresponding - // to the 64-bit most significant bits of the product. - simdjson::internal::value128 secondproduct = - jsoncharutils::full_multiplication( - i, simdjson::internal::power_of_five_128[index + 1]); - firstproduct.low += secondproduct.high; - if (secondproduct.high > firstproduct.low) { - firstproduct.high++; - } - // At this point, we might need to add at most one to firstproduct, but - // this - // can only change the value of firstproduct.high if firstproduct.low is - // maximal. - if (simdjson_unlikely(firstproduct.low == 0xFFFFFFFFFFFFFFFF)) { - // This is very unlikely, but if so, we need to do much more work! - return false; - } - } - uint64_t lower = firstproduct.low; - uint64_t upper = firstproduct.high; - // The final mantissa should be 53 bits with a leading 1. - // We shift it so that it occupies 54 bits with a leading 1. - /////// - uint64_t upperbit = upper >> 63; - uint64_t mantissa = upper >> (upperbit + 9); - lz += int(1 ^ upperbit); - - // Here we have mantissa < (1<<54). - int64_t real_exponent = exponent - lz; - if (simdjson_unlikely(real_exponent <= 0)) { // we have a subnormal? - // Here have that real_exponent <= 0 so -real_exponent >= 0 - if (-real_exponent + 1 >= 64) { // if we have more than 64 bits below - // the minimum exponent, you have a - // zero for sure. - d = 0.0; - return true; - } - // next line is safe because -real_exponent + 1 < 0 - mantissa >>= -real_exponent + 1; - // Thankfully, we can't have both "round-to-even" and subnormals because - // "round-to-even" only occurs for powers close to 0. - mantissa += (mantissa & 1); // round up - mantissa >>= 1; - // There is a weird scenario where we don't have a subnormal but just. - // Suppose we start with 2.2250738585072013e-308, we end up - // with 0x3fffffffffffff x 2^-1023-53 which is technically subnormal - // whereas 0x40000000000000 x 2^-1023-53 is normal. Now, we need to - // round - // up 0x3fffffffffffff x 2^-1023-53 and once we do, we are no longer - // subnormal, but we can only know this after rounding. - // So we only declare a subnormal if we are smaller than the threshold. - real_exponent = (mantissa < (uint64_t(1) << 52)) ? 0 : 1; - d = to_double(mantissa, real_exponent, negative); - return true; - } - // We have to round to even. The "to even" part - // is only a problem when we are right in between two floats - // which we guard against. - // If we have lots of trailing zeros, we may fall right between two - // floating-point values. - // - // The round-to-even cases take the form of a number 2m+1 which is in - // (2^53,2^54] - // times a power of two. That is, it is right between a number with binary - // significand - // m and another number with binary significand m+1; and it must be the case - // that it cannot be represented by a float itself. - // - // We must have that w * 10 ^q == (2m+1) * 2^p for some power of two 2^p. - // Recall that 10^q = 5^q * 2^q. - // When q >= 0, we must have that (2m+1) is divible by 5^q, so 5^q <= 2^54. - // We have that - // 5^23 <= 2^54 and it is the last power of five to qualify, so q <= 23. - // When q<0, we have w >= (2m+1) x 5^{-q}. We must have that w<2^{64} so - // (2m+1) x 5^{-q} < 2^{64}. We have that 2m+1>2^{53}. Hence, we must have - // 2^{53} x 5^{-q} < 2^{64}. - // Hence we have 5^{-q} < 2^{11}$ or q>= -4. - // - // We require lower <= 1 and not lower == 0 because we could not prove that - // that lower == 0 is implied; but we could prove that lower <= 1 is a - // necessary and sufficient test. - if (simdjson_unlikely((lower <= 1) && (power >= -4) && (power <= 23) && - ((mantissa & 3) == 1))) { - if ((mantissa << (upperbit + 64 - 53 - 2)) == upper) { - mantissa &= ~1; // flip it so that we do not round up - } - } - - mantissa += mantissa & 1; - mantissa >>= 1; - - // Here we have mantissa < (1<<53), unless there was an overflow - if (mantissa >= (1ULL << 53)) { - ////////// - // This will happen when parsing values such as 7.2057594037927933e+16 - //////// - mantissa = (1ULL << 52); - real_exponent++; - } - mantissa &= ~(1ULL << 52); - // we have to check that real_exponent is in range, otherwise we bail out - if (simdjson_unlikely(real_exponent > 2046)) { - // We have an infinite value!!! We could actually throw an error here if - // we could. - return false; - } - d = to_double(mantissa, real_exponent, negative); - return true; -} - -// We call a fallback floating-point parser that might be slow. Note -// it will accept JSON numbers, but the JSON spec. is more restrictive so -// before you call parse_float_fallback, you need to have validated the input -// string with the JSON grammar. -// It will return an error (false) if the parsed number is infinite. -// The string parsing itself always succeeds. We know that there is at least -// one digit. -static bool parse_float_fallback(const uint8_t *ptr, double *outDouble) { - *outDouble = - simdjson::internal::from_chars(reinterpret_cast(ptr)); - // We do not accept infinite values. - - // Detecting finite values in a portable manner is ridiculously hard, - // ideally - // we would want to do: - // return !std::isfinite(*outDouble); - // but that mysteriously fails under legacy/old libc++ libraries, see - // https://github.com/simdjson/simdjson/issues/1286 - // - // Therefore, fall back to this solution (the extra parens are there - // to handle that max may be a macro on windows). - return !(*outDouble > (std::numeric_limits::max)() || - *outDouble < std::numeric_limits::lowest()); -} -static bool parse_float_fallback(const uint8_t *ptr, - const uint8_t *end_ptr, - double *outDouble) { - *outDouble = - simdjson::internal::from_chars(reinterpret_cast(ptr), - reinterpret_cast(end_ptr)); - // We do not accept infinite values. - - // Detecting finite values in a portable manner is ridiculously hard, - // ideally - // we would want to do: - // return !std::isfinite(*outDouble); - // but that mysteriously fails under legacy/old libc++ libraries, see - // https://github.com/simdjson/simdjson/issues/1286 - // - // Therefore, fall back to this solution (the extra parens are there - // to handle that max may be a macro on windows). - return !(*outDouble > (std::numeric_limits::max)() || - *outDouble < std::numeric_limits::lowest()); -} - -// check quickly whether the next 8 chars are made of digits -// at a glance, it looks better than Mula's -// http://0x80.pl/articles/swar-digits-validate.html -simdjson_really_inline bool is_made_of_eight_digits_fast(const uint8_t *chars) { - uint64_t val; - // this can read up to 7 bytes beyond the buffer size, but we require - // SIMDJSON_PADDING of padding - static_assert(7 <= SIMDJSON_PADDING, - "SIMDJSON_PADDING must be bigger than 7"); - std::memcpy(&val, chars, 8); - // a branchy method might be faster: - // return (( val & 0xF0F0F0F0F0F0F0F0 ) == 0x3030303030303030) - // && (( (val + 0x0606060606060606) & 0xF0F0F0F0F0F0F0F0 ) == - // 0x3030303030303030); - return (((val & 0xF0F0F0F0F0F0F0F0) | - (((val + 0x0606060606060606) & 0xF0F0F0F0F0F0F0F0) >> 4)) == - 0x3333333333333333); -} - -template -error_code slow_float_parsing(simdjson_unused const uint8_t *src, W writer) { - double d; - if (parse_float_fallback(src, &d)) { - writer.append_double(d); - return SUCCESS; - } - return INVALID_NUMBER(src); -} - -template -SIMDJSON_NO_SANITIZE_UNDEFINED // We deliberately allow overflow here and check - // later - simdjson_really_inline bool - parse_digit(const uint8_t c, I &i) { - const uint8_t digit = static_cast(c - '0'); - if (digit > 9) { - return false; - } - // PERF NOTE: multiplication by 10 is cheaper than arbitrary integer - // multiplication - i = 10 * i + digit; // might overflow, we will handle the overflow later - return true; -} - -simdjson_really_inline error_code -parse_decimal(simdjson_unused const uint8_t *const src, - const uint8_t *&p, - uint64_t &i, - int64_t &exponent) { - // we continue with the fiction that we have an integer. If the - // floating point number is representable as x * 10^z for some integer - // z that fits in 53 bits, then we will be able to convert back the - // the integer into a float in a lossless manner. - const uint8_t *const first_after_period = p; - -#ifdef SIMDJSON_SWAR_NUMBER_PARSING -#if SIMDJSON_SWAR_NUMBER_PARSING - // this helps if we have lots of decimals! - // this turns out to be frequent enough. - if (is_made_of_eight_digits_fast(p)) { - i = i * 100000000 + parse_eight_digits_unrolled(p); - p += 8; - } -#endif // SIMDJSON_SWAR_NUMBER_PARSING -#endif // #ifdef SIMDJSON_SWAR_NUMBER_PARSING - // Unrolling the first digit makes a small difference on some - // implementations (e.g. westmere) - if (parse_digit(*p, i)) { - ++p; - } - while (parse_digit(*p, i)) { - p++; - } - exponent = first_after_period - p; - // Decimal without digits (123.) is illegal - if (exponent == 0) { - return INVALID_NUMBER(src); - } - return SUCCESS; -} - -simdjson_really_inline error_code -parse_exponent(simdjson_unused const uint8_t *const src, - const uint8_t *&p, - int64_t &exponent) { - // Exp Sign: -123.456e[-]78 - bool neg_exp = ('-' == *p); - if (neg_exp || '+' == *p) { - p++; - } // Skip + as well - - // Exponent: -123.456e-[78] - auto start_exp = p; - int64_t exp_number = 0; - while (parse_digit(*p, exp_number)) { - ++p; - } - // It is possible for parse_digit to overflow. - // In particular, it could overflow to INT64_MIN, and we cannot do - - // INT64_MIN. - // Thus we *must* check for possible overflow before we negate exp_number. - - // Performance notes: it may seem like combining the two "simdjson_unlikely - // checks" below into - // a single simdjson_unlikely path would be faster. The reasoning is sound, - // but the compiler may - // not oblige and may, in fact, generate two distinct paths in any case. It - // might be - // possible to do uint64_t(p - start_exp - 1) >= 18 but it could end up - // trading off - // instructions for a simdjson_likely branch, an unconclusive gain. - - // If there were no digits, it's an error. - if (simdjson_unlikely(p == start_exp)) { - return INVALID_NUMBER(src); - } - // We have a valid positive exponent in exp_number at this point, except - // that - // it may have overflowed. - - // If there were more than 18 digits, we may have overflowed the integer. We - // have to do - // something!!!! - if (simdjson_unlikely(p > start_exp + 18)) { - // Skip leading zeroes: 1e000000000000000000001 is technically valid and - // doesn't overflow - while (*start_exp == '0') { - start_exp++; - } - // 19 digits could overflow int64_t and is kind of absurd anyway. We - // don't - // support exponents smaller than -999,999,999,999,999,999 and bigger - // than 999,999,999,999,999,999. - // We can truncate. - // Note that 999999999999999999 is assuredly too large. The maximal - // ieee64 value before - // infinity is ~1.8e308. The smallest subnormal is ~5e-324. So, - // actually, we could - // truncate at 324. - // Note that there is no reason to fail per se at this point in time. - // E.g., 0e999999999999999999999 is a fine number. - if (p > start_exp + 18) { - exp_number = 999999999999999999; - } - } - // At this point, we know that exp_number is a sane, positive, signed - // integer. - // It is <= 999,999,999,999,999,999. As long as 'exponent' is in - // [-8223372036854775808, 8223372036854775808], we won't overflow. Because - // 'exponent' - // is bounded in magnitude by the size of the JSON input, we are fine in - // this universe. - // To sum it up: the next line should never overflow. - exponent += (neg_exp ? -exp_number : exp_number); - return SUCCESS; -} - -simdjson_really_inline size_t significant_digits(const uint8_t *start_digits, - size_t digit_count) { - // It is possible that the integer had an overflow. - // We have to handle the case where we have 0.0000somenumber. - const uint8_t *start = start_digits; - while ((*start == '0') || (*start == '.')) { - ++start; - } - // we over-decrement by one when there is a '.' - return digit_count - size_t(start - start_digits); -} - -template -simdjson_really_inline error_code write_float(const uint8_t *const src, - bool negative, - uint64_t i, - const uint8_t *start_digits, - size_t digit_count, - int64_t exponent, - W &writer) { - // If we frequently had to deal with long strings of digits, - // we could extend our code by using a 128-bit integer instead - // of a 64-bit integer. However, this is uncommon in practice. - // - // 9999999999999999999 < 2**64 so we can accommodate 19 digits. - // If we have a decimal separator, then digit_count - 1 is the number of - // digits, but we - // may not have a decimal separator! - if (simdjson_unlikely(digit_count > 19 && - significant_digits(start_digits, digit_count) > 19)) { - // Ok, chances are good that we had an overflow! - // this is almost never going to get called!!! - // we start anew, going slowly!!! - // This will happen in the following examples: - // 10000000000000000000000000000000000000000000e+308 - // 3.1415926535897932384626433832795028841971693993751 - // - // NOTE: This makes a *copy* of the writer and passes it to - // slow_float_parsing. This happens - // because slow_float_parsing is a non-inlined function. If we passed - // our writer reference to - // it, it would force it to be stored in memory, preventing the compiler - // from picking it apart - // and putting into registers. i.e. if we pass it as reference, it gets - // slow. - // This is what forces the skip_double, as well. - error_code error = slow_float_parsing(src, writer); - writer.skip_double(); - return error; - } - // NOTE: it's weird that the simdjson_unlikely() only wraps half the if, but - // it seems to get slower any other - // way we've tried: - // https://github.com/simdjson/simdjson/pull/990#discussion_r448497331 - // To future reader: we'd love if someone found a better way, or at least - // could explain this result! - if (simdjson_unlikely(exponent < simdjson::internal::smallest_power) || - (exponent > simdjson::internal::largest_power)) { - // - // Important: smallest_power is such that it leads to a zero value. - // Observe that 18446744073709551615e-343 == 0, i.e. (2**64 - 1) e -343 - // is zero - // so something x 10^-343 goes to zero, but not so with something x - // 10^-342. - static_assert(simdjson::internal::smallest_power <= -342, - "smallest_power is not small enough"); - // - if ((exponent < simdjson::internal::smallest_power) || (i == 0)) { - WRITE_DOUBLE(0, src, writer); - return SUCCESS; - } else { // (exponent > largest_power) and (i != 0) - // We have, for sure, an infinite value and simdjson refuses to - // parse infinite values. - return INVALID_NUMBER(src); - } - } - double d; - if (!compute_float_64(exponent, i, negative, d)) { - // we are almost never going to get here. - if (!parse_float_fallback(src, &d)) { - return INVALID_NUMBER(src); - } - } - WRITE_DOUBLE(d, src, writer); - return SUCCESS; -} - -// for performance analysis, it is sometimes useful to skip parsing -#ifdef SIMDJSON_SKIPNUMBERPARSING - -template -simdjson_really_inline error_code parse_number(const uint8_t *const, - W &writer) { - writer.append_s64(0); // always write zero - return SUCCESS; // always succeeds -} - -simdjson_unused simdjson_really_inline simdjson_result parse_unsigned( - const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result parse_integer( - const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result parse_double( - const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result -parse_unsigned_in_string(const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result -parse_integer_in_string(const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result -parse_double_in_string(const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline bool is_negative( - const uint8_t *src) noexcept { - return false; -} -simdjson_unused simdjson_really_inline simdjson_result is_integer( - const uint8_t *src) noexcept { - return false; -} -simdjson_unused simdjson_really_inline simdjson_result -get_number_type(const uint8_t *src) noexcept { - return ondemand::number_type::signed_integer; -} -#else - -// parse the number at src -// define JSON_TEST_NUMBERS for unit testing -// -// It is assumed that the number is followed by a structural ({,},],[) character -// or a white space character. If that is not the case (e.g., when the JSON -// document is made of a single number), then it is necessary to copy the -// content and append a space before calling this function. -// -// Our objective is accurate parsing (ULP of 0) at high speed. -template -simdjson_really_inline error_code parse_number(const uint8_t *const src, - W &writer) { - // - // Check for minus sign - // - bool negative = (*src == '-'); - const uint8_t *p = src + negative; - - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - if (digit_count == 0 || ('0' == *start_digits && digit_count > 1)) { - return INVALID_NUMBER(src); - } - - // - // Handle floats if there is a . or e (or both) - // - int64_t exponent = 0; - bool is_float = false; - if ('.' == *p) { - is_float = true; - ++p; - SIMDJSON_TRY(parse_decimal(src, p, i, exponent)); - digit_count = - int(p - start_digits); // used later to guard against overflows - } - if (('e' == *p) || ('E' == *p)) { - is_float = true; - ++p; - SIMDJSON_TRY(parse_exponent(src, p, exponent)); - } - if (is_float) { - const bool dirty_end = - jsoncharutils::is_not_structural_or_whitespace(*p); - SIMDJSON_TRY(write_float( - src, negative, i, start_digits, digit_count, exponent, writer)); - if (dirty_end) { - return INVALID_NUMBER(src); - } - return SUCCESS; - } - - // The longest negative 64-bit number is 19 digits. - // The longest positive 64-bit number is 20 digits. - // We do it this way so we don't trigger this branch unless we must. - size_t longest_digit_count = negative ? 19 : 20; - if (digit_count > longest_digit_count) { - return INVALID_NUMBER(src); - } - if (digit_count == longest_digit_count) { - if (negative) { - // Anything negative above INT64_MAX+1 is invalid - if (i > uint64_t(INT64_MAX) + 1) { - return INVALID_NUMBER(src); - } - WRITE_INTEGER(~i + 1, src, writer); - if (jsoncharutils::is_not_structural_or_whitespace(*p)) { - return INVALID_NUMBER(src); - } - return SUCCESS; - // Positive overflow check: - // - A 20 digit number starting with 2-9 is overflow, because - // 18,446,744,073,709,551,615 is the - // biggest uint64_t. - // - A 20 digit number starting with 1 is overflow if it is less - // than INT64_MAX. - // If we got here, it's a 20 digit number starting with the digit - // "1". - // - If a 20 digit number starting with 1 overflowed (i*10+digit), - // the result will be smaller - // than 1,553,255,926,290,448,384. - // - That is smaller than the smallest possible 20-digit number the - // user could write: - // 10,000,000,000,000,000,000. - // - Therefore, if the number is positive and lower than that, it's - // overflow. - // - The value we are looking at is less than or equal to INT64_MAX. - // - } else if (src[0] != uint8_t('1') || i <= uint64_t(INT64_MAX)) { - return INVALID_NUMBER(src); - } - } - - // Write unsigned if it doesn't fit in a signed integer. - if (i > uint64_t(INT64_MAX)) { - WRITE_UNSIGNED(i, src, writer); - } else { - WRITE_INTEGER(negative ? (~i + 1) : i, src, writer); - } - if (jsoncharutils::is_not_structural_or_whitespace(*p)) { - return INVALID_NUMBER(src); - } - return SUCCESS; -} - -// Inlineable functions -namespace { - -// This table can be used to characterize the final character of an integer -// string. For JSON structural character and allowable white space characters, -// we return SUCCESS. For 'e', '.' and 'E', we return INCORRECT_TYPE. Otherwise -// we return NUMBER_ERROR. -// Optimization note: we could easily reduce the size of the table by half (to -// 128) -// at the cost of an extra branch. -// Optimization note: we want the values to use at most 8 bits (not, e.g., 32 -// bits): -static_assert(error_code(uint8_t(NUMBER_ERROR)) == NUMBER_ERROR, - "bad NUMBER_ERROR cast"); -static_assert(error_code(uint8_t(SUCCESS)) == SUCCESS, "bad NUMBER_ERROR cast"); -static_assert(error_code(uint8_t(INCORRECT_TYPE)) == INCORRECT_TYPE, - "bad NUMBER_ERROR cast"); - -const uint8_t integer_string_finisher[256] = { - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, - SUCCESS, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, - NUMBER_ERROR, INCORRECT_TYPE, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, INCORRECT_TYPE, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, SUCCESS, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, INCORRECT_TYPE, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, - SUCCESS, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR}; - -// Parse any number from 0 to 18,446,744,073,709,551,615 -simdjson_unused simdjson_really_inline simdjson_result parse_unsigned( - const uint8_t *const src) noexcept { - const uint8_t *p = src; - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // The longest positive 64-bit number is 20 digits. - // We do it this way so we don't trigger this branch unless we must. - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > 20)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > 20)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if (integer_string_finisher[*p] != SUCCESS) { - return error_code(integer_string_finisher[*p]); - } - - if (digit_count == 20) { - // Positive overflow check: - // - A 20 digit number starting with 2-9 is overflow, because - // 18,446,744,073,709,551,615 is the - // biggest uint64_t. - // - A 20 digit number starting with 1 is overflow if it is less than - // INT64_MAX. - // If we got here, it's a 20 digit number starting with the digit "1". - // - If a 20 digit number starting with 1 overflowed (i*10+digit), the - // result will be smaller - // than 1,553,255,926,290,448,384. - // - That is smaller than the smallest possible 20-digit number the user - // could write: - // 10,000,000,000,000,000,000. - // - Therefore, if the number is positive and lower than that, it's - // overflow. - // - The value we are looking at is less than or equal to INT64_MAX. - // - if (src[0] != uint8_t('1') || i <= uint64_t(INT64_MAX)) { - return INCORRECT_TYPE; - } - } - - return i; -} - - -// Parse any number from 0 to 18,446,744,073,709,551,615 -// Never read at src_end or beyond -simdjson_unused simdjson_really_inline simdjson_result parse_unsigned( - const uint8_t *const src, const uint8_t *const src_end) noexcept { - const uint8_t *p = src; - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while ((p != src_end) && parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // The longest positive 64-bit number is 20 digits. - // We do it this way so we don't trigger this branch unless we must. - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > 20)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > 20)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if ((p != src_end) && integer_string_finisher[*p] != SUCCESS) { - return error_code(integer_string_finisher[*p]); - } - - if (digit_count == 20) { - // Positive overflow check: - // - A 20 digit number starting with 2-9 is overflow, because - // 18,446,744,073,709,551,615 is the - // biggest uint64_t. - // - A 20 digit number starting with 1 is overflow if it is less than - // INT64_MAX. - // If we got here, it's a 20 digit number starting with the digit "1". - // - If a 20 digit number starting with 1 overflowed (i*10+digit), the - // result will be smaller - // than 1,553,255,926,290,448,384. - // - That is smaller than the smallest possible 20-digit number the user - // could write: - // 10,000,000,000,000,000,000. - // - Therefore, if the number is positive and lower than that, it's - // overflow. - // - The value we are looking at is less than or equal to INT64_MAX. - // - if (src[0] != uint8_t('1') || i <= uint64_t(INT64_MAX)) { - return INCORRECT_TYPE; - } - } - - return i; -} - -// Parse any number from 0 to 18,446,744,073,709,551,615 -simdjson_unused simdjson_really_inline simdjson_result -parse_unsigned_in_string(const uint8_t *const src) noexcept { - const uint8_t *p = src + 1; - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // The longest positive 64-bit number is 20 digits. - // We do it this way so we don't trigger this branch unless we must. - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > 20)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > 20)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if (*p != '"') { - return NUMBER_ERROR; - } - - if (digit_count == 20) { - // Positive overflow check: - // - A 20 digit number starting with 2-9 is overflow, because - // 18,446,744,073,709,551,615 is the - // biggest uint64_t. - // - A 20 digit number starting with 1 is overflow if it is less than - // INT64_MAX. - // If we got here, it's a 20 digit number starting with the digit "1". - // - If a 20 digit number starting with 1 overflowed (i*10+digit), the - // result will be smaller - // than 1,553,255,926,290,448,384. - // - That is smaller than the smallest possible 20-digit number the user - // could write: - // 10,000,000,000,000,000,000. - // - Therefore, if the number is positive and lower than that, it's - // overflow. - // - The value we are looking at is less than or equal to INT64_MAX. - // - // Note: we use src[1] and not src[0] because src[0] is the quote - // character in this - // instance. - if (src[1] != uint8_t('1') || i <= uint64_t(INT64_MAX)) { - return INCORRECT_TYPE; - } - } - - return i; -} - -// Parse any number from -9,223,372,036,854,775,808 to -// 9,223,372,036,854,775,807 -simdjson_unused simdjson_really_inline simdjson_result parse_integer( - const uint8_t *src) noexcept { - // - // Check for minus sign - // - bool negative = (*src == '-'); - const uint8_t *p = src + negative; - - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // We go from - // -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807 - // so we can never represent numbers that have more than 19 digits. - size_t longest_digit_count = 19; - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > longest_digit_count)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > longest_digit_count)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if (integer_string_finisher[*p] != SUCCESS) { - return error_code(integer_string_finisher[*p]); - } - // Negative numbers have can go down to - INT64_MAX - 1 whereas positive - // numbers are limited to INT64_MAX. - // Performance note: This check is only needed when digit_count == - // longest_digit_count but it is - // so cheap that we might as well always make it. - if (i > uint64_t(INT64_MAX) + uint64_t(negative)) { - return INCORRECT_TYPE; - } - return negative ? (~i + 1) : i; -} - -// Parse any number from -9,223,372,036,854,775,808 to -// 9,223,372,036,854,775,807 -// Never read at src_end or beyond -simdjson_unused simdjson_really_inline simdjson_result parse_integer( - const uint8_t *const src, const uint8_t *const src_end) noexcept { - // - // Check for minus sign - // - if (src == src_end) { - return NUMBER_ERROR; - } - bool negative = (*src == '-'); - const uint8_t *p = src + negative; - - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while ((p != src_end) && parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // We go from - // -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807 - // so we can never represent numbers that have more than 19 digits. - size_t longest_digit_count = 19; - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > longest_digit_count)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > longest_digit_count)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if ((p != src_end) && integer_string_finisher[*p] != SUCCESS) { - return error_code(integer_string_finisher[*p]); - } - // Negative numbers have can go down to - INT64_MAX - 1 whereas positive - // numbers are limited to INT64_MAX. - // Performance note: This check is only needed when digit_count == - // longest_digit_count but it is - // so cheap that we might as well always make it. - if (i > uint64_t(INT64_MAX) + uint64_t(negative)) { - return INCORRECT_TYPE; - } - return negative ? (~i + 1) : i; -} - -// Parse any number from -9,223,372,036,854,775,808 to -// 9,223,372,036,854,775,807 -simdjson_unused simdjson_really_inline simdjson_result -parse_integer_in_string(const uint8_t *src) noexcept { - // - // Check for minus sign - // - bool negative = (*(src + 1) == '-'); - const uint8_t *p = src + negative + 1; - - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // We go from - // -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807 - // so we can never represent numbers that have more than 19 digits. - size_t longest_digit_count = 19; - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > longest_digit_count)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > longest_digit_count)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if (*p != '"') { - return NUMBER_ERROR; - } - // Negative numbers have can go down to - INT64_MAX - 1 whereas positive - // numbers are limited to INT64_MAX. - // Performance note: This check is only needed when digit_count == - // longest_digit_count but it is - // so cheap that we might as well always make it. - if (i > uint64_t(INT64_MAX) + uint64_t(negative)) { - return INCORRECT_TYPE; - } - return negative ? (~i + 1) : i; -} - -simdjson_unused simdjson_really_inline simdjson_result parse_double( - const uint8_t *src) noexcept { - // - // Check for minus sign - // - bool negative = (*src == '-'); - src += negative; - - // - // Parse the integer part. - // - uint64_t i = 0; - const uint8_t *p = src; - p += parse_digit(*p, i); - bool leading_zero = (i == 0); - while (parse_digit(*p, i)) { - p++; - } - // no integer digits, or 0123 (zero must be solo) - if (p == src) { - return INCORRECT_TYPE; - } - if ((leading_zero && p != src + 1)) { - return NUMBER_ERROR; - } - - // - // Parse the decimal part. - // - int64_t exponent = 0; - bool overflow; - if (simdjson_likely(*p == '.')) { - p++; - const uint8_t *start_decimal_digits = p; - if (!parse_digit(*p, i)) { - return NUMBER_ERROR; - } // no decimal digits - p++; - while (parse_digit(*p, i)) { - p++; - } - exponent = -(p - start_decimal_digits); - - // Overflow check. More than 19 digits (minus the decimal) may be - // overflow. - overflow = p - src - 1 > 19; - if (simdjson_unlikely(overflow && leading_zero)) { - // Skip leading 0.00000 and see if it still overflows - const uint8_t *start_digits = src + 2; - while (*start_digits == '0') { - start_digits++; - } - overflow = start_digits - src > 19; - } - } else { - overflow = p - src > 19; - } - - // - // Parse the exponent - // - if (*p == 'e' || *p == 'E') { - p++; - bool exp_neg = *p == '-'; - p += exp_neg || *p == '+'; - - uint64_t exp = 0; - const uint8_t *start_exp_digits = p; - while (parse_digit(*p, exp)) { - p++; - } - // no exp digits, or 20+ exp digits - if (p - start_exp_digits == 0 || p - start_exp_digits > 19) { - return NUMBER_ERROR; - } - - exponent += exp_neg ? 0 - exp : exp; - } - - if (jsoncharutils::is_not_structural_or_whitespace(*p)) { - return NUMBER_ERROR; - } - - overflow = overflow || exponent < simdjson::internal::smallest_power || - exponent > simdjson::internal::largest_power; - - // - // Assemble (or slow-parse) the float - // - double d; - if (simdjson_likely(!overflow)) { - if (compute_float_64(exponent, i, negative, d)) { - return d; - } - } - if (!parse_float_fallback(src - negative, &d)) { - return NUMBER_ERROR; - } - return d; -} - -simdjson_unused simdjson_really_inline bool is_negative( - const uint8_t *src) noexcept { - return (*src == '-'); -} - -simdjson_unused simdjson_really_inline simdjson_result is_integer( - const uint8_t *src) noexcept { - bool negative = (*src == '-'); - src += negative; - const uint8_t *p = src; - while (static_cast(*p - '0') <= 9) { - p++; - } - if (p == src) { - return NUMBER_ERROR; - } - if (jsoncharutils::is_structural_or_whitespace(*p)) { - return true; - } - return false; -} - -simdjson_unused simdjson_really_inline simdjson_result -get_number_type(const uint8_t *src) noexcept { - bool negative = (*src == '-'); - src += negative; - const uint8_t *p = src; - while (static_cast(*p - '0') <= 9) { - p++; - } - if (p == src) { - return NUMBER_ERROR; - } - if (jsoncharutils::is_structural_or_whitespace(*p)) { - int digit_count = int(p - src); - if (digit_count >= 19) { - const uint8_t *smaller_big_integer = - reinterpret_cast("9223372036854775808"); - if ((digit_count >= 20) || - (memcmp(src, smaller_big_integer, 19) >= 0)) { - return ondemand::number_type::unsigned_integer; - } - } - return ondemand::number_type::signed_integer; - } - return ondemand::number_type::floating_point_number; -} - -// Never read at src_end or beyond -simdjson_unused simdjson_really_inline simdjson_result parse_double( - const uint8_t *src, const uint8_t *const src_end) noexcept { - if (src == src_end) { - return NUMBER_ERROR; - } - // - // Check for minus sign - // - bool negative = (*src == '-'); - src += negative; - - // - // Parse the integer part. - // - uint64_t i = 0; - const uint8_t *p = src; - if (p == src_end) { - return NUMBER_ERROR; - } - p += parse_digit(*p, i); - bool leading_zero = (i == 0); - while ((p != src_end) && parse_digit(*p, i)) { - p++; - } - // no integer digits, or 0123 (zero must be solo) - if (p == src) { - return INCORRECT_TYPE; - } - if ((leading_zero && p != src + 1)) { - return NUMBER_ERROR; - } - - // - // Parse the decimal part. - // - int64_t exponent = 0; - bool overflow; - if (simdjson_likely((p != src_end) && (*p == '.'))) { - p++; - const uint8_t *start_decimal_digits = p; - if ((p == src_end) || !parse_digit(*p, i)) { - return NUMBER_ERROR; - } // no decimal digits - p++; - while ((p != src_end) && parse_digit(*p, i)) { - p++; - } - exponent = -(p - start_decimal_digits); - - // Overflow check. More than 19 digits (minus the decimal) may be - // overflow. - overflow = p - src - 1 > 19; - if (simdjson_unlikely(overflow && leading_zero)) { - // Skip leading 0.00000 and see if it still overflows - const uint8_t *start_digits = src + 2; - while (*start_digits == '0') { - start_digits++; - } - overflow = start_digits - src > 19; - } - } else { - overflow = p - src > 19; - } - - // - // Parse the exponent - // - if ((p != src_end) && (*p == 'e' || *p == 'E')) { - p++; - if (p == src_end) { - return NUMBER_ERROR; - } - bool exp_neg = *p == '-'; - p += exp_neg || *p == '+'; - - uint64_t exp = 0; - const uint8_t *start_exp_digits = p; - while ((p != src_end) && parse_digit(*p, exp)) { - p++; - } - // no exp digits, or 20+ exp digits - if (p - start_exp_digits == 0 || p - start_exp_digits > 19) { - return NUMBER_ERROR; - } - - exponent += exp_neg ? 0 - exp : exp; - } - - if ((p != src_end) && jsoncharutils::is_not_structural_or_whitespace(*p)) { - return NUMBER_ERROR; - } - - overflow = overflow || exponent < simdjson::internal::smallest_power || - exponent > simdjson::internal::largest_power; - - // - // Assemble (or slow-parse) the float - // - double d; - if (simdjson_likely(!overflow)) { - if (compute_float_64(exponent, i, negative, d)) { - return d; - } - } - if (!parse_float_fallback(src - negative, src_end, &d)) { - return NUMBER_ERROR; - } - return d; -} - -simdjson_unused simdjson_really_inline simdjson_result -parse_double_in_string(const uint8_t *src) noexcept { - // - // Check for minus sign - // - bool negative = (*(src + 1) == '-'); - src += negative + 1; - - // - // Parse the integer part. - // - uint64_t i = 0; - const uint8_t *p = src; - p += parse_digit(*p, i); - bool leading_zero = (i == 0); - while (parse_digit(*p, i)) { - p++; - } - // no integer digits, or 0123 (zero must be solo) - if (p == src) { - return INCORRECT_TYPE; - } - if ((leading_zero && p != src + 1)) { - return NUMBER_ERROR; - } - - // - // Parse the decimal part. - // - int64_t exponent = 0; - bool overflow; - if (simdjson_likely(*p == '.')) { - p++; - const uint8_t *start_decimal_digits = p; - if (!parse_digit(*p, i)) { - return NUMBER_ERROR; - } // no decimal digits - p++; - while (parse_digit(*p, i)) { - p++; - } - exponent = -(p - start_decimal_digits); - - // Overflow check. More than 19 digits (minus the decimal) may be - // overflow. - overflow = p - src - 1 > 19; - if (simdjson_unlikely(overflow && leading_zero)) { - // Skip leading 0.00000 and see if it still overflows - const uint8_t *start_digits = src + 2; - while (*start_digits == '0') { - start_digits++; - } - overflow = start_digits - src > 19; - } - } else { - overflow = p - src > 19; - } - - // - // Parse the exponent - // - if (*p == 'e' || *p == 'E') { - p++; - bool exp_neg = *p == '-'; - p += exp_neg || *p == '+'; - - uint64_t exp = 0; - const uint8_t *start_exp_digits = p; - while (parse_digit(*p, exp)) { - p++; - } - // no exp digits, or 20+ exp digits - if (p - start_exp_digits == 0 || p - start_exp_digits > 19) { - return NUMBER_ERROR; - } - - exponent += exp_neg ? 0 - exp : exp; - } - - if (*p != '"') { - return NUMBER_ERROR; - } - - overflow = overflow || exponent < simdjson::internal::smallest_power || - exponent > simdjson::internal::largest_power; - - // - // Assemble (or slow-parse) the float - // - double d; - if (simdjson_likely(!overflow)) { - if (compute_float_64(exponent, i, negative, d)) { - return d; - } - } - if (!parse_float_fallback(src - negative, &d)) { - return NUMBER_ERROR; - } - return d; -} -} // namespace {} -#endif // SIMDJSON_SKIPNUMBERPARSING - -} // namespace numberparsing -} // unnamed namespace -} // namespace ppc64 -} // namespace simdjson -/* end file include/simdjson/generic/numberparsing.h */ - -#endif // SIMDJSON_PPC64_NUMBERPARSING_H -/* end file include/simdjson/ppc64/numberparsing.h */ -/* begin file include/simdjson/ppc64/end.h */ -/* end file include/simdjson/ppc64/end.h */ - -#endif // SIMDJSON_IMPLEMENTATION_PPC64 - -#endif // SIMDJSON_PPC64_H -/* end file include/simdjson/ppc64.h */ -/* begin file include/simdjson/westmere.h */ -#ifndef SIMDJSON_WESTMERE_H -#define SIMDJSON_WESTMERE_H - - -#if SIMDJSON_IMPLEMENTATION_WESTMERE - -#if SIMDJSON_CAN_ALWAYS_RUN_WESTMERE -#define SIMDJSON_TARGET_WESTMERE -#define SIMDJSON_UNTARGET_WESTMERE -#else -#define SIMDJSON_TARGET_WESTMERE SIMDJSON_TARGET_REGION("sse4.2,pclmul") -#define SIMDJSON_UNTARGET_WESTMERE SIMDJSON_UNTARGET_REGION -#endif - -namespace simdjson { -/** - * Implementation for Westmere (Intel SSE4.2). - */ -namespace westmere {} // namespace westmere -} // namespace simdjson - -// -// These two need to be included outside SIMDJSON_TARGET_WESTMERE -// -/* begin file include/simdjson/westmere/implementation.h */ -#ifndef SIMDJSON_WESTMERE_IMPLEMENTATION_H -#define SIMDJSON_WESTMERE_IMPLEMENTATION_H - - -// The constructor may be executed on any host, so we take care not to use -// SIMDJSON_TARGET_WESTMERE -namespace simdjson { -namespace westmere { - -namespace { -using namespace simdjson; -using namespace simdjson::dom; -} - -class implementation final : public simdjson::implementation { - public: - simdjson_really_inline implementation() - : simdjson::implementation("westmere", - "Intel/AMD SSE4.2", - internal::instruction_set::SSE42 | - internal::instruction_set::PCLMULQDQ) {} - simdjson_warn_unused error_code create_dom_parser_implementation( - size_t capacity, - size_t max_length, - std::unique_ptr &dst) const - noexcept final; - simdjson_warn_unused error_code - minify(const uint8_t *buf, size_t len, uint8_t *dst, size_t &dst_len) const - noexcept final; - simdjson_warn_unused bool validate_utf8(const char *buf, size_t len) const - noexcept final; -}; - -} // namespace westmere -} // namespace simdjson - -#endif // SIMDJSON_WESTMERE_IMPLEMENTATION_H -/* end file include/simdjson/westmere/implementation.h */ -/* begin file include/simdjson/westmere/intrinsics.h */ -#ifndef SIMDJSON_WESTMERE_INTRINSICS_H -#define SIMDJSON_WESTMERE_INTRINSICS_H - -#ifdef SIMDJSON_VISUAL_STUDIO -// under clang within visual studio, this will include -#include // visual studio or clang -#else -#include // elsewhere -#endif // SIMDJSON_VISUAL_STUDIO - - -#ifdef SIMDJSON_CLANG_VISUAL_STUDIO -/** - * You are not supposed, normally, to include these - * headers directly. Instead you should either include intrin.h - * or x86intrin.h. However, when compiling with clang - * under Windows (i.e., when _MSC_VER is set), these headers - * only get included *if* the corresponding features are detected - * from macros: - */ -#include // for _mm_alignr_epi8 -#include // for _mm_clmulepi64_si128 -#endif - - -#endif // SIMDJSON_WESTMERE_INTRINSICS_H -/* end file include/simdjson/westmere/intrinsics.h */ - -// -// The rest need to be inside the region -// -/* begin file include/simdjson/westmere/begin.h */ -// redefining SIMDJSON_IMPLEMENTATION to "westmere" -// #define SIMDJSON_IMPLEMENTATION westmere -SIMDJSON_TARGET_WESTMERE -/* end file include/simdjson/westmere/begin.h */ - -// Declarations -/* begin file include/simdjson/generic/dom_parser_implementation.h */ - -namespace simdjson { -namespace westmere { - -// expectation: sizeof(open_container) = 64/8. -struct open_container { - uint32_t tape_index; // where, on the tape, does the scope ([,{) begins - uint32_t count; // how many elements in the scope -}; // struct open_container - -static_assert(sizeof(open_container) == 64 / 8, - "Open container must be 64 bits"); - -class dom_parser_implementation final - : public internal::dom_parser_implementation { - public: - /** Tape location of each open { or [ */ - std::unique_ptr open_containers{}; - /** Whether each open container is a [ or { */ - std::unique_ptr is_array{}; - /** Buffer passed to stage 1 */ - const uint8_t *buf{}; - /** Length passed to stage 1 */ - size_t len{0}; - /** Document passed to stage 2 */ - dom::document *doc{}; - - inline dom_parser_implementation() noexcept; - inline dom_parser_implementation( - dom_parser_implementation &&other) noexcept; - inline dom_parser_implementation &operator=( - dom_parser_implementation &&other) noexcept; - dom_parser_implementation(const dom_parser_implementation &) = delete; - dom_parser_implementation &operator=(const dom_parser_implementation &) = - delete; - - simdjson_warn_unused error_code parse(const uint8_t *buf, - size_t len, - dom::document &doc) noexcept final; - simdjson_warn_unused error_code stage1(const uint8_t *buf, - size_t len, - stage1_mode partial) noexcept final; - simdjson_warn_unused error_code stage2(dom::document &doc) noexcept final; - simdjson_warn_unused error_code - stage2_next(dom::document &doc) noexcept final; - inline simdjson_warn_unused error_code - set_capacity(size_t capacity) noexcept final; - inline simdjson_warn_unused error_code - set_max_depth(size_t max_depth) noexcept final; - - private: - simdjson_really_inline simdjson_warn_unused error_code - set_capacity_stage1(size_t capacity); -}; - -} // namespace westmere -} // namespace simdjson - -namespace simdjson { -namespace westmere { - -inline dom_parser_implementation::dom_parser_implementation() noexcept = - default; -inline dom_parser_implementation::dom_parser_implementation( - dom_parser_implementation &&other) noexcept = default; -inline dom_parser_implementation &dom_parser_implementation::operator=( - dom_parser_implementation &&other) noexcept = default; - -// Leaving these here so they can be inlined if so desired -inline simdjson_warn_unused error_code -dom_parser_implementation::set_capacity(size_t capacity) noexcept { - if (capacity > SIMDJSON_MAXSIZE_BYTES) { - return CAPACITY; - } - // Stage 1 index output - size_t max_structures = SIMDJSON_ROUNDUP_N(capacity, 64) + 2 + 7; - structural_indexes.reset(new (std::nothrow) uint32_t[max_structures]); - if (!structural_indexes) { - _capacity = 0; - return MEMALLOC; - } - structural_indexes[0] = 0; - n_structural_indexes = 0; - - _capacity = capacity; - return SUCCESS; -} - -inline simdjson_warn_unused error_code -dom_parser_implementation::set_max_depth(size_t max_depth) noexcept { - // Stage 2 stacks - open_containers.reset(new (std::nothrow) open_container[max_depth]); - is_array.reset(new (std::nothrow) bool[max_depth]); - if (!is_array || !open_containers) { - _max_depth = 0; - return MEMALLOC; - } - - _max_depth = max_depth; - return SUCCESS; -} - -} // namespace westmere -} // namespace simdjson -/* end file include/simdjson/generic/dom_parser_implementation.h */ -/* begin file include/simdjson/westmere/bitmanipulation.h */ -#ifndef SIMDJSON_WESTMERE_BITMANIPULATION_H -#define SIMDJSON_WESTMERE_BITMANIPULATION_H - -namespace simdjson { -namespace westmere { -namespace { - -// We sometimes call trailing_zero on inputs that are zero, -// but the algorithms do not end up using the returned value. -// Sadly, sanitizers are not smart enough to figure it out. -SIMDJSON_NO_SANITIZE_UNDEFINED -simdjson_really_inline int trailing_zeroes(uint64_t input_num) { -#ifdef SIMDJSON_REGULAR_VISUAL_STUDIO - unsigned long ret; - // Search the mask data from least significant bit (LSB) - // to the most significant bit (MSB) for a set bit (1). - _BitScanForward64(&ret, input_num); - return (int)ret; -#else // SIMDJSON_REGULAR_VISUAL_STUDIO - return __builtin_ctzll(input_num); -#endif // SIMDJSON_REGULAR_VISUAL_STUDIO -} - -/* result might be undefined when input_num is zero */ -simdjson_really_inline uint64_t clear_lowest_bit(uint64_t input_num) { - return input_num & (input_num - 1); -} - -/* result might be undefined when input_num is zero */ -simdjson_really_inline int leading_zeroes(uint64_t input_num) { -#ifdef SIMDJSON_REGULAR_VISUAL_STUDIO - unsigned long leading_zero = 0; - // Search the mask data from most significant bit (MSB) - // to least significant bit (LSB) for a set bit (1). - if (_BitScanReverse64(&leading_zero, input_num)) - return (int)(63 - leading_zero); - else - return 64; -#else - return __builtin_clzll(input_num); -#endif // SIMDJSON_REGULAR_VISUAL_STUDIO -} - -#ifdef SIMDJSON_REGULAR_VISUAL_STUDIO -simdjson_really_inline unsigned __int64 count_ones(uint64_t input_num) { - // note: we do not support legacy 32-bit Windows - return __popcnt64(input_num); // Visual Studio wants two underscores -} -#else -simdjson_really_inline long long int count_ones(uint64_t input_num) { - return _popcnt64(input_num); -} -#endif - -simdjson_really_inline bool add_overflow(uint64_t value1, - uint64_t value2, - uint64_t *result) { -#ifdef SIMDJSON_REGULAR_VISUAL_STUDIO - return _addcarry_u64( - 0, value1, value2, reinterpret_cast(result)); -#else - return __builtin_uaddll_overflow( - value1, value2, reinterpret_cast(result)); -#endif -} - -} // unnamed namespace -} // namespace westmere -} // namespace simdjson - -#endif // SIMDJSON_WESTMERE_BITMANIPULATION_H -/* end file include/simdjson/westmere/bitmanipulation.h */ -/* begin file include/simdjson/westmere/bitmask.h */ -#ifndef SIMDJSON_WESTMERE_BITMASK_H -#define SIMDJSON_WESTMERE_BITMASK_H - -namespace simdjson { -namespace westmere { -namespace { - -// -// Perform a "cumulative bitwise xor," flipping bits each time a 1 is -// encountered. -// -// For example, prefix_xor(00100100) == 00011100 -// -simdjson_really_inline uint64_t prefix_xor(const uint64_t bitmask) { - // There should be no such thing with a processing supporting avx2 - // but not clmul. - __m128i all_ones = _mm_set1_epi8('\xFF'); - __m128i result = - _mm_clmulepi64_si128(_mm_set_epi64x(0ULL, bitmask), all_ones, 0); - return _mm_cvtsi128_si64(result); -} - -} // unnamed namespace -} // namespace westmere -} // namespace simdjson - -#endif // SIMDJSON_WESTMERE_BITMASK_H -/* end file include/simdjson/westmere/bitmask.h */ -/* begin file include/simdjson/westmere/simd.h */ -#ifndef SIMDJSON_WESTMERE_SIMD_H -#define SIMDJSON_WESTMERE_SIMD_H - - -namespace simdjson { -namespace westmere { -namespace { -namespace simd { - -template -struct base { - __m128i value; - - // Zero constructor - simdjson_really_inline base() : value{__m128i()} {} - - // Conversion from SIMD register - simdjson_really_inline base(const __m128i _value) : value(_value) {} - - // Conversion to SIMD register - simdjson_really_inline operator const __m128i &() const { - return this->value; - } - simdjson_really_inline operator __m128i &() { return this->value; } - - // Bit operations - simdjson_really_inline Child operator|(const Child other) const { - return _mm_or_si128(*this, other); - } - simdjson_really_inline Child operator&(const Child other) const { - return _mm_and_si128(*this, other); - } - simdjson_really_inline Child operator^(const Child other) const { - return _mm_xor_si128(*this, other); - } - simdjson_really_inline Child bit_andnot(const Child other) const { - return _mm_andnot_si128(other, *this); - } - simdjson_really_inline Child &operator|=(const Child other) { - auto this_cast = static_cast(this); - *this_cast = *this_cast | other; - return *this_cast; - } - simdjson_really_inline Child &operator&=(const Child other) { - auto this_cast = static_cast(this); - *this_cast = *this_cast & other; - return *this_cast; - } - simdjson_really_inline Child &operator^=(const Child other) { - auto this_cast = static_cast(this); - *this_cast = *this_cast ^ other; - return *this_cast; - } -}; - -// Forward-declared so they can be used by splat and friends. -template -struct simd8; - -template > -struct base8 : base> { - typedef uint16_t bitmask_t; - typedef uint32_t bitmask2_t; - - simdjson_really_inline base8() : base>() {} - simdjson_really_inline base8(const __m128i _value) - : base>(_value) {} - - friend simdjson_really_inline Mask operator==(const simd8 lhs, - const simd8 rhs) { - return _mm_cmpeq_epi8(lhs, rhs); - } - - static const int SIZE = sizeof(base>::value); - - template - simdjson_really_inline simd8 prev(const simd8 prev_chunk) const { - return _mm_alignr_epi8(*this, prev_chunk, 16 - N); - } -}; - -// SIMD byte mask type (returned by things like eq and gt) -template <> -struct simd8 : base8 { - static simdjson_really_inline simd8 splat(bool _value) { - return _mm_set1_epi8(uint8_t(-(!!_value))); - } - - simdjson_really_inline simd8() : base8() {} - simdjson_really_inline simd8(const __m128i _value) - : base8(_value) {} - // Splat constructor - simdjson_really_inline simd8(bool _value) - : base8(splat(_value)) {} - - simdjson_really_inline int to_bitmask() const { - return _mm_movemask_epi8(*this); - } - simdjson_really_inline bool any() const { - return !_mm_testz_si128(*this, *this); - } - simdjson_really_inline simd8 operator~() const { - return *this ^ true; - } -}; - -template -struct base8_numeric : base8 { - static simdjson_really_inline simd8 splat(T _value) { - return _mm_set1_epi8(_value); - } - static simdjson_really_inline simd8 zero() { - return _mm_setzero_si128(); - } - static simdjson_really_inline simd8 load(const T values[16]) { - return _mm_loadu_si128(reinterpret_cast(values)); - } - // Repeat 16 values as many times as necessary (usually for lookup tables) - static simdjson_really_inline simd8 repeat_16(T v0, - T v1, - T v2, - T v3, - T v4, - T v5, - T v6, - T v7, - T v8, - T v9, - T v10, - T v11, - T v12, - T v13, - T v14, - T v15) { - return simd8(v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15); - } - - simdjson_really_inline base8_numeric() : base8() {} - simdjson_really_inline base8_numeric(const __m128i _value) - : base8(_value) {} - - // Store to array - simdjson_really_inline void store(T dst[16]) const { - return _mm_storeu_si128(reinterpret_cast<__m128i *>(dst), *this); - } - - // Override to distinguish from bool version - simdjson_really_inline simd8 operator~() const { return *this ^ 0xFFu; } - - // Addition/subtraction are the same for signed and unsigned - simdjson_really_inline simd8 operator+(const simd8 other) const { - return _mm_add_epi8(*this, other); - } - simdjson_really_inline simd8 operator-(const simd8 other) const { - return _mm_sub_epi8(*this, other); - } - simdjson_really_inline simd8 &operator+=(const simd8 other) { - *this = *this + other; - return *static_cast *>(this); - } - simdjson_really_inline simd8 &operator-=(const simd8 other) { - *this = *this - other; - return *static_cast *>(this); - } - - // Perform a lookup assuming the value is between 0 and 16 (undefined - // behavior for out of range values) - template - simdjson_really_inline simd8 lookup_16(simd8 lookup_table) const { - return _mm_shuffle_epi8(lookup_table, *this); - } - - // Copies to 'output" all bytes corresponding to a 0 in the mask - // (interpreted as a bitset). - // Passing a 0 value for mask would be equivalent to writing out every byte - // to output. - // Only the first 16 - count_ones(mask) bytes of the result are significant - // but 16 bytes - // get written. - // Design consideration: it seems like a function with the - // signature simd8 compress(uint32_t mask) would be - // sensible, but the AVX ISA makes this kind of approach difficult. - template - simdjson_really_inline void compress(uint16_t mask, L *output) const { - using internal::thintable_epi8; - using internal::BitsSetTable256mul2; - using internal::pshufb_combine_table; - // this particular implementation was inspired by work done by - // @animetosho - // we do it in two steps, first 8 bytes and then second 8 bytes - uint8_t mask1 = uint8_t(mask); // least significant 8 bits - uint8_t mask2 = uint8_t(mask >> 8); // most significant 8 bits - // next line just loads the 64-bit values thintable_epi8[mask1] and - // thintable_epi8[mask2] into a 128-bit register, using only - // two instructions on most compilers. - __m128i shufmask = - _mm_set_epi64x(thintable_epi8[mask2], thintable_epi8[mask1]); - // we increment by 0x08 the second half of the mask - shufmask = - _mm_add_epi8(shufmask, _mm_set_epi32(0x08080808, 0x08080808, 0, 0)); - // this is the version "nearly pruned" - __m128i pruned = _mm_shuffle_epi8(*this, shufmask); - // we still need to put the two halves together. - // we compute the popcount of the first half: - int pop1 = BitsSetTable256mul2[mask1]; - // then load the corresponding mask, what it does is to write - // only the first pop1 bytes from the first 8 bytes, and then - // it fills in with the bytes from the second 8 bytes + some filling - // at the end. - __m128i compactmask = _mm_loadu_si128( - reinterpret_cast(pshufb_combine_table + pop1 * 8)); - __m128i answer = _mm_shuffle_epi8(pruned, compactmask); - _mm_storeu_si128(reinterpret_cast<__m128i *>(output), answer); - } - - template - simdjson_really_inline simd8 lookup_16(L replace0, - L replace1, - L replace2, - L replace3, - L replace4, - L replace5, - L replace6, - L replace7, - L replace8, - L replace9, - L replace10, - L replace11, - L replace12, - L replace13, - L replace14, - L replace15) const { - return lookup_16(simd8::repeat_16(replace0, - replace1, - replace2, - replace3, - replace4, - replace5, - replace6, - replace7, - replace8, - replace9, - replace10, - replace11, - replace12, - replace13, - replace14, - replace15)); - } -}; - -// Signed bytes -template <> -struct simd8 : base8_numeric { - simdjson_really_inline simd8() : base8_numeric() {} - simdjson_really_inline simd8(const __m128i _value) - : base8_numeric(_value) {} - // Splat constructor - simdjson_really_inline simd8(int8_t _value) : simd8(splat(_value)) {} - // Array constructor - simdjson_really_inline simd8(const int8_t *values) : simd8(load(values)) {} - // Member-by-member initialization - simdjson_really_inline simd8(int8_t v0, - int8_t v1, - int8_t v2, - int8_t v3, - int8_t v4, - int8_t v5, - int8_t v6, - int8_t v7, - int8_t v8, - int8_t v9, - int8_t v10, - int8_t v11, - int8_t v12, - int8_t v13, - int8_t v14, - int8_t v15) - : simd8(_mm_setr_epi8(v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15)) {} - // Repeat 16 values as many times as necessary (usually for lookup tables) - simdjson_really_inline static simd8 repeat_16(int8_t v0, - int8_t v1, - int8_t v2, - int8_t v3, - int8_t v4, - int8_t v5, - int8_t v6, - int8_t v7, - int8_t v8, - int8_t v9, - int8_t v10, - int8_t v11, - int8_t v12, - int8_t v13, - int8_t v14, - int8_t v15) { - return simd8(v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15); - } - - // Order-sensitive comparisons - simdjson_really_inline simd8 max_val( - const simd8 other) const { - return _mm_max_epi8(*this, other); - } - simdjson_really_inline simd8 min_val( - const simd8 other) const { - return _mm_min_epi8(*this, other); - } - simdjson_really_inline simd8 operator>( - const simd8 other) const { - return _mm_cmpgt_epi8(*this, other); - } - simdjson_really_inline simd8 operator<( - const simd8 other) const { - return _mm_cmpgt_epi8(other, *this); - } -}; - -// Unsigned bytes -template <> -struct simd8 : base8_numeric { - simdjson_really_inline simd8() : base8_numeric() {} - simdjson_really_inline simd8(const __m128i _value) - : base8_numeric(_value) {} - // Splat constructor - simdjson_really_inline simd8(uint8_t _value) : simd8(splat(_value)) {} - // Array constructor - simdjson_really_inline simd8(const uint8_t *values) : simd8(load(values)) {} - // Member-by-member initialization - simdjson_really_inline simd8(uint8_t v0, - uint8_t v1, - uint8_t v2, - uint8_t v3, - uint8_t v4, - uint8_t v5, - uint8_t v6, - uint8_t v7, - uint8_t v8, - uint8_t v9, - uint8_t v10, - uint8_t v11, - uint8_t v12, - uint8_t v13, - uint8_t v14, - uint8_t v15) - : simd8(_mm_setr_epi8(v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15)) {} - // Repeat 16 values as many times as necessary (usually for lookup tables) - simdjson_really_inline static simd8 repeat_16(uint8_t v0, - uint8_t v1, - uint8_t v2, - uint8_t v3, - uint8_t v4, - uint8_t v5, - uint8_t v6, - uint8_t v7, - uint8_t v8, - uint8_t v9, - uint8_t v10, - uint8_t v11, - uint8_t v12, - uint8_t v13, - uint8_t v14, - uint8_t v15) { - return simd8(v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7, - v8, - v9, - v10, - v11, - v12, - v13, - v14, - v15); - } - - // Saturated math - simdjson_really_inline simd8 saturating_add( - const simd8 other) const { - return _mm_adds_epu8(*this, other); - } - simdjson_really_inline simd8 saturating_sub( - const simd8 other) const { - return _mm_subs_epu8(*this, other); - } - - // Order-specific operations - simdjson_really_inline simd8 max_val( - const simd8 other) const { - return _mm_max_epu8(*this, other); - } - simdjson_really_inline simd8 min_val( - const simd8 other) const { - return _mm_min_epu8(*this, other); - } - // Same as >, but only guarantees true is nonzero (< guarantees true = -1) - simdjson_really_inline simd8 gt_bits( - const simd8 other) const { - return this->saturating_sub(other); - } - // Same as <, but only guarantees true is nonzero (< guarantees true = -1) - simdjson_really_inline simd8 lt_bits( - const simd8 other) const { - return other.saturating_sub(*this); - } - simdjson_really_inline simd8 operator<=( - const simd8 other) const { - return other.max_val(*this) == other; - } - simdjson_really_inline simd8 operator>=( - const simd8 other) const { - return other.min_val(*this) == other; - } - simdjson_really_inline simd8 operator>( - const simd8 other) const { - return this->gt_bits(other).any_bits_set(); - } - simdjson_really_inline simd8 operator<( - const simd8 other) const { - return this->gt_bits(other).any_bits_set(); - } - - // Bit-specific operations - simdjson_really_inline simd8 bits_not_set() const { - return *this == uint8_t(0); - } - simdjson_really_inline simd8 bits_not_set(simd8 bits) const { - return (*this & bits).bits_not_set(); - } - simdjson_really_inline simd8 any_bits_set() const { - return ~this->bits_not_set(); - } - simdjson_really_inline simd8 any_bits_set(simd8 bits) const { - return ~this->bits_not_set(bits); - } - simdjson_really_inline bool is_ascii() const { - return _mm_movemask_epi8(*this) == 0; - } - simdjson_really_inline bool bits_not_set_anywhere() const { - return _mm_testz_si128(*this, *this); - } - simdjson_really_inline bool any_bits_set_anywhere() const { - return !bits_not_set_anywhere(); - } - simdjson_really_inline bool bits_not_set_anywhere( - simd8 bits) const { - return _mm_testz_si128(*this, bits); - } - simdjson_really_inline bool any_bits_set_anywhere( - simd8 bits) const { - return !bits_not_set_anywhere(bits); - } - template - simdjson_really_inline simd8 shr() const { - return simd8(_mm_srli_epi16(*this, N)) & uint8_t(0xFFu >> N); - } - template - simdjson_really_inline simd8 shl() const { - return simd8(_mm_slli_epi16(*this, N)) & uint8_t(0xFFu << N); - } - // Get one of the bits and make a bitmask out of it. - // e.g. value.get_bit<7>() gets the high bit - template - simdjson_really_inline int get_bit() const { - return _mm_movemask_epi8(_mm_slli_epi16(*this, 7 - N)); - } -}; - -template -struct simd8x64 { - static constexpr int NUM_CHUNKS = 64 / sizeof(simd8); - static_assert( - NUM_CHUNKS == 4, - "Westmere kernel should use four registers per 64-byte block."); - const simd8 chunks[NUM_CHUNKS]; - - simd8x64(const simd8x64 &o) = delete; // no copy allowed - simd8x64 &operator=(const simd8 &other) = - delete; // no assignment allowed - simd8x64() = delete; // no default constructor allowed - - simdjson_really_inline simd8x64(const simd8 chunk0, - const simd8 chunk1, - const simd8 chunk2, - const simd8 chunk3) - : chunks{chunk0, chunk1, chunk2, chunk3} {} - simdjson_really_inline simd8x64(const T ptr[64]) - : chunks{simd8::load(ptr), - simd8::load(ptr + 16), - simd8::load(ptr + 32), - simd8::load(ptr + 48)} {} - - simdjson_really_inline void store(T ptr[64]) const { - this->chunks[0].store(ptr + sizeof(simd8) * 0); - this->chunks[1].store(ptr + sizeof(simd8) * 1); - this->chunks[2].store(ptr + sizeof(simd8) * 2); - this->chunks[3].store(ptr + sizeof(simd8) * 3); - } - - simdjson_really_inline simd8 reduce_or() const { - return (this->chunks[0] | this->chunks[1]) | - (this->chunks[2] | this->chunks[3]); - } - - simdjson_really_inline uint64_t compress(uint64_t mask, T *output) const { - this->chunks[0].compress(uint16_t(mask), output); - this->chunks[1].compress(uint16_t(mask >> 16), - output + 16 - count_ones(mask & 0xFFFF)); - this->chunks[2].compress(uint16_t(mask >> 32), - output + 32 - count_ones(mask & 0xFFFFFFFF)); - this->chunks[3].compress( - uint16_t(mask >> 48), - output + 48 - count_ones(mask & 0xFFFFFFFFFFFF)); - return 64 - count_ones(mask); - } - - simdjson_really_inline uint64_t to_bitmask() const { - uint64_t r0 = uint32_t(this->chunks[0].to_bitmask()); - uint64_t r1 = this->chunks[1].to_bitmask(); - uint64_t r2 = this->chunks[2].to_bitmask(); - uint64_t r3 = this->chunks[3].to_bitmask(); - return r0 | (r1 << 16) | (r2 << 32) | (r3 << 48); - } - - simdjson_really_inline uint64_t eq(const T m) const { - const simd8 mask = simd8::splat(m); - return simd8x64(this->chunks[0] == mask, - this->chunks[1] == mask, - this->chunks[2] == mask, - this->chunks[3] == mask) - .to_bitmask(); - } - - simdjson_really_inline uint64_t eq(const simd8x64 &other) const { - return simd8x64(this->chunks[0] == other.chunks[0], - this->chunks[1] == other.chunks[1], - this->chunks[2] == other.chunks[2], - this->chunks[3] == other.chunks[3]) - .to_bitmask(); - } - - simdjson_really_inline uint64_t lteq(const T m) const { - const simd8 mask = simd8::splat(m); - return simd8x64(this->chunks[0] <= mask, - this->chunks[1] <= mask, - this->chunks[2] <= mask, - this->chunks[3] <= mask) - .to_bitmask(); - } -}; // struct simd8x64 - -} // namespace simd -} // unnamed namespace -} // namespace westmere -} // namespace simdjson - -#endif // SIMDJSON_WESTMERE_SIMD_INPUT_H -/* end file include/simdjson/westmere/simd.h */ -/* begin file include/simdjson/generic/jsoncharutils.h */ - -namespace simdjson { -namespace westmere { -namespace { -namespace jsoncharutils { - -// return non-zero if not a structural or whitespace char -// zero otherwise -simdjson_really_inline uint32_t is_not_structural_or_whitespace(uint8_t c) { - return internal::structural_or_whitespace_negated[c]; -} - -simdjson_really_inline uint32_t is_structural_or_whitespace(uint8_t c) { - return internal::structural_or_whitespace[c]; -} - -// returns a value with the high 16 bits set if not valid -// otherwise returns the conversion of the 4 hex digits at src into the bottom -// 16 bits of the 32-bit return register -// -// see -// https://lemire.me/blog/2019/04/17/parsing-short-hexadecimal-strings-efficiently/ -static inline uint32_t hex_to_u32_nocheck( - const uint8_t *src) { // strictly speaking, static inline is a C-ism - uint32_t v1 = internal::digit_to_val32[630 + src[0]]; - uint32_t v2 = internal::digit_to_val32[420 + src[1]]; - uint32_t v3 = internal::digit_to_val32[210 + src[2]]; - uint32_t v4 = internal::digit_to_val32[0 + src[3]]; - return v1 | v2 | v3 | v4; -} - -// given a code point cp, writes to c -// the utf-8 code, outputting the length in -// bytes, if the length is zero, the code point -// is invalid -// -// This can possibly be made faster using pdep -// and clz and table lookups, but JSON documents -// have few escaped code points, and the following -// function looks cheap. -// -// Note: we assume that surrogates are treated separately -// -simdjson_really_inline size_t codepoint_to_utf8(uint32_t cp, uint8_t *c) { - if (cp <= 0x7F) { - c[0] = uint8_t(cp); - return 1; // ascii - } - if (cp <= 0x7FF) { - c[0] = uint8_t((cp >> 6) + 192); - c[1] = uint8_t((cp & 63) + 128); - return 2; // universal plane - // Surrogates are treated elsewhere... - //} //else if (0xd800 <= cp && cp <= 0xdfff) { - // return 0; // surrogates // could put assert here - } else if (cp <= 0xFFFF) { - c[0] = uint8_t((cp >> 12) + 224); - c[1] = uint8_t(((cp >> 6) & 63) + 128); - c[2] = uint8_t((cp & 63) + 128); - return 3; - } else if (cp <= - 0x10FFFF) { // if you know you have a valid code point, this - // is not needed - c[0] = uint8_t((cp >> 18) + 240); - c[1] = uint8_t(((cp >> 12) & 63) + 128); - c[2] = uint8_t(((cp >> 6) & 63) + 128); - c[3] = uint8_t((cp & 63) + 128); - return 4; - } - // will return 0 when the code point was too large. - return 0; // bad r -} - -#ifdef SIMDJSON_IS_32BITS // _umul128 for x86, arm -// this is a slow emulation routine for 32-bit -// -static simdjson_really_inline uint64_t __emulu(uint32_t x, uint32_t y) { - return x * (uint64_t)y; -} -static simdjson_really_inline uint64_t _umul128(uint64_t ab, - uint64_t cd, - uint64_t *hi) { - uint64_t ad = __emulu((uint32_t)(ab >> 32), (uint32_t)cd); - uint64_t bd = __emulu((uint32_t)ab, (uint32_t)cd); - uint64_t adbc = ad + __emulu((uint32_t)ab, (uint32_t)(cd >> 32)); - uint64_t adbc_carry = !!(adbc < ad); - uint64_t lo = bd + (adbc << 32); - *hi = __emulu((uint32_t)(ab >> 32), (uint32_t)(cd >> 32)) + (adbc >> 32) + - (adbc_carry << 32) + !!(lo < bd); - return lo; -} -#endif - -using internal::value128; - -simdjson_really_inline value128 full_multiplication(uint64_t value1, - uint64_t value2) { - value128 answer; -#if defined(SIMDJSON_REGULAR_VISUAL_STUDIO) || defined(SIMDJSON_IS_32BITS) -#ifdef _M_ARM64 - // ARM64 has native support for 64-bit multiplications, no need to emultate - answer.high = __umulh(value1, value2); - answer.low = value1 * value2; -#else - answer.low = _umul128( - value1, value2, &answer.high); // _umul128 not available on ARM64 -#endif // _M_ARM64 -#else // defined(SIMDJSON_REGULAR_VISUAL_STUDIO) || defined(SIMDJSON_IS_32BITS) - __uint128_t r = (static_cast<__uint128_t>(value1)) * value2; - answer.low = uint64_t(r); - answer.high = uint64_t(r >> 64); -#endif - return answer; -} - -} // namespace jsoncharutils -} // unnamed namespace -} // namespace westmere -} // namespace simdjson -/* end file include/simdjson/generic/jsoncharutils.h */ -/* begin file include/simdjson/generic/atomparsing.h */ -namespace simdjson { -namespace westmere { -namespace { -/// @private -namespace atomparsing { - -// The string_to_uint32 is exclusively used to map literal strings to 32-bit -// values. -// We use memcpy instead of a pointer cast to avoid undefined behaviors since we -// cannot -// be certain that the character pointer will be properly aligned. -// You might think that using memcpy makes this function expensive, but you'd be -// wrong. -// All decent optimizing compilers (GCC, clang, Visual Studio) will compile -// string_to_uint32("false"); -// to the compile-time constant 1936482662. -simdjson_really_inline uint32_t string_to_uint32(const char *str) { - uint32_t val; - std::memcpy(&val, str, sizeof(uint32_t)); - return val; -} - - -// Again in str4ncmp we use a memcpy to avoid undefined behavior. The memcpy may -// appear expensive. -// Yet all decent optimizing compilers will compile memcpy to a single -// instruction, just about. -simdjson_warn_unused simdjson_really_inline uint32_t -str4ncmp(const uint8_t *src, const char *atom) { - uint32_t - srcval; // we want to avoid unaligned 32-bit loads (undefined in C/C++) - static_assert(sizeof(uint32_t) <= SIMDJSON_PADDING, - "SIMDJSON_PADDING must be larger than 4 bytes"); - std::memcpy(&srcval, src, sizeof(uint32_t)); - return srcval ^ string_to_uint32(atom); -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_true_atom( - const uint8_t *src) { - return (str4ncmp(src, "true") | - jsoncharutils::is_not_structural_or_whitespace(src[4])) == 0; -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_true_atom( - const uint8_t *src, size_t len) { - if (len > 4) { - return is_valid_true_atom(src); - } else if (len == 4) { - return !str4ncmp(src, "true"); - } else { - return false; - } -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_false_atom( - const uint8_t *src) { - return (str4ncmp(src + 1, "alse") | - jsoncharutils::is_not_structural_or_whitespace(src[5])) == 0; -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_false_atom( - const uint8_t *src, size_t len) { - if (len > 5) { - return is_valid_false_atom(src); - } else if (len == 5) { - return !str4ncmp(src + 1, "alse"); - } else { - return false; - } -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_null_atom( - const uint8_t *src) { - return (str4ncmp(src, "null") | - jsoncharutils::is_not_structural_or_whitespace(src[4])) == 0; -} - -simdjson_warn_unused simdjson_really_inline bool is_valid_null_atom( - const uint8_t *src, size_t len) { - if (len > 4) { - return is_valid_null_atom(src); - } else if (len == 4) { - return !str4ncmp(src, "null"); - } else { - return false; - } -} - -} // namespace atomparsing -} // unnamed namespace -} // namespace westmere -} // namespace simdjson -/* end file include/simdjson/generic/atomparsing.h */ -/* begin file include/simdjson/westmere/stringparsing.h */ -#ifndef SIMDJSON_WESTMERE_STRINGPARSING_H -#define SIMDJSON_WESTMERE_STRINGPARSING_H - -namespace simdjson { -namespace westmere { -namespace { - -using namespace simd; - -// Holds backslashes and quotes locations. -struct backslash_and_quote { - public: - static constexpr uint32_t BYTES_PROCESSED = 32; - simdjson_really_inline static backslash_and_quote copy_and_find( - const uint8_t *src, uint8_t *dst); - - simdjson_really_inline bool has_quote_first() { - return ((bs_bits - 1) & quote_bits) != 0; - } - simdjson_really_inline bool has_backslash() { return bs_bits != 0; } - simdjson_really_inline int quote_index() { - return trailing_zeroes(quote_bits); - } - simdjson_really_inline int backslash_index() { - return trailing_zeroes(bs_bits); - } - - uint32_t bs_bits; - uint32_t quote_bits; -}; // struct backslash_and_quote - -simdjson_really_inline backslash_and_quote -backslash_and_quote::copy_and_find(const uint8_t *src, uint8_t *dst) { - // this can read up to 31 bytes beyond the buffer size, but we require - // SIMDJSON_PADDING of padding - static_assert(SIMDJSON_PADDING >= (BYTES_PROCESSED - 1), - "backslash and quote finder must process fewer than " - "SIMDJSON_PADDING bytes"); - simd8 v0(src); - simd8 v1(src + 16); - v0.store(dst); - v1.store(dst + 16); - uint64_t bs_and_quote = - simd8x64(v0 == '\\', v1 == '\\', v0 == '"', v1 == '"') - .to_bitmask(); - return { - uint32_t(bs_and_quote), // bs_bits - uint32_t(bs_and_quote >> 32) // quote_bits - }; -} - -} // unnamed namespace -} // namespace westmere -} // namespace simdjson - -/* begin file include/simdjson/generic/stringparsing.h */ -// This file contains the common code every implementation uses -// It is intended to be included multiple times and compiled multiple times - -namespace simdjson { -namespace westmere { -namespace { -/// @private -namespace stringparsing { - -// begin copypasta -// These chars yield themselves: " \ / -// b -> backspace, f -> formfeed, n -> newline, r -> cr, t -> horizontal tab -// u not handled in this table as it's complex -static const uint8_t escape_map[256] = { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x0. - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0x22, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x2f, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x4. - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x5c, 0, 0, 0, // 0x5. - 0, 0, 0x08, 0, 0, 0, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0x0a, 0, // 0x6. - 0, 0, 0x0d, 0, 0x09, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x7. - - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -}; - -// handle a unicode codepoint -// write appropriate values into dest -// src will advance 6 bytes or 12 bytes -// dest will advance a variable amount (return via pointer) -// return true if the unicode codepoint was valid -// We work in little-endian then swap at write time -simdjson_warn_unused simdjson_really_inline bool handle_unicode_codepoint( - const uint8_t **src_ptr, uint8_t **dst_ptr) { - // jsoncharutils::hex_to_u32_nocheck fills high 16 bits of the return value - // with 1s if the - // conversion isn't valid; we defer the check for this to inside the - // multilingual plane check - uint32_t code_point = jsoncharutils::hex_to_u32_nocheck(*src_ptr + 2); - *src_ptr += 6; - // check for low surrogate for characters outside the Basic - // Multilingual Plane. - if (code_point >= 0xd800 && code_point < 0xdc00) { - if (((*src_ptr)[0] != '\\') || (*src_ptr)[1] != 'u') { - return false; - } - uint32_t code_point_2 = jsoncharutils::hex_to_u32_nocheck(*src_ptr + 2); - - // if the first code point is invalid we will get here, as we will go - // past - // the check for being outside the Basic Multilingual plane. If we don't - // find a \u immediately afterwards we fail out anyhow, but if we do, - // this check catches both the case of the first code point being - // invalid - // or the second code point being invalid. - if ((code_point | code_point_2) >> 16) { - return false; - } - - code_point = - (((code_point - 0xd800) << 10) | (code_point_2 - 0xdc00)) + 0x10000; - *src_ptr += 6; - } - size_t offset = jsoncharutils::codepoint_to_utf8(code_point, *dst_ptr); - *dst_ptr += offset; - return offset > 0; -} - -/** - * Unescape a string from src to dst, stopping at a final unescaped quote. E.g., - * if src points at 'joe"', then - * dst needs to have four free bytes. - */ -simdjson_warn_unused simdjson_really_inline uint8_t *parse_string( - const uint8_t *src, uint8_t *dst) { - while (1) { - // Copy the next n bytes, and find the backslash and quote in them. - auto bs_quote = backslash_and_quote::copy_and_find(src, dst); - // If the next thing is the end quote, copy and return - if (bs_quote.has_quote_first()) { - // we encountered quotes first. Move dst to point to quotes and exit - return dst + bs_quote.quote_index(); - } - if (bs_quote.has_backslash()) { - /* find out where the backspace is */ - auto bs_dist = bs_quote.backslash_index(); - uint8_t escape_char = src[bs_dist + 1]; - /* we encountered backslash first. Handle backslash */ - if (escape_char == 'u') { - /* move src/dst up to the start; they will be further adjusted - within the unicode codepoint handling code. */ - src += bs_dist; - dst += bs_dist; - if (!handle_unicode_codepoint(&src, &dst)) { - return nullptr; - } - } else { - /* simple 1:1 conversion. Will eat bs_dist+2 characters in input - * and - * write bs_dist+1 characters to output - * note this may reach beyond the part of the buffer we've - * actually - * seen. I think this is ok */ - uint8_t escape_result = escape_map[escape_char]; - if (escape_result == 0u) { - return nullptr; /* bogus escape value is an error */ - } - dst[bs_dist] = escape_result; - src += bs_dist + 2; - dst += bs_dist + 1; - } - } else { - /* they are the same. Since they can't co-occur, it means we - * encountered neither. */ - src += backslash_and_quote::BYTES_PROCESSED; - dst += backslash_and_quote::BYTES_PROCESSED; - } - } - /* can't be reached */ - return nullptr; -} - -simdjson_unused simdjson_warn_unused simdjson_really_inline error_code -parse_string_to_buffer(const uint8_t *src, - uint8_t *¤t_string_buf_loc, - std::string_view &s) { - if (*(src++) != '"') { - return STRING_ERROR; - } - auto end = stringparsing::parse_string(src, current_string_buf_loc); - if (!end) { - return STRING_ERROR; - } - s = std::string_view(reinterpret_cast(current_string_buf_loc), - end - current_string_buf_loc); - current_string_buf_loc = end; - return SUCCESS; -} - -} // namespace stringparsing -} // unnamed namespace -} // namespace westmere -} // namespace simdjson -/* end file include/simdjson/generic/stringparsing.h */ - -#endif // SIMDJSON_WESTMERE_STRINGPARSING_H -/* end file include/simdjson/westmere/stringparsing.h */ -/* begin file include/simdjson/westmere/numberparsing.h */ -#ifndef SIMDJSON_WESTMERE_NUMBERPARSING_H -#define SIMDJSON_WESTMERE_NUMBERPARSING_H - -namespace simdjson { -namespace westmere { -namespace { - -static simdjson_really_inline uint32_t -parse_eight_digits_unrolled(const uint8_t *chars) { - // this actually computes *16* values so we are being wasteful. - const __m128i ascii0 = _mm_set1_epi8('0'); - const __m128i mul_1_10 = - _mm_setr_epi8(10, 1, 10, 1, 10, 1, 10, 1, 10, 1, 10, 1, 10, 1, 10, 1); - const __m128i mul_1_100 = _mm_setr_epi16(100, 1, 100, 1, 100, 1, 100, 1); - const __m128i mul_1_10000 = - _mm_setr_epi16(10000, 1, 10000, 1, 10000, 1, 10000, 1); - const __m128i input = _mm_sub_epi8( - _mm_loadu_si128(reinterpret_cast(chars)), ascii0); - const __m128i t1 = _mm_maddubs_epi16(input, mul_1_10); - const __m128i t2 = _mm_madd_epi16(t1, mul_1_100); - const __m128i t3 = _mm_packus_epi32(t2, t2); - const __m128i t4 = _mm_madd_epi16(t3, mul_1_10000); - return _mm_cvtsi128_si32( - t4); // only captures the sum of the first 8 digits, drop the rest -} - -} // unnamed namespace -} // namespace westmere -} // namespace simdjson - -#define SIMDJSON_SWAR_NUMBER_PARSING 1 - -/* begin file include/simdjson/generic/numberparsing.h */ -#include - -namespace simdjson { -namespace westmere { - -namespace ondemand { -/** - * The type of a JSON number - */ -enum class number_type { - floating_point_number = 1, /// a binary64 number - signed_integer, /// a signed integer that fits in a 64-bit word using two's - /// complement - unsigned_integer /// a positive integer larger or equal to 1<<63 -}; -} - -namespace { -/// @private -namespace numberparsing { - - -#ifdef JSON_TEST_NUMBERS -#define INVALID_NUMBER(SRC) (found_invalid_number((SRC)), NUMBER_ERROR) -#define WRITE_INTEGER(VALUE, SRC, WRITER) \ - (found_integer((VALUE), (SRC)), (WRITER).append_s64((VALUE))) -#define WRITE_UNSIGNED(VALUE, SRC, WRITER) \ - (found_unsigned_integer((VALUE), (SRC)), (WRITER).append_u64((VALUE))) -#define WRITE_DOUBLE(VALUE, SRC, WRITER) \ - (found_float((VALUE), (SRC)), (WRITER).append_double((VALUE))) -#else -#define INVALID_NUMBER(SRC) (NUMBER_ERROR) -#define WRITE_INTEGER(VALUE, SRC, WRITER) (WRITER).append_s64((VALUE)) -#define WRITE_UNSIGNED(VALUE, SRC, WRITER) (WRITER).append_u64((VALUE)) -#define WRITE_DOUBLE(VALUE, SRC, WRITER) (WRITER).append_double((VALUE)) -#endif - -namespace { -// Convert a mantissa, an exponent and a sign bit into an ieee64 double. -// The real_exponent needs to be in [0, 2046] (technically real_exponent = 2047 -// would be acceptable). -// The mantissa should be in [0,1<<53). The bit at index (1ULL << 52) while be -// zeroed. -simdjson_really_inline double to_double(uint64_t mantissa, - uint64_t real_exponent, - bool negative) { - double d; - mantissa &= ~(1ULL << 52); - mantissa |= real_exponent << 52; - mantissa |= ((static_cast(negative)) << 63); - std::memcpy(&d, &mantissa, sizeof(d)); - return d; -} -} -// Attempts to compute i * 10^(power) exactly; and if "negative" is -// true, negate the result. -// This function will only work in some cases, when it does not work, success is -// set to false. This should work *most of the time* (like 99% of the time). -// We assume that power is in the [smallest_power, -// largest_power] interval: the caller is responsible for this check. -simdjson_really_inline bool compute_float_64(int64_t power, - uint64_t i, - bool negative, - double &d) { -// we start with a fast path -// It was described in -// Clinger WD. How to read floating point numbers accurately. -// ACM SIGPLAN Notices. 1990 -#ifndef FLT_EVAL_METHOD -#error "FLT_EVAL_METHOD should be defined, please include cfloat." -#endif -#if (FLT_EVAL_METHOD != 1) && (FLT_EVAL_METHOD != 0) - // We cannot be certain that x/y is rounded to nearest. - if (0 <= power && power <= 22 && i <= 9007199254740991) { -#else - if (-22 <= power && power <= 22 && i <= 9007199254740991) { -#endif - // convert the integer into a double. This is lossless since - // 0 <= i <= 2^53 - 1. - d = double(i); - // - // The general idea is as follows. - // If 0 <= s < 2^53 and if 10^0 <= p <= 10^22 then - // 1) Both s and p can be represented exactly as 64-bit floating-point - // values - // (binary64). - // 2) Because s and p can be represented exactly as floating-point - // values, - // then s * p - // and s / p will produce correctly rounded values. - // - if (power < 0) { - d = d / simdjson::internal::power_of_ten[-power]; - } else { - d = d * simdjson::internal::power_of_ten[power]; - } - if (negative) { - d = -d; - } - return true; - } - // When 22 < power && power < 22 + 16, we could - // hope for another, secondary fast path. It was - // described by David M. Gay in "Correctly rounded - // binary-decimal and decimal-binary conversions." (1990) - // If you need to compute i * 10^(22 + x) for x < 16, - // first compute i * 10^x, if you know that result is exact - // (e.g., when i * 10^x < 2^53), - // then you can still proceed and do (i * 10^x) * 10^22. - // Is this worth your time? - // You need 22 < power *and* power < 22 + 16 *and* (i * 10^(x-22) < 2^53) - // for this second fast path to work. - // If you you have 22 < power *and* power < 22 + 16, and then you - // optimistically compute "i * 10^(x-22)", there is still a chance that you - // have wasted your time if i * 10^(x-22) >= 2^53. It makes the use cases of - // this optimization maybe less common than we would like. Source: - // http://www.exploringbinary.com/fast-path-decimal-to-floating-point-conversion/ - // also used in RapidJSON: https://rapidjson.org/strtod_8h_source.html - - // The fast path has now failed, so we are failing back on the slower path. - - // In the slow path, we need to adjust i so that it is > 1<<63 which is - // always - // possible, except if i == 0, so we handle i == 0 separately. - if (i == 0) { - d = 0.0; - return true; - } - - - // The exponent is 1024 + 63 + power - // + floor(log(5**power)/log(2)). - // The 1024 comes from the ieee64 standard. - // The 63 comes from the fact that we use a 64-bit word. - // - // Computing floor(log(5**power)/log(2)) could be - // slow. Instead we use a fast function. - // - // For power in (-400,350), we have that - // (((152170 + 65536) * power ) >> 16); - // is equal to - // floor(log(5**power)/log(2)) + power when power >= 0 - // and it is equal to - // ceil(log(5**-power)/log(2)) + power when power < 0 - // - // The 65536 is (1<<16) and corresponds to - // (65536 * power) >> 16 ---> power - // - // ((152170 * power ) >> 16) is equal to - // floor(log(5**power)/log(2)) - // - // Note that this is not magic: 152170/(1<<16) is - // approximatively equal to log(5)/log(2). - // The 1<<16 value is a power of two; we could use a - // larger power of 2 if we wanted to. - // - int64_t exponent = (((152170 + 65536) * power) >> 16) + 1024 + 63; - - - // We want the most significant bit of i to be 1. Shift if needed. - int lz = leading_zeroes(i); - i <<= lz; - - - // We are going to need to do some 64-bit arithmetic to get a precise - // product. - // We use a table lookup approach. - // It is safe because - // power >= smallest_power - // and power <= largest_power - // We recover the mantissa of the power, it has a leading 1. It is always - // rounded down. - // - // We want the most significant 64 bits of the product. We know - // this will be non-zero because the most significant bit of i is - // 1. - const uint32_t index = - 2 * uint32_t(power - simdjson::internal::smallest_power); - // Optimization: It may be that materializing the index as a variable might - // confuse some compilers and prevent effective complex-addressing loads. - // (Done for code clarity.) - // - // The full_multiplication function computes the 128-bit product of two - // 64-bit words - // with a returned value of type value128 with a "low component" - // corresponding to the - // 64-bit least significant bits of the product and with a "high component" - // corresponding - // to the 64-bit most significant bits of the product. - simdjson::internal::value128 firstproduct = - jsoncharutils::full_multiplication( - i, simdjson::internal::power_of_five_128[index]); - // Both i and power_of_five_128[index] have their most significant bit set - // to 1 which - // implies that the either the most or the second most significant bit of - // the product - // is 1. We pack values in this manner for efficiency reasons: it maximizes - // the use - // we make of the product. It also makes it easy to reason about the - // product: there - // is 0 or 1 leading zero in the product. - - // Unless the least significant 9 bits of the high (64-bit) part of the full - // product are all 1s, then we know that the most significant 55 bits are - // exact and no further work is needed. Having 55 bits is necessary because - // we need 53 bits for the mantissa but we have to have one rounding bit and - // we can waste a bit if the most significant bit of the product is zero. - if ((firstproduct.high & 0x1FF) == 0x1FF) { - // We want to compute i * 5^q, but only care about the top 55 bits at - // most. - // Consider the scenario where q>=0. Then 5^q may not fit in 64-bits. - // Doing - // the full computation is wasteful. So we do what is called a - // "truncated - // multiplication". - // We take the most significant 64-bits, and we put them in - // power_of_five_128[index]. Usually, that's good enough to approximate - // i * 5^q - // to the desired approximation using one multiplication. Sometimes it - // does not suffice. - // Then we store the next most significant 64 bits in - // power_of_five_128[index + 1], and - // then we get a better approximation to i * 5^q. In very rare cases, - // even that - // will not suffice, though it is seemingly very hard to find such a - // scenario. - // - // That's for when q>=0. The logic for q<0 is somewhat similar but it is - // somewhat - // more complicated. - // - // There is an extra layer of complexity in that we need more than 55 - // bits of - // accuracy in the round-to-even scenario. - // - // The full_multiplication function computes the 128-bit product of two - // 64-bit words - // with a returned value of type value128 with a "low component" - // corresponding to the - // 64-bit least significant bits of the product and with a "high - // component" corresponding - // to the 64-bit most significant bits of the product. - simdjson::internal::value128 secondproduct = - jsoncharutils::full_multiplication( - i, simdjson::internal::power_of_five_128[index + 1]); - firstproduct.low += secondproduct.high; - if (secondproduct.high > firstproduct.low) { - firstproduct.high++; - } - // At this point, we might need to add at most one to firstproduct, but - // this - // can only change the value of firstproduct.high if firstproduct.low is - // maximal. - if (simdjson_unlikely(firstproduct.low == 0xFFFFFFFFFFFFFFFF)) { - // This is very unlikely, but if so, we need to do much more work! - return false; - } - } - uint64_t lower = firstproduct.low; - uint64_t upper = firstproduct.high; - // The final mantissa should be 53 bits with a leading 1. - // We shift it so that it occupies 54 bits with a leading 1. - /////// - uint64_t upperbit = upper >> 63; - uint64_t mantissa = upper >> (upperbit + 9); - lz += int(1 ^ upperbit); - - // Here we have mantissa < (1<<54). - int64_t real_exponent = exponent - lz; - if (simdjson_unlikely(real_exponent <= 0)) { // we have a subnormal? - // Here have that real_exponent <= 0 so -real_exponent >= 0 - if (-real_exponent + 1 >= 64) { // if we have more than 64 bits below - // the minimum exponent, you have a - // zero for sure. - d = 0.0; - return true; - } - // next line is safe because -real_exponent + 1 < 0 - mantissa >>= -real_exponent + 1; - // Thankfully, we can't have both "round-to-even" and subnormals because - // "round-to-even" only occurs for powers close to 0. - mantissa += (mantissa & 1); // round up - mantissa >>= 1; - // There is a weird scenario where we don't have a subnormal but just. - // Suppose we start with 2.2250738585072013e-308, we end up - // with 0x3fffffffffffff x 2^-1023-53 which is technically subnormal - // whereas 0x40000000000000 x 2^-1023-53 is normal. Now, we need to - // round - // up 0x3fffffffffffff x 2^-1023-53 and once we do, we are no longer - // subnormal, but we can only know this after rounding. - // So we only declare a subnormal if we are smaller than the threshold. - real_exponent = (mantissa < (uint64_t(1) << 52)) ? 0 : 1; - d = to_double(mantissa, real_exponent, negative); - return true; - } - // We have to round to even. The "to even" part - // is only a problem when we are right in between two floats - // which we guard against. - // If we have lots of trailing zeros, we may fall right between two - // floating-point values. - // - // The round-to-even cases take the form of a number 2m+1 which is in - // (2^53,2^54] - // times a power of two. That is, it is right between a number with binary - // significand - // m and another number with binary significand m+1; and it must be the case - // that it cannot be represented by a float itself. - // - // We must have that w * 10 ^q == (2m+1) * 2^p for some power of two 2^p. - // Recall that 10^q = 5^q * 2^q. - // When q >= 0, we must have that (2m+1) is divible by 5^q, so 5^q <= 2^54. - // We have that - // 5^23 <= 2^54 and it is the last power of five to qualify, so q <= 23. - // When q<0, we have w >= (2m+1) x 5^{-q}. We must have that w<2^{64} so - // (2m+1) x 5^{-q} < 2^{64}. We have that 2m+1>2^{53}. Hence, we must have - // 2^{53} x 5^{-q} < 2^{64}. - // Hence we have 5^{-q} < 2^{11}$ or q>= -4. - // - // We require lower <= 1 and not lower == 0 because we could not prove that - // that lower == 0 is implied; but we could prove that lower <= 1 is a - // necessary and sufficient test. - if (simdjson_unlikely((lower <= 1) && (power >= -4) && (power <= 23) && - ((mantissa & 3) == 1))) { - if ((mantissa << (upperbit + 64 - 53 - 2)) == upper) { - mantissa &= ~1; // flip it so that we do not round up - } - } - - mantissa += mantissa & 1; - mantissa >>= 1; - - // Here we have mantissa < (1<<53), unless there was an overflow - if (mantissa >= (1ULL << 53)) { - ////////// - // This will happen when parsing values such as 7.2057594037927933e+16 - //////// - mantissa = (1ULL << 52); - real_exponent++; - } - mantissa &= ~(1ULL << 52); - // we have to check that real_exponent is in range, otherwise we bail out - if (simdjson_unlikely(real_exponent > 2046)) { - // We have an infinite value!!! We could actually throw an error here if - // we could. - return false; - } - d = to_double(mantissa, real_exponent, negative); - return true; -} - -// We call a fallback floating-point parser that might be slow. Note -// it will accept JSON numbers, but the JSON spec. is more restrictive so -// before you call parse_float_fallback, you need to have validated the input -// string with the JSON grammar. -// It will return an error (false) if the parsed number is infinite. -// The string parsing itself always succeeds. We know that there is at least -// one digit. -static bool parse_float_fallback(const uint8_t *ptr, double *outDouble) { - *outDouble = - simdjson::internal::from_chars(reinterpret_cast(ptr)); - // We do not accept infinite values. - - // Detecting finite values in a portable manner is ridiculously hard, - // ideally - // we would want to do: - // return !std::isfinite(*outDouble); - // but that mysteriously fails under legacy/old libc++ libraries, see - // https://github.com/simdjson/simdjson/issues/1286 - // - // Therefore, fall back to this solution (the extra parens are there - // to handle that max may be a macro on windows). - return !(*outDouble > (std::numeric_limits::max)() || - *outDouble < std::numeric_limits::lowest()); -} -static bool parse_float_fallback(const uint8_t *ptr, - const uint8_t *end_ptr, - double *outDouble) { - *outDouble = - simdjson::internal::from_chars(reinterpret_cast(ptr), - reinterpret_cast(end_ptr)); - // We do not accept infinite values. - - // Detecting finite values in a portable manner is ridiculously hard, - // ideally - // we would want to do: - // return !std::isfinite(*outDouble); - // but that mysteriously fails under legacy/old libc++ libraries, see - // https://github.com/simdjson/simdjson/issues/1286 - // - // Therefore, fall back to this solution (the extra parens are there - // to handle that max may be a macro on windows). - return !(*outDouble > (std::numeric_limits::max)() || - *outDouble < std::numeric_limits::lowest()); -} - -// check quickly whether the next 8 chars are made of digits -// at a glance, it looks better than Mula's -// http://0x80.pl/articles/swar-digits-validate.html -simdjson_really_inline bool is_made_of_eight_digits_fast(const uint8_t *chars) { - uint64_t val; - // this can read up to 7 bytes beyond the buffer size, but we require - // SIMDJSON_PADDING of padding - static_assert(7 <= SIMDJSON_PADDING, - "SIMDJSON_PADDING must be bigger than 7"); - std::memcpy(&val, chars, 8); - // a branchy method might be faster: - // return (( val & 0xF0F0F0F0F0F0F0F0 ) == 0x3030303030303030) - // && (( (val + 0x0606060606060606) & 0xF0F0F0F0F0F0F0F0 ) == - // 0x3030303030303030); - return (((val & 0xF0F0F0F0F0F0F0F0) | - (((val + 0x0606060606060606) & 0xF0F0F0F0F0F0F0F0) >> 4)) == - 0x3333333333333333); -} - -template -error_code slow_float_parsing(simdjson_unused const uint8_t *src, W writer) { - double d; - if (parse_float_fallback(src, &d)) { - writer.append_double(d); - return SUCCESS; - } - return INVALID_NUMBER(src); -} - -template -SIMDJSON_NO_SANITIZE_UNDEFINED // We deliberately allow overflow here and check - // later - simdjson_really_inline bool - parse_digit(const uint8_t c, I &i) { - const uint8_t digit = static_cast(c - '0'); - if (digit > 9) { - return false; - } - // PERF NOTE: multiplication by 10 is cheaper than arbitrary integer - // multiplication - i = 10 * i + digit; // might overflow, we will handle the overflow later - return true; -} - -simdjson_really_inline error_code -parse_decimal(simdjson_unused const uint8_t *const src, - const uint8_t *&p, - uint64_t &i, - int64_t &exponent) { - // we continue with the fiction that we have an integer. If the - // floating point number is representable as x * 10^z for some integer - // z that fits in 53 bits, then we will be able to convert back the - // the integer into a float in a lossless manner. - const uint8_t *const first_after_period = p; - -#ifdef SIMDJSON_SWAR_NUMBER_PARSING -#if SIMDJSON_SWAR_NUMBER_PARSING - // this helps if we have lots of decimals! - // this turns out to be frequent enough. - if (is_made_of_eight_digits_fast(p)) { - i = i * 100000000 + parse_eight_digits_unrolled(p); - p += 8; - } -#endif // SIMDJSON_SWAR_NUMBER_PARSING -#endif // #ifdef SIMDJSON_SWAR_NUMBER_PARSING - // Unrolling the first digit makes a small difference on some - // implementations (e.g. westmere) - if (parse_digit(*p, i)) { - ++p; - } - while (parse_digit(*p, i)) { - p++; - } - exponent = first_after_period - p; - // Decimal without digits (123.) is illegal - if (exponent == 0) { - return INVALID_NUMBER(src); - } - return SUCCESS; -} - -simdjson_really_inline error_code -parse_exponent(simdjson_unused const uint8_t *const src, - const uint8_t *&p, - int64_t &exponent) { - // Exp Sign: -123.456e[-]78 - bool neg_exp = ('-' == *p); - if (neg_exp || '+' == *p) { - p++; - } // Skip + as well - - // Exponent: -123.456e-[78] - auto start_exp = p; - int64_t exp_number = 0; - while (parse_digit(*p, exp_number)) { - ++p; - } - // It is possible for parse_digit to overflow. - // In particular, it could overflow to INT64_MIN, and we cannot do - - // INT64_MIN. - // Thus we *must* check for possible overflow before we negate exp_number. - - // Performance notes: it may seem like combining the two "simdjson_unlikely - // checks" below into - // a single simdjson_unlikely path would be faster. The reasoning is sound, - // but the compiler may - // not oblige and may, in fact, generate two distinct paths in any case. It - // might be - // possible to do uint64_t(p - start_exp - 1) >= 18 but it could end up - // trading off - // instructions for a simdjson_likely branch, an unconclusive gain. - - // If there were no digits, it's an error. - if (simdjson_unlikely(p == start_exp)) { - return INVALID_NUMBER(src); - } - // We have a valid positive exponent in exp_number at this point, except - // that - // it may have overflowed. - - // If there were more than 18 digits, we may have overflowed the integer. We - // have to do - // something!!!! - if (simdjson_unlikely(p > start_exp + 18)) { - // Skip leading zeroes: 1e000000000000000000001 is technically valid and - // doesn't overflow - while (*start_exp == '0') { - start_exp++; - } - // 19 digits could overflow int64_t and is kind of absurd anyway. We - // don't - // support exponents smaller than -999,999,999,999,999,999 and bigger - // than 999,999,999,999,999,999. - // We can truncate. - // Note that 999999999999999999 is assuredly too large. The maximal - // ieee64 value before - // infinity is ~1.8e308. The smallest subnormal is ~5e-324. So, - // actually, we could - // truncate at 324. - // Note that there is no reason to fail per se at this point in time. - // E.g., 0e999999999999999999999 is a fine number. - if (p > start_exp + 18) { - exp_number = 999999999999999999; - } - } - // At this point, we know that exp_number is a sane, positive, signed - // integer. - // It is <= 999,999,999,999,999,999. As long as 'exponent' is in - // [-8223372036854775808, 8223372036854775808], we won't overflow. Because - // 'exponent' - // is bounded in magnitude by the size of the JSON input, we are fine in - // this universe. - // To sum it up: the next line should never overflow. - exponent += (neg_exp ? -exp_number : exp_number); - return SUCCESS; -} - -simdjson_really_inline size_t significant_digits(const uint8_t *start_digits, - size_t digit_count) { - // It is possible that the integer had an overflow. - // We have to handle the case where we have 0.0000somenumber. - const uint8_t *start = start_digits; - while ((*start == '0') || (*start == '.')) { - ++start; - } - // we over-decrement by one when there is a '.' - return digit_count - size_t(start - start_digits); -} - -template -simdjson_really_inline error_code write_float(const uint8_t *const src, - bool negative, - uint64_t i, - const uint8_t *start_digits, - size_t digit_count, - int64_t exponent, - W &writer) { - // If we frequently had to deal with long strings of digits, - // we could extend our code by using a 128-bit integer instead - // of a 64-bit integer. However, this is uncommon in practice. - // - // 9999999999999999999 < 2**64 so we can accommodate 19 digits. - // If we have a decimal separator, then digit_count - 1 is the number of - // digits, but we - // may not have a decimal separator! - if (simdjson_unlikely(digit_count > 19 && - significant_digits(start_digits, digit_count) > 19)) { - // Ok, chances are good that we had an overflow! - // this is almost never going to get called!!! - // we start anew, going slowly!!! - // This will happen in the following examples: - // 10000000000000000000000000000000000000000000e+308 - // 3.1415926535897932384626433832795028841971693993751 - // - // NOTE: This makes a *copy* of the writer and passes it to - // slow_float_parsing. This happens - // because slow_float_parsing is a non-inlined function. If we passed - // our writer reference to - // it, it would force it to be stored in memory, preventing the compiler - // from picking it apart - // and putting into registers. i.e. if we pass it as reference, it gets - // slow. - // This is what forces the skip_double, as well. - error_code error = slow_float_parsing(src, writer); - writer.skip_double(); - return error; - } - // NOTE: it's weird that the simdjson_unlikely() only wraps half the if, but - // it seems to get slower any other - // way we've tried: - // https://github.com/simdjson/simdjson/pull/990#discussion_r448497331 - // To future reader: we'd love if someone found a better way, or at least - // could explain this result! - if (simdjson_unlikely(exponent < simdjson::internal::smallest_power) || - (exponent > simdjson::internal::largest_power)) { - // - // Important: smallest_power is such that it leads to a zero value. - // Observe that 18446744073709551615e-343 == 0, i.e. (2**64 - 1) e -343 - // is zero - // so something x 10^-343 goes to zero, but not so with something x - // 10^-342. - static_assert(simdjson::internal::smallest_power <= -342, - "smallest_power is not small enough"); - // - if ((exponent < simdjson::internal::smallest_power) || (i == 0)) { - WRITE_DOUBLE(0, src, writer); - return SUCCESS; - } else { // (exponent > largest_power) and (i != 0) - // We have, for sure, an infinite value and simdjson refuses to - // parse infinite values. - return INVALID_NUMBER(src); - } - } - double d; - if (!compute_float_64(exponent, i, negative, d)) { - // we are almost never going to get here. - if (!parse_float_fallback(src, &d)) { - return INVALID_NUMBER(src); - } - } - WRITE_DOUBLE(d, src, writer); - return SUCCESS; -} - -// for performance analysis, it is sometimes useful to skip parsing -#ifdef SIMDJSON_SKIPNUMBERPARSING - -template -simdjson_really_inline error_code parse_number(const uint8_t *const, - W &writer) { - writer.append_s64(0); // always write zero - return SUCCESS; // always succeeds -} - -simdjson_unused simdjson_really_inline simdjson_result parse_unsigned( - const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result parse_integer( - const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result parse_double( - const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result -parse_unsigned_in_string(const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result -parse_integer_in_string(const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline simdjson_result -parse_double_in_string(const uint8_t *const src) noexcept { - return 0; -} -simdjson_unused simdjson_really_inline bool is_negative( - const uint8_t *src) noexcept { - return false; -} -simdjson_unused simdjson_really_inline simdjson_result is_integer( - const uint8_t *src) noexcept { - return false; -} -simdjson_unused simdjson_really_inline simdjson_result -get_number_type(const uint8_t *src) noexcept { - return ondemand::number_type::signed_integer; -} -#else - -// parse the number at src -// define JSON_TEST_NUMBERS for unit testing -// -// It is assumed that the number is followed by a structural ({,},],[) character -// or a white space character. If that is not the case (e.g., when the JSON -// document is made of a single number), then it is necessary to copy the -// content and append a space before calling this function. -// -// Our objective is accurate parsing (ULP of 0) at high speed. -template -simdjson_really_inline error_code parse_number(const uint8_t *const src, - W &writer) { - // - // Check for minus sign - // - bool negative = (*src == '-'); - const uint8_t *p = src + negative; - - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - if (digit_count == 0 || ('0' == *start_digits && digit_count > 1)) { - return INVALID_NUMBER(src); - } - - // - // Handle floats if there is a . or e (or both) - // - int64_t exponent = 0; - bool is_float = false; - if ('.' == *p) { - is_float = true; - ++p; - SIMDJSON_TRY(parse_decimal(src, p, i, exponent)); - digit_count = - int(p - start_digits); // used later to guard against overflows - } - if (('e' == *p) || ('E' == *p)) { - is_float = true; - ++p; - SIMDJSON_TRY(parse_exponent(src, p, exponent)); - } - if (is_float) { - const bool dirty_end = - jsoncharutils::is_not_structural_or_whitespace(*p); - SIMDJSON_TRY(write_float( - src, negative, i, start_digits, digit_count, exponent, writer)); - if (dirty_end) { - return INVALID_NUMBER(src); - } - return SUCCESS; - } - - // The longest negative 64-bit number is 19 digits. - // The longest positive 64-bit number is 20 digits. - // We do it this way so we don't trigger this branch unless we must. - size_t longest_digit_count = negative ? 19 : 20; - if (digit_count > longest_digit_count) { - return INVALID_NUMBER(src); - } - if (digit_count == longest_digit_count) { - if (negative) { - // Anything negative above INT64_MAX+1 is invalid - if (i > uint64_t(INT64_MAX) + 1) { - return INVALID_NUMBER(src); - } - WRITE_INTEGER(~i + 1, src, writer); - if (jsoncharutils::is_not_structural_or_whitespace(*p)) { - return INVALID_NUMBER(src); - } - return SUCCESS; - // Positive overflow check: - // - A 20 digit number starting with 2-9 is overflow, because - // 18,446,744,073,709,551,615 is the - // biggest uint64_t. - // - A 20 digit number starting with 1 is overflow if it is less - // than INT64_MAX. - // If we got here, it's a 20 digit number starting with the digit - // "1". - // - If a 20 digit number starting with 1 overflowed (i*10+digit), - // the result will be smaller - // than 1,553,255,926,290,448,384. - // - That is smaller than the smallest possible 20-digit number the - // user could write: - // 10,000,000,000,000,000,000. - // - Therefore, if the number is positive and lower than that, it's - // overflow. - // - The value we are looking at is less than or equal to INT64_MAX. - // - } else if (src[0] != uint8_t('1') || i <= uint64_t(INT64_MAX)) { - return INVALID_NUMBER(src); - } - } - - // Write unsigned if it doesn't fit in a signed integer. - if (i > uint64_t(INT64_MAX)) { - WRITE_UNSIGNED(i, src, writer); - } else { - WRITE_INTEGER(negative ? (~i + 1) : i, src, writer); - } - if (jsoncharutils::is_not_structural_or_whitespace(*p)) { - return INVALID_NUMBER(src); - } - return SUCCESS; -} - -// Inlineable functions -namespace { - -// This table can be used to characterize the final character of an integer -// string. For JSON structural character and allowable white space characters, -// we return SUCCESS. For 'e', '.' and 'E', we return INCORRECT_TYPE. Otherwise -// we return NUMBER_ERROR. -// Optimization note: we could easily reduce the size of the table by half (to -// 128) -// at the cost of an extra branch. -// Optimization note: we want the values to use at most 8 bits (not, e.g., 32 -// bits): -static_assert(error_code(uint8_t(NUMBER_ERROR)) == NUMBER_ERROR, - "bad NUMBER_ERROR cast"); -static_assert(error_code(uint8_t(SUCCESS)) == SUCCESS, "bad NUMBER_ERROR cast"); -static_assert(error_code(uint8_t(INCORRECT_TYPE)) == INCORRECT_TYPE, - "bad NUMBER_ERROR cast"); - -const uint8_t integer_string_finisher[256] = { - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, - SUCCESS, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, - NUMBER_ERROR, INCORRECT_TYPE, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, INCORRECT_TYPE, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, SUCCESS, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, INCORRECT_TYPE, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, SUCCESS, NUMBER_ERROR, - SUCCESS, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, NUMBER_ERROR, - NUMBER_ERROR}; - -// Parse any number from 0 to 18,446,744,073,709,551,615 -simdjson_unused simdjson_really_inline simdjson_result parse_unsigned( - const uint8_t *const src) noexcept { - const uint8_t *p = src; - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // The longest positive 64-bit number is 20 digits. - // We do it this way so we don't trigger this branch unless we must. - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > 20)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > 20)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if (integer_string_finisher[*p] != SUCCESS) { - return error_code(integer_string_finisher[*p]); - } - - if (digit_count == 20) { - // Positive overflow check: - // - A 20 digit number starting with 2-9 is overflow, because - // 18,446,744,073,709,551,615 is the - // biggest uint64_t. - // - A 20 digit number starting with 1 is overflow if it is less than - // INT64_MAX. - // If we got here, it's a 20 digit number starting with the digit "1". - // - If a 20 digit number starting with 1 overflowed (i*10+digit), the - // result will be smaller - // than 1,553,255,926,290,448,384. - // - That is smaller than the smallest possible 20-digit number the user - // could write: - // 10,000,000,000,000,000,000. - // - Therefore, if the number is positive and lower than that, it's - // overflow. - // - The value we are looking at is less than or equal to INT64_MAX. - // - if (src[0] != uint8_t('1') || i <= uint64_t(INT64_MAX)) { - return INCORRECT_TYPE; - } - } - - return i; -} - - -// Parse any number from 0 to 18,446,744,073,709,551,615 -// Never read at src_end or beyond -simdjson_unused simdjson_really_inline simdjson_result parse_unsigned( - const uint8_t *const src, const uint8_t *const src_end) noexcept { - const uint8_t *p = src; - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while ((p != src_end) && parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // The longest positive 64-bit number is 20 digits. - // We do it this way so we don't trigger this branch unless we must. - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > 20)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > 20)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if ((p != src_end) && integer_string_finisher[*p] != SUCCESS) { - return error_code(integer_string_finisher[*p]); - } - - if (digit_count == 20) { - // Positive overflow check: - // - A 20 digit number starting with 2-9 is overflow, because - // 18,446,744,073,709,551,615 is the - // biggest uint64_t. - // - A 20 digit number starting with 1 is overflow if it is less than - // INT64_MAX. - // If we got here, it's a 20 digit number starting with the digit "1". - // - If a 20 digit number starting with 1 overflowed (i*10+digit), the - // result will be smaller - // than 1,553,255,926,290,448,384. - // - That is smaller than the smallest possible 20-digit number the user - // could write: - // 10,000,000,000,000,000,000. - // - Therefore, if the number is positive and lower than that, it's - // overflow. - // - The value we are looking at is less than or equal to INT64_MAX. - // - if (src[0] != uint8_t('1') || i <= uint64_t(INT64_MAX)) { - return INCORRECT_TYPE; - } - } - - return i; -} - -// Parse any number from 0 to 18,446,744,073,709,551,615 -simdjson_unused simdjson_really_inline simdjson_result -parse_unsigned_in_string(const uint8_t *const src) noexcept { - const uint8_t *p = src + 1; - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // The longest positive 64-bit number is 20 digits. - // We do it this way so we don't trigger this branch unless we must. - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > 20)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > 20)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if (*p != '"') { - return NUMBER_ERROR; - } - - if (digit_count == 20) { - // Positive overflow check: - // - A 20 digit number starting with 2-9 is overflow, because - // 18,446,744,073,709,551,615 is the - // biggest uint64_t. - // - A 20 digit number starting with 1 is overflow if it is less than - // INT64_MAX. - // If we got here, it's a 20 digit number starting with the digit "1". - // - If a 20 digit number starting with 1 overflowed (i*10+digit), the - // result will be smaller - // than 1,553,255,926,290,448,384. - // - That is smaller than the smallest possible 20-digit number the user - // could write: - // 10,000,000,000,000,000,000. - // - Therefore, if the number is positive and lower than that, it's - // overflow. - // - The value we are looking at is less than or equal to INT64_MAX. - // - // Note: we use src[1] and not src[0] because src[0] is the quote - // character in this - // instance. - if (src[1] != uint8_t('1') || i <= uint64_t(INT64_MAX)) { - return INCORRECT_TYPE; - } - } - - return i; -} - -// Parse any number from -9,223,372,036,854,775,808 to -// 9,223,372,036,854,775,807 -simdjson_unused simdjson_really_inline simdjson_result parse_integer( - const uint8_t *src) noexcept { - // - // Check for minus sign - // - bool negative = (*src == '-'); - const uint8_t *p = src + negative; - - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // We go from - // -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807 - // so we can never represent numbers that have more than 19 digits. - size_t longest_digit_count = 19; - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > longest_digit_count)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > longest_digit_count)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if (integer_string_finisher[*p] != SUCCESS) { - return error_code(integer_string_finisher[*p]); - } - // Negative numbers have can go down to - INT64_MAX - 1 whereas positive - // numbers are limited to INT64_MAX. - // Performance note: This check is only needed when digit_count == - // longest_digit_count but it is - // so cheap that we might as well always make it. - if (i > uint64_t(INT64_MAX) + uint64_t(negative)) { - return INCORRECT_TYPE; - } - return negative ? (~i + 1) : i; -} - -// Parse any number from -9,223,372,036,854,775,808 to -// 9,223,372,036,854,775,807 -// Never read at src_end or beyond -simdjson_unused simdjson_really_inline simdjson_result parse_integer( - const uint8_t *const src, const uint8_t *const src_end) noexcept { - // - // Check for minus sign - // - if (src == src_end) { - return NUMBER_ERROR; - } - bool negative = (*src == '-'); - const uint8_t *p = src + negative; - - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while ((p != src_end) && parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // We go from - // -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807 - // so we can never represent numbers that have more than 19 digits. - size_t longest_digit_count = 19; - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > longest_digit_count)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > longest_digit_count)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if ((p != src_end) && integer_string_finisher[*p] != SUCCESS) { - return error_code(integer_string_finisher[*p]); - } - // Negative numbers have can go down to - INT64_MAX - 1 whereas positive - // numbers are limited to INT64_MAX. - // Performance note: This check is only needed when digit_count == - // longest_digit_count but it is - // so cheap that we might as well always make it. - if (i > uint64_t(INT64_MAX) + uint64_t(negative)) { - return INCORRECT_TYPE; - } - return negative ? (~i + 1) : i; -} - -// Parse any number from -9,223,372,036,854,775,808 to -// 9,223,372,036,854,775,807 -simdjson_unused simdjson_really_inline simdjson_result -parse_integer_in_string(const uint8_t *src) noexcept { - // - // Check for minus sign - // - bool negative = (*(src + 1) == '-'); - const uint8_t *p = src + negative + 1; - - // - // Parse the integer part. - // - // PERF NOTE: we don't use is_made_of_eight_digits_fast because large - // integers like 123456789 are rare - const uint8_t *const start_digits = p; - uint64_t i = 0; - while (parse_digit(*p, i)) { - p++; - } - - // If there were no digits, or if the integer starts with 0 and has more - // than one digit, it's an error. - // Optimization note: size_t is expected to be unsigned. - size_t digit_count = size_t(p - start_digits); - // We go from - // -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807 - // so we can never represent numbers that have more than 19 digits. - size_t longest_digit_count = 19; - // Optimization note: the compiler can probably merge - // ((digit_count == 0) || (digit_count > longest_digit_count)) - // into a single branch since digit_count is unsigned. - if ((digit_count == 0) || (digit_count > longest_digit_count)) { - return INCORRECT_TYPE; - } - // Here digit_count > 0. - if (('0' == *start_digits) && (digit_count > 1)) { - return NUMBER_ERROR; - } - // We can do the following... - // if (!jsoncharutils::is_structural_or_whitespace(*p)) { - // return (*p == '.' || *p == 'e' || *p == 'E') ? INCORRECT_TYPE : - // NUMBER_ERROR; - // } - // as a single table lookup: - if (*p != '"') { - return NUMBER_ERROR; - } - // Negative numbers have can go down to - INT64_MAX - 1 whereas positive - // numbers are limited to INT64_MAX. - // Performance note: This check is only needed when digit_count == - // longest_digit_count but it is - // so cheap that we might as well always make it. - if (i > uint64_t(INT64_MAX) + uint64_t(negative)) { - return INCORRECT_TYPE; - } - return negative ? (~i + 1) : i; -} - -simdjson_unused simdjson_really_inline simdjson_result parse_double( - const uint8_t *src) noexcept { - // - // Check for minus sign - // - bool negative = (*src == '-'); - src += negative; - - // - // Parse the integer part. - // - uint64_t i = 0; - const uint8_t *p = src; - p += parse_digit(*p, i); - bool leading_zero = (i == 0); - while (parse_digit(*p, i)) { - p++; - } - // no integer digits, or 0123 (zero must be solo) - if (p == src) { - return INCORRECT_TYPE; - } - if ((leading_zero && p != src + 1)) { - return NUMBER_ERROR; - } - - // - // Parse the decimal part. - // - int64_t exponent = 0; - bool overflow; - if (simdjson_likely(*p == '.')) { - p++; - const uint8_t *start_decimal_digits = p; - if (!parse_digit(*p, i)) { - return NUMBER_ERROR; - } // no decimal digits - p++; - while (parse_digit(*p, i)) { - p++; - } - exponent = -(p - start_decimal_digits); - - // Overflow check. More than 19 digits (minus the decimal) may be - // overflow. - overflow = p - src - 1 > 19; - if (simdjson_unlikely(overflow && leading_zero)) { - // Skip leading 0.00000 and see if it still overflows - const uint8_t *start_digits = src + 2; - while (*start_digits == '0') { - start_digits++; - } - overflow = start_digits - src > 19; - } - } else { - overflow = p - src > 19; - } - - // - // Parse the exponent - // - if (*p == 'e' || *p == 'E') { - p++; - bool exp_neg = *p == '-'; - p += exp_neg || *p == '+'; - - uint64_t exp = 0; - const uint8_t *start_exp_digits = p; - while (parse_digit(*p, exp)) { - p++; - } - // no exp digits, or 20+ exp digits - if (p - start_exp_digits == 0 || p - start_exp_digits > 19) { - return NUMBER_ERROR; - } - - exponent += exp_neg ? 0 - exp : exp; - } - - if (jsoncharutils::is_not_structural_or_whitespace(*p)) { - return NUMBER_ERROR; - } - - overflow = overflow || exponent < simdjson::internal::smallest_power || - exponent > simdjson::internal::largest_power; - - // - // Assemble (or slow-parse) the float - // - double d; - if (simdjson_likely(!overflow)) { - if (compute_float_64(exponent, i, negative, d)) { - return d; - } - } - if (!parse_float_fallback(src - negative, &d)) { - return NUMBER_ERROR; - } - return d; -} - -simdjson_unused simdjson_really_inline bool is_negative( - const uint8_t *src) noexcept { - return (*src == '-'); -} - -simdjson_unused simdjson_really_inline simdjson_result is_integer( - const uint8_t *src) noexcept { - bool negative = (*src == '-'); - src += negative; - const uint8_t *p = src; - while (static_cast(*p - '0') <= 9) { - p++; - } - if (p == src) { - return NUMBER_ERROR; - } - if (jsoncharutils::is_structural_or_whitespace(*p)) { - return true; - } - return false; -} - -simdjson_unused simdjson_really_inline simdjson_result -get_number_type(const uint8_t *src) noexcept { - bool negative = (*src == '-'); - src += negative; - const uint8_t *p = src; - while (static_cast(*p - '0') <= 9) { - p++; - } - if (p == src) { - return NUMBER_ERROR; - } - if (jsoncharutils::is_structural_or_whitespace(*p)) { - int digit_count = int(p - src); - if (digit_count >= 19) { - const uint8_t *smaller_big_integer = - reinterpret_cast("9223372036854775808"); - if ((digit_count >= 20) || - (memcmp(src, smaller_big_integer, 19) >= 0)) { - return ondemand::number_type::unsigned_integer; - } - } - return ondemand::number_type::signed_integer; - } - return ondemand::number_type::floating_point_number; -} - -// Never read at src_end or beyond -simdjson_unused simdjson_really_inline simdjson_result parse_double( - const uint8_t *src, const uint8_t *const src_end) noexcept { - if (src == src_end) { - return NUMBER_ERROR; - } - // - // Check for minus sign - // - bool negative = (*src == '-'); - src += negative; - - // - // Parse the integer part. - // - uint64_t i = 0; - const uint8_t *p = src; - if (p == src_end) { - return NUMBER_ERROR; - } - p += parse_digit(*p, i); - bool leading_zero = (i == 0); - while ((p != src_end) && parse_digit(*p, i)) { - p++; - } - // no integer digits, or 0123 (zero must be solo) - if (p == src) { - return INCORRECT_TYPE; - } - if ((leading_zero && p != src + 1)) { - return NUMBER_ERROR; - } - - // - // Parse the decimal part. - // - int64_t exponent = 0; - bool overflow; - if (simdjson_likely((p != src_end) && (*p == '.'))) { - p++; - const uint8_t *start_decimal_digits = p; - if ((p == src_end) || !parse_digit(*p, i)) { - return NUMBER_ERROR; - } // no decimal digits - p++; - while ((p != src_end) && parse_digit(*p, i)) { - p++; - } - exponent = -(p - start_decimal_digits); - - // Overflow check. More than 19 digits (minus the decimal) may be - // overflow. - overflow = p - src - 1 > 19; - if (simdjson_unlikely(overflow && leading_zero)) { - // Skip leading 0.00000 and see if it still overflows - const uint8_t *start_digits = src + 2; - while (*start_digits == '0') { - start_digits++; - } - overflow = start_digits - src > 19; - } - } else { - overflow = p - src > 19; - } - - // - // Parse the exponent - // - if ((p != src_end) && (*p == 'e' || *p == 'E')) { - p++; - if (p == src_end) { - return NUMBER_ERROR; - } - bool exp_neg = *p == '-'; - p += exp_neg || *p == '+'; - - uint64_t exp = 0; - const uint8_t *start_exp_digits = p; - while ((p != src_end) && parse_digit(*p, exp)) { - p++; - } - // no exp digits, or 20+ exp digits - if (p - start_exp_digits == 0 || p - start_exp_digits > 19) { - return NUMBER_ERROR; - } - - exponent += exp_neg ? 0 - exp : exp; - } - - if ((p != src_end) && jsoncharutils::is_not_structural_or_whitespace(*p)) { - return NUMBER_ERROR; - } - - overflow = overflow || exponent < simdjson::internal::smallest_power || - exponent > simdjson::internal::largest_power; - - // - // Assemble (or slow-parse) the float - // - double d; - if (simdjson_likely(!overflow)) { - if (compute_float_64(exponent, i, negative, d)) { - return d; - } - } - if (!parse_float_fallback(src - negative, src_end, &d)) { - return NUMBER_ERROR; - } - return d; -} - -simdjson_unused simdjson_really_inline simdjson_result -parse_double_in_string(const uint8_t *src) noexcept { - // - // Check for minus sign - // - bool negative = (*(src + 1) == '-'); - src += negative + 1; - - // - // Parse the integer part. - // - uint64_t i = 0; - const uint8_t *p = src; - p += parse_digit(*p, i); - bool leading_zero = (i == 0); - while (parse_digit(*p, i)) { - p++; - } - // no integer digits, or 0123 (zero must be solo) - if (p == src) { - return INCORRECT_TYPE; - } - if ((leading_zero && p != src + 1)) { - return NUMBER_ERROR; - } - - // - // Parse the decimal part. - // - int64_t exponent = 0; - bool overflow; - if (simdjson_likely(*p == '.')) { - p++; - const uint8_t *start_decimal_digits = p; - if (!parse_digit(*p, i)) { - return NUMBER_ERROR; - } // no decimal digits - p++; - while (parse_digit(*p, i)) { - p++; - } - exponent = -(p - start_decimal_digits); - - // Overflow check. More than 19 digits (minus the decimal) may be - // overflow. - overflow = p - src - 1 > 19; - if (simdjson_unlikely(overflow && leading_zero)) { - // Skip leading 0.00000 and see if it still overflows - const uint8_t *start_digits = src + 2; - while (*start_digits == '0') { - start_digits++; - } - overflow = start_digits - src > 19; - } - } else { - overflow = p - src > 19; - } - - // - // Parse the exponent - // - if (*p == 'e' || *p == 'E') { - p++; - bool exp_neg = *p == '-'; - p += exp_neg || *p == '+'; - - uint64_t exp = 0; - const uint8_t *start_exp_digits = p; - while (parse_digit(*p, exp)) { - p++; - } - // no exp digits, or 20+ exp digits - if (p - start_exp_digits == 0 || p - start_exp_digits > 19) { - return NUMBER_ERROR; - } - - exponent += exp_neg ? 0 - exp : exp; - } - - if (*p != '"') { - return NUMBER_ERROR; - } - - overflow = overflow || exponent < simdjson::internal::smallest_power || - exponent > simdjson::internal::largest_power; - - // - // Assemble (or slow-parse) the float - // - double d; - if (simdjson_likely(!overflow)) { - if (compute_float_64(exponent, i, negative, d)) { - return d; - } - } - if (!parse_float_fallback(src - negative, &d)) { - return NUMBER_ERROR; - } - return d; -} -} // namespace {} -#endif // SIMDJSON_SKIPNUMBERPARSING - -} // namespace numberparsing -} // unnamed namespace -} // namespace westmere -} // namespace simdjson -/* end file include/simdjson/generic/numberparsing.h */ - -#endif // SIMDJSON_WESTMERE_NUMBERPARSING_H -/* end file include/simdjson/westmere/numberparsing.h */ -/* begin file include/simdjson/westmere/end.h */ -SIMDJSON_UNTARGET_WESTMERE -/* end file include/simdjson/westmere/end.h */ - -#endif // SIMDJSON_IMPLEMENTATION_WESTMERE -#endif // SIMDJSON_WESTMERE_COMMON_H -/* end file include/simdjson/westmere.h */ - -// Builtin implementation - -SIMDJSON_POP_DISABLE_WARNINGS - -#endif // SIMDJSON_IMPLEMENTATIONS_H -/* end file include/simdjson/implementations.h */ - -// Determine the best builtin implementation -#ifndef SIMDJSON_BUILTIN_IMPLEMENTATION -#if SIMDJSON_CAN_ALWAYS_RUN_HASWELL -#define SIMDJSON_BUILTIN_IMPLEMENTATION haswell -#elif SIMDJSON_CAN_ALWAYS_RUN_WESTMERE -#define SIMDJSON_BUILTIN_IMPLEMENTATION westmere -#elif SIMDJSON_CAN_ALWAYS_RUN_ARM64 -#define SIMDJSON_BUILTIN_IMPLEMENTATION arm64 -#elif SIMDJSON_CAN_ALWAYS_RUN_PPC64 -#define SIMDJSON_BUILTIN_IMPLEMENTATION ppc64 -#elif SIMDJSON_CAN_ALWAYS_RUN_FALLBACK -#define SIMDJSON_BUILTIN_IMPLEMENTATION fallback -#else -#error \ - "All possible implementations (including fallback) have been disabled! simdjson will not run." -#endif -#endif // SIMDJSON_BUILTIN_IMPLEMENTATION - -// redefining SIMDJSON_IMPLEMENTATION to "SIMDJSON_BUILTIN_IMPLEMENTATION" -// #define SIMDJSON_IMPLEMENTATION SIMDJSON_BUILTIN_IMPLEMENTATION - -// ondemand is only compiled as part of the builtin implementation at present - -// Interface declarations -/* begin file include/simdjson/generic/implementation_simdjson_result_base.h */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { - -// This is a near copy of include/error.h's implementation_simdjson_result_base, -// except it doesn't use std::pair -// so we can avoid inlining errors -// TODO reconcile these! -/** - * The result of a simdjson operation that could fail. - * - * Gives the option of reading error codes, or throwing an exception by casting - * to the desired result. - * - * This is a base class for implementations that want to add functions to the - * result type for - * chaining. - * - * Override like: - * - * struct simdjson_result : public - * internal::implementation_simdjson_result_base { - * simdjson_result() noexcept : - * internal::implementation_simdjson_result_base() {} - * simdjson_result(error_code error) noexcept : - * internal::implementation_simdjson_result_base(error) {} - * simdjson_result(T &&value) noexcept : - * internal::implementation_simdjson_result_base(std::forward(value)) {} - * simdjson_result(T &&value, error_code error) noexcept : - * internal::implementation_simdjson_result_base(value, error) {} - * // Your extra methods here - * } - * - * Then any method returning simdjson_result will be chainable with your - * methods. - */ -template -struct implementation_simdjson_result_base { - /** - * Create a new empty result with error = UNINITIALIZED. - */ - simdjson_really_inline implementation_simdjson_result_base() noexcept = - default; - - /** - * Create a new error result. - */ - simdjson_really_inline implementation_simdjson_result_base( - error_code error) noexcept; - - /** - * Create a new successful result. - */ - simdjson_really_inline implementation_simdjson_result_base( - T &&value) noexcept; - - /** - * Create a new result with both things (use if you don't want to branch - * when creating the result). - */ - simdjson_really_inline implementation_simdjson_result_base( - T &&value, error_code error) noexcept; - - /** - * Move the value and the error to the provided variables. - * - * @param value The variable to assign the value to. May not be set if there - * is an error. - * @param error The variable to assign the error to. Set to SUCCESS if there - * is no error. - */ - simdjson_really_inline void tie(T &value, error_code &error) && noexcept; - - /** - * Move the value to the provided variable. - * - * @param value The variable to assign the value to. May not be set if there - * is an error. - */ - simdjson_really_inline error_code get(T &value) && noexcept; - - /** - * The error. - */ - simdjson_really_inline error_code error() const noexcept; - -#if SIMDJSON_EXCEPTIONS - - /** - * Get the result value. - * - * @throw simdjson_error if there was an error. - */ - simdjson_really_inline T &value() & noexcept(false); - - /** - * Take the result value (move it). - * - * @throw simdjson_error if there was an error. - */ - simdjson_really_inline T &&value() && noexcept(false); - - /** - * Take the result value (move it). - * - * @throw simdjson_error if there was an error. - */ - simdjson_really_inline T &&take_value() && noexcept(false); - - /** - * Cast to the value (will throw on error). - * - * @throw simdjson_error if there was an error. - */ - simdjson_really_inline operator T &&() && noexcept(false); - - -#endif // SIMDJSON_EXCEPTIONS - - /** - * Get the result value. This function is safe if and only - * the error() method returns a value that evaluates to false. - */ - simdjson_really_inline const T &value_unsafe() const &noexcept; - /** - * Get the result value. This function is safe if and only - * the error() method returns a value that evaluates to false. - */ - simdjson_really_inline T &value_unsafe() & noexcept; - /** - * Take the result value (move it). This function is safe if and only - * the error() method returns a value that evaluates to false. - */ - simdjson_really_inline T &&value_unsafe() && noexcept; - - protected: - /** users should never directly access first and second. **/ - T first{}; /** Users should never directly access 'first'. **/ - error_code second{ - UNINITIALIZED}; /** Users should never directly access 'second'. **/ -}; // struct implementation_simdjson_result_base - -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson -/* end file include/simdjson/generic/implementation_simdjson_result_base.h */ -/* begin file include/simdjson/generic/ondemand.h */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -/** - * A fast, simple, DOM-like interface that parses JSON as you use it. - * - * Designed for maximum speed and a lower memory profile. - */ -namespace ondemand { - -/** Represents the depth of a JSON value (number of nested arrays/objects). */ -using depth_t = int32_t; - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -/* begin file include/simdjson/generic/ondemand/json_type.h */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { -/** - * The type of a JSON value. - */ -enum class json_type { - // Start at 1 to catch uninitialized / default values more easily - array = 1, ///< A JSON array ( [ 1, 2, 3 ... ] ) - object, ///< A JSON object ( { "a": 1, "b" 2, ... } ) - number, ///< A JSON number ( 1 or -2.3 or 4.5e6 ...) - string, ///< A JSON string ( "a" or "hello world\n" ...) - boolean, ///< A JSON boolean (true or false) - null ///< A JSON null (null) -}; - -class value_iterator; - -/** - * A type representing a JSON number. - * The design of the struct is deliberately straight-forward. All - * functions return standard values with no error check. - */ -struct number { - /** - * return the automatically determined type of - * the number: number_type::floating_point_number, - * number_type::signed_integer or number_type::unsigned_integer. - * - * enum class number_type { - * floating_point_number=1, /// a binary64 number - * signed_integer, /// a signed integer that fits in a - * 64-bit word using two's complement - * unsigned_integer /// a positive integer larger or equal to - * 1<<63 - * }; - */ - simdjson_really_inline number_type get_number_type() const noexcept; - /** - * return true if the automatically determined type of - * the number is number_type::unsigned_integer. - */ - simdjson_really_inline bool is_uint64() const noexcept; - /** - * return the value as a uint64_t, only valid if is_uint64() is true. - */ - simdjson_really_inline uint64_t get_uint64() const noexcept; - simdjson_really_inline operator uint64_t() const noexcept; - - /** - * return true if the automatically determined type of - * the number is number_type::signed_integer. - */ - simdjson_really_inline bool is_int64() const noexcept; - /** - * return the value as a int64_t, only valid if is_int64() is true. - */ - simdjson_really_inline int64_t get_int64() const noexcept; - simdjson_really_inline operator int64_t() const noexcept; - - - /** - * return true if the automatically determined type of - * the number is number_type::floating_point_number. - */ - simdjson_really_inline bool is_double() const noexcept; - /** - * return the value as a double, only valid if is_double() is true. - */ - simdjson_really_inline double get_double() const noexcept; - simdjson_really_inline operator double() const noexcept; - - /** - * Convert the number to a double. Though it always succeed, the conversion - * may be lossy if the number cannot be represented exactly. - */ - simdjson_really_inline double as_double() const noexcept; - - - protected: - /** - * The next block of declaration is designed so that we can call the number - * parsing - * functions on a number type. They are protected and should never be used - * outside - * of the core simdjson library. - */ - friend class value_iterator; - template - friend error_code numberparsing::write_float(const uint8_t *const src, - bool negative, - uint64_t i, - const uint8_t *start_digits, - size_t digit_count, - int64_t exponent, - W &writer); - template - friend error_code numberparsing::parse_number(const uint8_t *const src, - W &writer); - template - friend error_code numberparsing::slow_float_parsing( - simdjson_unused const uint8_t *src, W writer); - /** Store a signed 64-bit value to the number. */ - simdjson_really_inline void append_s64(int64_t value) noexcept; - /** Store an unsigned 64-bit value to the number. */ - simdjson_really_inline void append_u64(uint64_t value) noexcept; - /** Store a double value to the number. */ - simdjson_really_inline void append_double(double value) noexcept; - /** Specifies that the value is a double, but leave it undefined. */ - simdjson_really_inline void skip_double() noexcept; - /** - * End of friend declarations. - */ - - /** - * Our attributes are a union type (size = 64 bits) - * followed by a type indicator. - */ - union { - double floating_point_number; - int64_t signed_integer; - uint64_t unsigned_integer; - } payload{0}; - number_type type{number_type::signed_integer}; -}; - -/** - * Write the JSON type to the output stream - * - * @param out The output stream. - * @param type The json_type. - */ -inline std::ostream &operator<<(std::ostream &out, json_type type) noexcept; -inline std::ostream &operator<<(std::ostream &out, number_type type) noexcept; - -#if SIMDJSON_EXCEPTIONS -/** - * Send JSON type to an output stream. - * - * @param out The output stream. - * @param type The json_type. - * @throw simdjson_error if the result being printed has an error. If there is - * an error with the - * underlying output stream, that error will be propagated - * (simdjson_error will not be - * thrown). - */ -inline std::ostream &operator<<( - std::ostream &out, simdjson_result &type) noexcept(false); -#endif - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -template <> -struct simdjson_result - : public SIMDJSON_BUILTIN_IMPLEMENTATION:: - implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::json_type> { - public: - simdjson_really_inline simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::json_type - &&value) noexcept; ///< @private - simdjson_really_inline simdjson_result( - error_code error) noexcept; ///< @private - simdjson_really_inline simdjson_result() noexcept = default; - simdjson_really_inline ~simdjson_result() noexcept = default; ///< @private -}; - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/json_type.h */ -/* begin file include/simdjson/generic/ondemand/token_position.h */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -/** @private Position in the JSON buffer indexes */ -using token_position = const uint32_t *; - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/token_position.h */ -/* begin file include/simdjson/generic/ondemand/logger.h */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -class json_iterator; -class value_iterator; - -namespace logger { - -#if SIMDJSON_VERBOSE_LOGGING -static constexpr const bool LOG_ENABLED = true; -#else -static constexpr const bool LOG_ENABLED = false; -#endif - -// We do not want these functions to be 'really inlined' since real inlining is -// for performance purposes and if you are using the loggers, you do not care -// about -// performance (or should not). -static inline void log_headers() noexcept; -static inline void log_line(const json_iterator &iter, - token_position index, - depth_t depth, - const char *title_prefix, - const char *title, - std::string_view detail) noexcept; -static inline void log_line(const json_iterator &iter, - const char *title_prefix, - const char *title, - std::string_view detail, - int delta, - int depth_delta) noexcept; -static inline void log_event(const json_iterator &iter, - const char *type, - std::string_view detail = "", - int delta = 0, - int depth_delta = 0) noexcept; -static inline void log_value(const json_iterator &iter, - token_position index, - depth_t depth, - const char *type, - std::string_view detail = "") noexcept; -static inline void log_value(const json_iterator &iter, - const char *type, - std::string_view detail = "", - int delta = -1, - int depth_delta = 0) noexcept; -static inline void log_start_value(const json_iterator &iter, - token_position index, - depth_t depth, - const char *type, - std::string_view detail = "") noexcept; -static inline void log_start_value(const json_iterator &iter, - const char *type, - int delta = -1, - int depth_delta = 0) noexcept; -static inline void log_end_value(const json_iterator &iter, - const char *type, - int delta = -1, - int depth_delta = 0) noexcept; -static inline void log_error(const json_iterator &iter, - token_position index, - depth_t depth, - const char *error, - const char *detail = "") noexcept; -static inline void log_error(const json_iterator &iter, - const char *error, - const char *detail = "", - int delta = -1, - int depth_delta = 0) noexcept; - -static inline void log_event(const value_iterator &iter, - const char *type, - std::string_view detail = "", - int delta = 0, - int depth_delta = 0) noexcept; -static inline void log_value(const value_iterator &iter, - const char *type, - std::string_view detail = "", - int delta = -1, - int depth_delta = 0) noexcept; -static inline void log_start_value(const value_iterator &iter, - const char *type, - int delta = -1, - int depth_delta = 0) noexcept; -static inline void log_end_value(const value_iterator &iter, - const char *type, - int delta = -1, - int depth_delta = 0) noexcept; -static inline void log_error(const value_iterator &iter, - const char *error, - const char *detail = "", - int delta = -1, - int depth_delta = 0) noexcept; - -} // namespace logger -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/logger.h */ -/* begin file include/simdjson/generic/ondemand/raw_json_string.h */ - -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -class object; -class parser; -class json_iterator; - -/** - * A string escaped per JSON rules, terminated with quote ("). They are used to - * represent - * unescaped keys inside JSON documents. - * - * (In other words, a pointer to the beginning of a string, just after the start - * quote, inside a - * JSON file.) - * - * This class is deliberately simplistic and has little functionality. You can - * compare a raw_json_string instance with an unescaped C string, but - * that is pretty much all you can do. - * - * They originate typically from field instance which in turn represent - * key-value pairs from - * object instances. From a field instance, you get the raw_json_string instance - * by calling key(). - * You can, if you want a more usable string_view instance, call the - * unescaped_key() method - * on the field instance. - */ -class raw_json_string { - public: - /** - * Create a new invalid raw_json_string. - * - * Exists so you can declare a variable and later assign to it before use. - */ - simdjson_really_inline raw_json_string() noexcept = default; - - /** - * Create a new invalid raw_json_string pointed at the given location in the - * JSON. - * - * The given location must be just *after* the beginning quote (") in the - * JSON file. - * - * It *must* be terminated by a ", and be a valid JSON string. - */ - simdjson_really_inline raw_json_string(const uint8_t *_buf) noexcept; - /** - * Get the raw pointer to the beginning of the string in the JSON (just - * after the "). - * - * It is possible for this function to return a null pointer if the instance - * has outlived its existence. - */ - simdjson_really_inline const char *raw() const noexcept; - - /** - * This compares the current instance to the std::string_view target: - * returns true if - * they are byte-by-byte equal (no escaping is done) on target.size() - * characters, - * and if the raw_json_string instance has a quote character at byte index - * target.size(). - * We never read more than length + 1 bytes in the raw_json_string instance. - * If length is smaller than target.size(), this will return false. - * - * The std::string_view instance may contain any characters. However, the - * caller - * is responsible for setting length so that length bytes may be read in the - * raw_json_string. - * - * Performance: the comparison may be done using memcmp which may be - * efficient - * for long strings. - */ - simdjson_really_inline bool unsafe_is_equal(size_t length, - std::string_view target) const - noexcept; - - /** - * This compares the current instance to the std::string_view target: - * returns true if - * they are byte-by-byte equal (no escaping is done). - * The std::string_view instance should not contain unescaped quote - * characters: - * the caller is responsible for this check. See - * is_free_from_unescaped_quote. - * - * Performance: the comparison is done byte-by-byte which might be - * inefficient for - * long strings. - * - * If target is a compile-time constant, and your compiler likes you, - * you should be able to do the following without performance penalty... - * - * static_assert(raw_json_string::is_free_from_unescaped_quote(target), - * ""); - * s.unsafe_is_equal(target); - */ - simdjson_really_inline bool unsafe_is_equal(std::string_view target) const - noexcept; - - /** - * This compares the current instance to the C string target: returns true - * if - * they are byte-by-byte equal (no escaping is done). - * The provided C string should not contain an unescaped quote character: - * the caller is responsible for this check. See - * is_free_from_unescaped_quote. - * - * If target is a compile-time constant, and your compiler likes you, - * you should be able to do the following without performance penalty... - * - * static_assert(raw_json_string::is_free_from_unescaped_quote(target), - * ""); - * s.unsafe_is_equal(target); - */ - simdjson_really_inline bool unsafe_is_equal(const char *target) const - noexcept; - - /** - * This compares the current instance to the std::string_view target: - * returns true if - * they are byte-by-byte equal (no escaping is done). - */ - simdjson_really_inline bool is_equal(std::string_view target) const - noexcept; - - /** - * This compares the current instance to the C string target: returns true - * if - * they are byte-by-byte equal (no escaping is done). - */ - simdjson_really_inline bool is_equal(const char *target) const noexcept; - - /** - * Returns true if target is free from unescaped quote. If target is known - * at - * compile-time, we might expect the computation to happen at compile time - * with - * many compilers (not all!). - */ - static simdjson_really_inline bool is_free_from_unescaped_quote( - std::string_view target) noexcept; - static simdjson_really_inline bool is_free_from_unescaped_quote( - const char *target) noexcept; - - private: - /** - * This will set the inner pointer to zero, effectively making - * this instance unusable. - */ - simdjson_really_inline void consume() noexcept { buf = nullptr; } - - /** - * Checks whether the inner pointer is non-null and thus usable. - */ - simdjson_really_inline simdjson_warn_unused bool alive() const noexcept { - return buf != nullptr; - } - - /** - * Unescape this JSON string, replacing \\ with \, \n with newline, etc. - * - * ## IMPORTANT: string_view lifetime - * - * The string_view is only valid as long as the bytes in dst. - * - * @param dst A pointer to a buffer at least large enough to write this - * string as well as a \0. - * dst will be updated to the next unused location (just after - * the \0 written out at - * the end of this string). - * @return A string_view pointing at the unescaped string in dst - * @error STRING_ERROR if escapes are incorrect. - */ - simdjson_really_inline simdjson_warn_unused - simdjson_result - unescape(uint8_t *&dst) const noexcept; - /** - * Unescape this JSON string, replacing \\ with \, \n with newline, etc. - * - * ## IMPORTANT: string_view lifetime - * - * The string_view is only valid until the next parse() call on the parser. - * - * @param iter A json_iterator, which contains a buffer where the string - * will be written. - */ - simdjson_really_inline simdjson_warn_unused - simdjson_result - unescape(json_iterator &iter) const noexcept; - - const uint8_t *buf{}; - friend class object; - friend class field; - friend struct simdjson_result; -}; - -simdjson_unused simdjson_really_inline std::ostream &operator<<( - std::ostream &, const raw_json_string &) noexcept; - -/** - * Comparisons between raw_json_string and std::string_view instances are - * potentially unsafe: the user is responsible - * for providing a string with no unescaped quote. Note that unescaped quotes - * cannot be present in valid JSON strings. - */ -simdjson_unused simdjson_really_inline bool operator==( - const raw_json_string &a, std::string_view c) noexcept; -simdjson_unused simdjson_really_inline bool operator==( - std::string_view c, const raw_json_string &a) noexcept; -simdjson_unused simdjson_really_inline bool operator!=( - const raw_json_string &a, std::string_view c) noexcept; -simdjson_unused simdjson_really_inline bool operator!=( - std::string_view c, const raw_json_string &a) noexcept; - - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -template <> -struct simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::raw_json_string> - : public SIMDJSON_BUILTIN_IMPLEMENTATION:: - implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::raw_json_string> { - public: - simdjson_really_inline simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::raw_json_string - &&value) noexcept; ///< @private - simdjson_really_inline simdjson_result( - error_code error) noexcept; ///< @private - simdjson_really_inline simdjson_result() noexcept = default; - simdjson_really_inline ~simdjson_result() noexcept = default; ///< @private - - simdjson_really_inline simdjson_result raw() const noexcept; - simdjson_really_inline simdjson_warn_unused - simdjson_result - unescape(uint8_t *&dst) const noexcept; - simdjson_really_inline simdjson_warn_unused - simdjson_result - unescape(SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::json_iterator &iter) - const noexcept; -}; - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/raw_json_string.h */ -/* begin file include/simdjson/generic/ondemand/token_iterator.h */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -/** - * Iterates through JSON tokens (`{` `}` `[` `]` `,` `:` `""` `123` - * `true` `false` `null`) - * detected by stage 1. - * - * @private This is not intended for external use. - */ -class token_iterator { - public: - /** - * Create a new invalid token_iterator. - * - * Exists so you can declare a variable and later assign to it before use. - */ - simdjson_really_inline token_iterator() noexcept = default; - simdjson_really_inline token_iterator(token_iterator &&other) noexcept = - default; - simdjson_really_inline token_iterator &operator=( - token_iterator &&other) noexcept = default; - simdjson_really_inline token_iterator( - const token_iterator &other) noexcept = default; - simdjson_really_inline token_iterator &operator=( - const token_iterator &other) noexcept = default; - - /** - * Advance to the next token (returning the current one). - */ - simdjson_really_inline const uint8_t *return_current_and_advance() noexcept; - /** - * Reports the current offset in bytes from the start of the underlying - * buffer. - */ - simdjson_really_inline uint32_t current_offset() const noexcept; - /** - * Get the JSON text for a given token (relative). - * - * This is not null-terminated; it is a view into the JSON. - * - * @param delta The relative position of the token to retrieve. e.g. 0 = - * current token, - * 1 = next token, -1 = prev token. - * - * TODO consider a string_view, assuming the length will get stripped out by - * the optimizer when - * it isn't used ... - */ - simdjson_really_inline const uint8_t *peek(int32_t delta = 0) const - noexcept; - /** - * Get the maximum length of the JSON text for a given token. - * - * The length will include any whitespace at the end of the token. - * - * @param delta The relative position of the token to retrieve. e.g. 0 = - * current token, - * 1 = next token, -1 = prev token. - */ - simdjson_really_inline uint32_t peek_length(int32_t delta = 0) const - noexcept; - - /** - * Get the JSON text for a given token. - * - * This is not null-terminated; it is a view into the JSON. - * - * @param position The position of the token. - * - */ - simdjson_really_inline const uint8_t *peek(token_position position) const - noexcept; - /** - * Get the maximum length of the JSON text for a given token. - * - * The length will include any whitespace at the end of the token. - * - * @param position The position of the token. - */ - simdjson_really_inline uint32_t peek_length(token_position position) const - noexcept; - - /** - * Return the current index. - */ - simdjson_really_inline token_position position() const noexcept; - /** - * Reset to a previously saved index. - */ - simdjson_really_inline void set_position( - token_position target_position) noexcept; - - // NOTE: we don't support a full C++ iterator interface, because we expect - // people to make - // different calls to advance the iterator based on *their own* state. - - simdjson_really_inline bool operator==(const token_iterator &other) const - noexcept; - simdjson_really_inline bool operator!=(const token_iterator &other) const - noexcept; - simdjson_really_inline bool operator>(const token_iterator &other) const - noexcept; - simdjson_really_inline bool operator>=(const token_iterator &other) const - noexcept; - simdjson_really_inline bool operator<(const token_iterator &other) const - noexcept; - simdjson_really_inline bool operator<=(const token_iterator &other) const - noexcept; - - protected: - simdjson_really_inline token_iterator(const uint8_t *buf, - token_position position) noexcept; - - /** - * Get the index of the JSON text for a given token (relative). - * - * This is not null-terminated; it is a view into the JSON. - * - * @param delta The relative position of the token to retrieve. e.g. 0 = - * current token, - * 1 = next token, -1 = prev token. - */ - simdjson_really_inline uint32_t peek_index(int32_t delta = 0) const - noexcept; - /** - * Get the index of the JSON text for a given token. - * - * This is not null-terminated; it is a view into the JSON. - * - * @param position The position of the token. - * - */ - simdjson_really_inline uint32_t peek_index(token_position position) const - noexcept; - - const uint8_t *buf{}; - token_position _position{}; - - friend class json_iterator; - friend class value_iterator; - friend class object; - friend simdjson_really_inline void logger::log_line( - const json_iterator &iter, - const char *title_prefix, - const char *title, - std::string_view detail, - int delta, - int depth_delta) noexcept; - friend simdjson_really_inline void logger::log_line( - const json_iterator &iter, - token_position index, - depth_t depth, - const char *title_prefix, - const char *title, - std::string_view detail) noexcept; -}; - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -template <> -struct simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::token_iterator> - : public SIMDJSON_BUILTIN_IMPLEMENTATION:: - implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::token_iterator> { - public: - simdjson_really_inline simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::token_iterator - &&value) noexcept; ///< @private - simdjson_really_inline simdjson_result( - error_code error) noexcept; ///< @private - simdjson_really_inline simdjson_result() noexcept = default; - simdjson_really_inline ~simdjson_result() noexcept = default; ///< @private -}; - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/token_iterator.h */ -/* begin file include/simdjson/generic/ondemand/json_iterator.h */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -class document; -class document_stream; -class object; -class array; -class value; -class raw_json_string; -class parser; - -/** - * Iterates through JSON tokens, keeping track of depth and string buffer. - * - * @private This is not intended for external use. - */ -class json_iterator { - protected: - token_iterator token{}; - ondemand::parser *parser{}; - /** - * Next free location in the string buffer. - * - * Used by raw_json_string::unescape() to have a place to unescape strings - * to. - */ - uint8_t *_string_buf_loc{}; - /** - * JSON error, if there is one. - * - * INCORRECT_TYPE and NO_SUCH_FIELD are *not* stored here, ever. - * - * PERF NOTE: we *hope* this will be elided into control flow, as it is only - * used (a) in the first - * iteration of the loop, or (b) for the final iteration after a missing - * comma is found in ++. If - * this is not elided, we should make sure it's at least not using up a - * register. Failing that, - * we should store it in document so there's only one of them. - */ - error_code error{SUCCESS}; - /** - * Depth of the current token in the JSON. - * - * - 0 = finished with document - * - 1 = document root value (could be [ or {, not yet known) - * - 2 = , or } inside root array/object - * - 3 = key or value inside root array/object. - */ - depth_t _depth{}; - /** - * Beginning of the document indexes. - * Normally we have root == parser->implementation->structural_indexes.get() - * but this may differ, especially in streaming mode (where we have several - * documents); - */ - token_position _root{}; - /** - * Normally, a json_iterator operates over a single document, but in - * some cases, we may have a stream of documents. This attribute is meant - * as meta-data: the json_iterator works the same irrespective of the - * value of this attribute. - */ - bool _streaming{false}; - - public: - simdjson_really_inline json_iterator() noexcept = default; - simdjson_really_inline json_iterator(json_iterator &&other) noexcept; - simdjson_really_inline json_iterator &operator=( - json_iterator &&other) noexcept; - simdjson_really_inline explicit json_iterator( - const json_iterator &other) noexcept = default; - simdjson_really_inline json_iterator &operator=( - const json_iterator &other) noexcept = default; - /** - * Skips a JSON value, whether it is a scalar, array or object. - */ - simdjson_warn_unused simdjson_really_inline error_code - skip_child(depth_t parent_depth) noexcept; - - /** - * Tell whether the iterator is still at the start - */ - simdjson_really_inline bool at_root() const noexcept; - - /** - * Tell whether we should be expected to run in streaming - * mode (iterating over many documents). It is pure metadata - * that does not affect how the iterator works. It is used by - * start_root_array() and start_root_object(). - */ - simdjson_really_inline bool streaming() const noexcept; - - /** - * Get the root value iterator - */ - simdjson_really_inline token_position root_position() const noexcept; - /** - * Assert that we are at the document depth (== 1) - */ - simdjson_really_inline void assert_at_document_depth() const noexcept; - /** - * Assert that we are at the root of the document - */ - simdjson_really_inline void assert_at_root() const noexcept; - - /** - * Tell whether the iterator is at the EOF mark - */ - simdjson_really_inline bool at_end() const noexcept; - - /** - * Tell whether the iterator is live (has not been moved). - */ - simdjson_really_inline bool is_alive() const noexcept; - - /** - * Abandon this iterator, setting depth to 0 (as if the document is - * finished). - */ - simdjson_really_inline void abandon() noexcept; - - /** - * Advance the current token without modifying depth. - */ - simdjson_really_inline const uint8_t *return_current_and_advance() noexcept; - - /** - * Assert that there are at least the given number of tokens left. - * - * Has no effect in release builds. - */ - simdjson_really_inline void assert_more_tokens( - uint32_t required_tokens = 1) const noexcept; - /** - * Assert that the given position addresses an actual token (is within - * bounds). - * - * Has no effect in release builds. - */ - simdjson_really_inline void assert_valid_position( - token_position position) const noexcept; - /** - * Get the JSON text for a given token (relative). - * - * This is not null-terminated; it is a view into the JSON. - * - * @param delta The relative position of the token to retrieve. e.g. 0 = - * next token, -1 = prev token. - * - * TODO consider a string_view, assuming the length will get stripped out by - * the optimizer when - * it isn't used ... - */ - simdjson_really_inline const uint8_t *peek(int32_t delta = 0) const - noexcept; - /** - * Get the maximum length of the JSON text for the current token (or - * relative). - * - * The length will include any whitespace at the end of the token. - * - * @param delta The relative position of the token to retrieve. e.g. 0 = - * next token, -1 = prev token. - */ - simdjson_really_inline uint32_t peek_length(int32_t delta = 0) const - noexcept; - /** - * Get a pointer to the current location in the input buffer. - * - * This is not null-terminated; it is a view into the JSON. - * - * You may be pointing outside of the input buffer: it is not generally - * safe to derefence this pointer. - */ - simdjson_really_inline const uint8_t *unsafe_pointer() const noexcept; - /** - * Get the JSON text for a given token. - * - * This is not null-terminated; it is a view into the JSON. - * - * @param position The position of the token to retrieve. - * - * TODO consider a string_view, assuming the length will get stripped out by - * the optimizer when - * it isn't used ... - */ - simdjson_really_inline const uint8_t *peek(token_position position) const - noexcept; - /** - * Get the maximum length of the JSON text for the current token (or - * relative). - * - * The length will include any whitespace at the end of the token. - * - * @param position The position of the token to retrieve. - */ - simdjson_really_inline uint32_t peek_length(token_position position) const - noexcept; - /** - * Get the JSON text for the last token in the document. - * - * This is not null-terminated; it is a view into the JSON. - * - * TODO consider a string_view, assuming the length will get stripped out by - * the optimizer when - * it isn't used ... - */ - simdjson_really_inline const uint8_t *peek_last() const noexcept; - - /** - * Ascend one level. - * - * Validates that the depth - 1 == parent_depth. - * - * @param parent_depth the expected parent depth. - */ - simdjson_really_inline void ascend_to(depth_t parent_depth) noexcept; - - /** - * Descend one level. - * - * Validates that the new depth == child_depth. - * - * @param child_depth the expected child depth. - */ - simdjson_really_inline void descend_to(depth_t child_depth) noexcept; - simdjson_really_inline void descend_to(depth_t child_depth, - int32_t delta) noexcept; - - /** - * Get current depth. - */ - simdjson_really_inline depth_t depth() const noexcept; - - /** - * Get current (writeable) location in the string buffer. - */ - simdjson_really_inline uint8_t *&string_buf_loc() noexcept; - - /** - * Report an unrecoverable error, preventing further iteration. - * - * @param error The error to report. Must not be SUCCESS, UNINITIALIZED, - * INCORRECT_TYPE, or NO_SUCH_FIELD. - * @param message An error message to report with the error. - */ - simdjson_really_inline error_code - report_error(error_code error, const char *message) noexcept; - - /** - * Log error, but don't stop iteration. - * @param error The error to report. Must be INCORRECT_TYPE, or - * NO_SUCH_FIELD. - * @param message An error message to report with the error. - */ - simdjson_really_inline error_code - optional_error(error_code error, const char *message) noexcept; - - template - simdjson_warn_unused simdjson_really_inline bool copy_to_buffer( - const uint8_t *json, uint32_t max_len, uint8_t (&tmpbuf)[N]) noexcept; - - simdjson_really_inline token_position position() const noexcept; - simdjson_really_inline void reenter_child(token_position position, - depth_t child_depth) noexcept; -#ifdef SIMDJSON_DEVELOPMENT_CHECKS - simdjson_really_inline token_position start_position(depth_t depth) const - noexcept; - simdjson_really_inline void set_start_position( - depth_t depth, token_position position) noexcept; -#endif - /* Useful for debugging and logging purposes. */ - inline std::string to_string() const noexcept; - - /** - * Returns the current location in the document if in bounds. - */ - inline simdjson_result current_location() noexcept; - - /** - * Updates this json iterator so that it is back at the beginning of the - * document, - * as if it had just been created. - */ - inline void rewind() noexcept; - - protected: - simdjson_really_inline json_iterator(const uint8_t *buf, - ondemand::parser *parser) noexcept; - /// The last token before the end - simdjson_really_inline token_position last_position() const noexcept; - /// The token *at* the end. This points at gibberish and should only be used - /// for comparison. - simdjson_really_inline token_position end_position() const noexcept; - /// The end of the buffer. - simdjson_really_inline token_position end() const noexcept; - - friend class document; - friend class document_stream; - friend class object; - friend class array; - friend class value; - friend class raw_json_string; - friend class parser; - friend class value_iterator; - friend simdjson_really_inline void logger::log_line( - const json_iterator &iter, - const char *title_prefix, - const char *title, - std::string_view detail, - int delta, - int depth_delta) noexcept; - friend simdjson_really_inline void logger::log_line( - const json_iterator &iter, - token_position index, - depth_t depth, - const char *title_prefix, - const char *title, - std::string_view detail) noexcept; -}; // json_iterator - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -template <> -struct simdjson_result - : public SIMDJSON_BUILTIN_IMPLEMENTATION:: - implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::json_iterator> { - public: - simdjson_really_inline simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::json_iterator - &&value) noexcept; ///< @private - simdjson_really_inline simdjson_result( - error_code error) noexcept; ///< @private - - simdjson_really_inline simdjson_result() noexcept = default; -}; - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/json_iterator.h */ -/* begin file include/simdjson/generic/ondemand/value_iterator.h */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -class document; -class object; -class array; -class value; -class raw_json_string; -class parser; - -/** - * Iterates through a single JSON value at a particular depth. - * - * Does not keep track of the type of value: provides methods for objects, - * arrays and scalars and expects - * the caller to call the right ones. - * - * @private This is not intended for external use. - */ -class value_iterator { - protected: - /** The underlying JSON iterator */ - json_iterator *_json_iter{}; - /** The depth of this value */ - depth_t _depth{}; - /** - * The starting token index for this value - */ - token_position _start_position{}; - - public: - simdjson_really_inline value_iterator() noexcept = default; - - /** - * Denote that we're starting a document. - */ - simdjson_really_inline void start_document() noexcept; - - /** - * Skips a non-iterated or partially-iterated JSON value, whether it is a - * scalar, array or object. - * - * Optimized for scalars. - */ - simdjson_warn_unused simdjson_really_inline error_code - skip_child() noexcept; - - /** - * Tell whether the iterator is at the EOF mark - */ - simdjson_really_inline bool at_end() const noexcept; - - /** - * Tell whether the iterator is at the start of the value - */ - simdjson_really_inline bool at_start() const noexcept; - - /** - * Tell whether the value is open--if the value has not been used, or the - * array/object is still open. - */ - simdjson_really_inline bool is_open() const noexcept; - - /** - * Tell whether the value is at an object's first field (just after the {). - */ - simdjson_really_inline bool at_first_field() const noexcept; - - /** - * Abandon all iteration. - */ - simdjson_really_inline void abandon() noexcept; - - /** - * Get the child value as a value_iterator. - */ - simdjson_really_inline value_iterator child_value() const noexcept; - - /** - * Get the depth of this value. - */ - simdjson_really_inline depth_t depth() const noexcept; - - /** - * Get the JSON type of this value. - * - * @error TAPE_ERROR when the JSON value is a bad token like "}" "," or - * "alse". - */ - simdjson_really_inline simdjson_result type() const noexcept; - - /** - * @addtogroup object Object iteration - * - * Methods to iterate and find object fields. These methods generally - * *assume* the value is - * actually an object; the caller is responsible for keeping track of that - * fact. - * - * @{ - */ - - /** - * Start an object iteration. - * - * @returns Whether the object had any fields (returns false for empty). - * @error INCORRECT_TYPE if there is no opening { - */ - simdjson_warn_unused simdjson_really_inline simdjson_result - start_object() noexcept; - /** - * Start an object iteration from the root. - * - * @returns Whether the object had any fields (returns false for empty). - * @error INCORRECT_TYPE if there is no opening { - * @error TAPE_ERROR if there is no matching } at end of document - */ - simdjson_warn_unused simdjson_really_inline simdjson_result - start_root_object() noexcept; - - /** - * Start an object iteration after the user has already checked and moved - * past the {. - * - * Does not move the iterator unless the object is empty ({}). - * - * @returns Whether the object had any fields (returns false for empty). - * @error INCOMPLETE_ARRAY_OR_OBJECT If there are no more tokens (implying - * the *parent* - * array or object is incomplete). - */ - simdjson_warn_unused simdjson_really_inline simdjson_result - started_object() noexcept; - /** - * Start an object iteration from the root, after the user has already - * checked and moved past the {. - * - * Does not move the iterator unless the object is empty ({}). - * - * @returns Whether the object had any fields (returns false for empty). - * @error INCOMPLETE_ARRAY_OR_OBJECT If there are no more tokens (implying - * the *parent* - * array or object is incomplete). - */ - simdjson_warn_unused simdjson_really_inline simdjson_result - started_root_object() noexcept; - - /** - * Moves to the next field in an object. - * - * Looks for , and }. If } is found, the object is finished and the iterator - * advances past it. - * Otherwise, it advances to the next value. - * - * @return whether there is another field in the object. - * @error TAPE_ERROR If there is a comma missing between fields. - * @error TAPE_ERROR If there is a comma, but not enough tokens remaining to - * have a key, :, and value. - */ - simdjson_warn_unused simdjson_really_inline simdjson_result - has_next_field() noexcept; - - /** - * Get the current field's key. - */ - simdjson_warn_unused simdjson_really_inline simdjson_result - field_key() noexcept; - - /** - * Pass the : in the field and move to its value. - */ - simdjson_warn_unused simdjson_really_inline error_code - field_value() noexcept; - - /** - * Find the next field with the given key. - * - * Assumes you have called next_field() or otherwise matched the previous - * value. - * - * This means the iterator must be sitting at the next key: - * - * ``` - * { "a": 1, "b": 2 } - * ^ - * ``` - * - * Key is *raw JSON,* meaning it will be matched against the verbatim JSON - * without attempting to - * unescape it. This works well for typical ASCII and UTF-8 keys (almost all - * of them), but may - * fail to match some keys with escapes (\u, \n, etc.). - */ - simdjson_warn_unused simdjson_really_inline error_code - find_field(const std::string_view key) noexcept; - - /** - * Find the next field with the given key, *without* unescaping. This - * assumes object order: it - * will not find the field if it was already passed when looking for some - * *other* field. - * - * Assumes you have called next_field() or otherwise matched the previous - * value. - * - * This means the iterator must be sitting at the next key: - * - * ``` - * { "a": 1, "b": 2 } - * ^ - * ``` - * - * Key is *raw JSON,* meaning it will be matched against the verbatim JSON - * without attempting to - * unescape it. This works well for typical ASCII and UTF-8 keys (almost all - * of them), but may - * fail to match some keys with escapes (\u, \n, etc.). - */ - simdjson_warn_unused simdjson_really_inline simdjson_result - find_field_raw(const std::string_view key) noexcept; - - /** - * Find the field with the given key without regard to order, and *without* - * unescaping. - * - * This is an unordered object lookup: if the field is not found initially, - * it will cycle around and scan from the beginning. - * - * Assumes you have called next_field() or otherwise matched the previous - * value. - * - * This means the iterator must be sitting at the next key: - * - * ``` - * { "a": 1, "b": 2 } - * ^ - * ``` - * - * Key is *raw JSON,* meaning it will be matched against the verbatim JSON - * without attempting to - * unescape it. This works well for typical ASCII and UTF-8 keys (almost all - * of them), but may - * fail to match some keys with escapes (\u, \n, etc.). - */ - simdjson_warn_unused simdjson_really_inline simdjson_result - find_field_unordered_raw(const std::string_view key) noexcept; - - /** @} */ - - /** - * @addtogroup array Array iteration - * Methods to iterate over array elements. These methods generally *assume* - * the value is actually - * an object; the caller is responsible for keeping track of that fact. - * @{ - */ - - /** - * Check for an opening [ and start an array iteration. - * - * @returns Whether the array had any elements (returns false for empty). - * @error INCORRECT_TYPE If there is no [. - */ - simdjson_warn_unused simdjson_really_inline simdjson_result - start_array() noexcept; - /** - * Check for an opening [ and start an array iteration while at the root. - * - * @returns Whether the array had any elements (returns false for empty). - * @error INCORRECT_TYPE If there is no [. - * @error TAPE_ERROR if there is no matching ] at end of document - */ - simdjson_warn_unused simdjson_really_inline simdjson_result - start_root_array() noexcept; - - /** - * Start an array iteration, after the user has already checked and moved - * past the [. - * - * Does not move the iterator unless the array is empty ([]). - * - * @returns Whether the array had any elements (returns false for empty). - * @error INCOMPLETE_ARRAY_OR_OBJECT If there are no more tokens (implying - * the *parent* - * array or object is incomplete). - */ - simdjson_warn_unused simdjson_really_inline simdjson_result - started_array() noexcept; - /** - * Start an array iteration from the root, after the user has already - * checked and moved past the [. - * - * Does not move the iterator unless the array is empty ([]). - * - * @returns Whether the array had any elements (returns false for empty). - * @error INCOMPLETE_ARRAY_OR_OBJECT If there are no more tokens (implying - * the *parent* - * array or object is incomplete). - */ - simdjson_warn_unused simdjson_really_inline simdjson_result - started_root_array() noexcept; - - /** - * Moves to the next element in an array. - * - * Looks for , and ]. If ] is found, the array is finished and the iterator - * advances past it. - * Otherwise, it advances to the next value. - * - * @return Whether there is another element in the array. - * @error TAPE_ERROR If there is a comma missing between elements. - */ - simdjson_warn_unused simdjson_really_inline simdjson_result - has_next_element() noexcept; - - /** - * Get a child value iterator. - */ - simdjson_warn_unused simdjson_really_inline value_iterator child() const - noexcept; - - /** @} */ - - /** - * @defgroup scalar Scalar values - * @addtogroup scalar - * @{ - */ - - simdjson_warn_unused simdjson_really_inline - simdjson_result - get_string() noexcept; - simdjson_warn_unused simdjson_really_inline simdjson_result - get_raw_json_string() noexcept; - simdjson_warn_unused simdjson_really_inline simdjson_result - get_uint64() noexcept; - simdjson_warn_unused simdjson_really_inline simdjson_result - get_uint64_in_string() noexcept; - simdjson_warn_unused simdjson_really_inline simdjson_result - get_int64() noexcept; - simdjson_warn_unused simdjson_really_inline simdjson_result - get_int64_in_string() noexcept; - simdjson_warn_unused simdjson_really_inline simdjson_result - get_double() noexcept; - simdjson_warn_unused simdjson_really_inline simdjson_result - get_double_in_string() noexcept; - simdjson_warn_unused simdjson_really_inline simdjson_result - get_bool() noexcept; - simdjson_really_inline bool is_null() noexcept; - simdjson_warn_unused simdjson_really_inline bool is_negative() noexcept; - simdjson_warn_unused simdjson_really_inline simdjson_result - is_integer() noexcept; - simdjson_warn_unused simdjson_really_inline simdjson_result - get_number_type() noexcept; - simdjson_warn_unused simdjson_really_inline simdjson_result - get_number() noexcept; - - simdjson_warn_unused simdjson_really_inline - simdjson_result - get_root_string() noexcept; - simdjson_warn_unused simdjson_really_inline simdjson_result - get_root_raw_json_string() noexcept; - simdjson_warn_unused simdjson_really_inline simdjson_result - get_root_uint64() noexcept; - simdjson_warn_unused simdjson_really_inline simdjson_result - get_root_uint64_in_string() noexcept; - simdjson_warn_unused simdjson_really_inline simdjson_result - get_root_int64() noexcept; - simdjson_warn_unused simdjson_really_inline simdjson_result - get_root_int64_in_string() noexcept; - simdjson_warn_unused simdjson_really_inline simdjson_result - get_root_double() noexcept; - simdjson_warn_unused simdjson_really_inline simdjson_result - get_root_double_in_string() noexcept; - simdjson_warn_unused simdjson_really_inline simdjson_result - get_root_bool() noexcept; - simdjson_warn_unused simdjson_really_inline bool - is_root_negative() noexcept; - simdjson_warn_unused simdjson_really_inline simdjson_result - is_root_integer() noexcept; - simdjson_warn_unused simdjson_really_inline simdjson_result - get_root_number_type() noexcept; - simdjson_warn_unused simdjson_really_inline simdjson_result - get_root_number() noexcept; - simdjson_really_inline bool is_root_null() noexcept; - - simdjson_really_inline error_code error() const noexcept; - simdjson_really_inline uint8_t *&string_buf_loc() noexcept; - simdjson_really_inline const json_iterator &json_iter() const noexcept; - simdjson_really_inline json_iterator &json_iter() noexcept; - - simdjson_really_inline void assert_is_valid() const noexcept; - simdjson_really_inline bool is_valid() const noexcept; - - /** @} */ - protected: - /** - * Restarts an array iteration. - * @returns Whether the array has any elements (returns false for empty). - */ - simdjson_really_inline simdjson_result reset_array() noexcept; - /** - * Restarts an object iteration. - * @returns Whether the object has any fields (returns false for empty). - */ - simdjson_really_inline simdjson_result reset_object() noexcept; - /** - * move_at_start(): moves us so that we are pointing at the beginning of - * the container. It updates the index so that at_start() is true and it - * syncs the depth. The user can then create a new container instance. - * - * Usage: used with value::count_elements(). - **/ - simdjson_really_inline void move_at_start() noexcept; - - /** - * move_at_container_start(): moves us so that we are pointing at the - *beginning of - * the container so that assert_at_container_start() passes. - * - * Usage: used with reset_array() and reset_object(). - **/ - simdjson_really_inline void move_at_container_start() noexcept; - /* Useful for debugging and logging purposes. */ - inline std::string to_string() const noexcept; - simdjson_really_inline value_iterator(json_iterator *json_iter, - depth_t depth, - token_position start_index) noexcept; - - simdjson_really_inline bool parse_null(const uint8_t *json) const noexcept; - simdjson_really_inline simdjson_result parse_bool( - const uint8_t *json) const noexcept; - simdjson_really_inline const uint8_t *peek_start() const noexcept; - simdjson_really_inline uint32_t peek_start_length() const noexcept; - - /** - * The general idea of the advance_... methods and the peek_* methods - * is that you first peek and check that you have desired type. If you do, - * and only if you do, then you advance. - * - * We used to unconditionally advance. But this made reasoning about our - * current state difficult. - * Suppose you always advance. Look at the 'value' matching the key - * "shadowable" in the following example... - * - * ({"globals":{"a":{"shadowable":[}}}}) - * - * If the user thinks it is a Boolean and asks for it, then we check the - * '[', - * decide it is not a Boolean, but still move into the next character ('}'). - * Now - * we are left pointing at '}' right after a '['. And we have not yet - * reported - * an error, only that we do not have a Boolean. - * - * If, instead, you just stand your ground until it is content that you - * know, then - * you will only even move beyond the '[' if the user tells you that you - * have an - * array. So you will be at the '}' character inside the array and, - * hopefully, you - * will then catch the error because an array cannot start with '}', but the - * code - * processing Boolean values does not know this. - * - * So the contract is: first call 'peek_...' and then call 'advance_...' - * only - * if you have determined that it is a type you can handle. - * - * Unfortunately, it makes the code more verbose, longer and maybe more - * error prone. - */ - - simdjson_really_inline void advance_scalar(const char *type) noexcept; - simdjson_really_inline void advance_root_scalar(const char *type) noexcept; - simdjson_really_inline void advance_non_root_scalar( - const char *type) noexcept; - - simdjson_really_inline const uint8_t *peek_scalar( - const char *type) noexcept; - simdjson_really_inline const uint8_t *peek_root_scalar( - const char *type) noexcept; - simdjson_really_inline const uint8_t *peek_non_root_scalar( - const char *type) noexcept; - - - simdjson_really_inline error_code - start_container(uint8_t start_char, - const char *incorrect_type_message, - const char *type) noexcept; - simdjson_really_inline error_code end_container() noexcept; - - /** - * Advance to a place expecting a value (increasing depth). - * - * @return The current token (the one left behind). - * @error TAPE_ERROR If the document ended early. - */ - simdjson_really_inline simdjson_result - advance_to_value() noexcept; - - simdjson_really_inline error_code - incorrect_type_error(const char *message) const noexcept; - simdjson_really_inline error_code - error_unless_more_tokens(uint32_t tokens = 1) const noexcept; - - simdjson_really_inline bool is_at_start() const noexcept; - /** - * is_at_iterator_start() returns true on an array or object after it has - * just been - * created, whether the instance is empty or not. - * - * Usage: used by array::begin() in debug mode (SIMDJSON_DEVELOPMENT_CHECKS) - */ - simdjson_really_inline bool is_at_iterator_start() const noexcept; - - /** - * Assuming that we are within an object, this returns true if we - * are pointing at a key. - * - * Usage: the skip_child() method should never be used while we are pointing - * at a key inside an object. - */ - simdjson_really_inline bool is_at_key() const noexcept; - - inline void assert_at_start() const noexcept; - inline void assert_at_container_start() const noexcept; - inline void assert_at_root() const noexcept; - inline void assert_at_child() const noexcept; - inline void assert_at_next() const noexcept; - inline void assert_at_non_root_start() const noexcept; - - /** Get the starting position of this value */ - simdjson_really_inline token_position start_position() const noexcept; - - /** @copydoc error_code json_iterator::position() const noexcept; */ - simdjson_really_inline token_position position() const noexcept; - /** @copydoc error_code json_iterator::end_position() const noexcept; */ - simdjson_really_inline token_position last_position() const noexcept; - /** @copydoc error_code json_iterator::end_position() const noexcept; */ - simdjson_really_inline token_position end_position() const noexcept; - /** @copydoc error_code json_iterator::report_error(error_code error, const - * char *message) noexcept; */ - simdjson_really_inline error_code - report_error(error_code error, const char *message) noexcept; - - friend class document; - friend class object; - friend class array; - friend class value; -}; // value_iterator - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -template <> -struct simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value_iterator> - : public SIMDJSON_BUILTIN_IMPLEMENTATION:: - implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value_iterator> { - public: - simdjson_really_inline simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value_iterator - &&value) noexcept; ///< @private - simdjson_really_inline simdjson_result( - error_code error) noexcept; ///< @private - simdjson_really_inline simdjson_result() noexcept = default; -}; - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/value_iterator.h */ -/* begin file include/simdjson/generic/ondemand/array_iterator.h */ - -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -class array; -class value; -class document; - -/** - * A forward-only JSON array. - * - * This is an input_iterator, meaning: - * - It is forward-only - * - * must be called exactly once per element. - * - ++ must be called exactly once in between each * (*, ++, *, ++, * ...) - */ -class array_iterator { - public: - /** Create a new, invalid array iterator. */ - simdjson_really_inline array_iterator() noexcept = default; - - // - // Iterator interface - // - - /** - * Get the current element. - * - * Part of the std::iterator interface. - */ - simdjson_really_inline simdjson_result operator - *() noexcept; // MUST ONLY BE CALLED ONCE PER ITERATION. - /** - * Check if we are at the end of the JSON. - * - * Part of the std::iterator interface. - * - * @return true if there are no more elements in the JSON array. - */ - simdjson_really_inline bool operator==(const array_iterator &) const - noexcept; - /** - * Check if there are more elements in the JSON array. - * - * Part of the std::iterator interface. - * - * @return true if there are more elements in the JSON array. - */ - simdjson_really_inline bool operator!=(const array_iterator &) const - noexcept; - /** - * Move to the next element. - * - * Part of the std::iterator interface. - */ - simdjson_really_inline array_iterator &operator++() noexcept; - - private: - value_iterator iter{}; - - simdjson_really_inline array_iterator(const value_iterator &iter) noexcept; - - friend class array; - friend class value; - friend struct simdjson_result; -}; - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -template <> -struct simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator> - : public SIMDJSON_BUILTIN_IMPLEMENTATION:: - implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator> { - public: - simdjson_really_inline simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator - &&value) noexcept; ///< @private - simdjson_really_inline simdjson_result( - error_code error) noexcept; ///< @private - simdjson_really_inline simdjson_result() noexcept = default; - - // - // Iterator interface - // - - simdjson_really_inline - simdjson_result - operator*() noexcept; // MUST ONLY BE CALLED ONCE PER ITERATION. - simdjson_really_inline bool operator==( - const simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator> &) const - noexcept; - simdjson_really_inline bool operator!=( - const simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator> &) const - noexcept; - simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator> - &operator++() noexcept; -}; - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/array_iterator.h */ -/* begin file include/simdjson/generic/ondemand/object_iterator.h */ - -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -class field; - -class object_iterator { - public: - /** - * Create a new invalid object_iterator. - * - * Exists so you can declare a variable and later assign to it before use. - */ - simdjson_really_inline object_iterator() noexcept = default; - - // - // Iterator interface - // - - // Reads key and value, yielding them to the user. - // MUST ONLY BE CALLED ONCE PER ITERATION. - simdjson_really_inline simdjson_result operator*() noexcept; - // Assumes it's being compared with the end. true if depth < iter->depth. - simdjson_really_inline bool operator==(const object_iterator &) const - noexcept; - // Assumes it's being compared with the end. true if depth >= iter->depth. - simdjson_really_inline bool operator!=(const object_iterator &) const - noexcept; - // Checks for ']' and ',' - simdjson_really_inline object_iterator &operator++() noexcept; - - private: - /** - * The underlying JSON iterator. - * - * PERF NOTE: expected to be elided in favor of the parent document: this is - * set when the object - * is first used, and never changes afterwards. - */ - value_iterator iter{}; - - simdjson_really_inline object_iterator(const value_iterator &iter) noexcept; - friend struct simdjson_result; - friend class object; -}; - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -template <> -struct simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object_iterator> - : public SIMDJSON_BUILTIN_IMPLEMENTATION:: - implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object_iterator> { - public: - simdjson_really_inline simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object_iterator - &&value) noexcept; ///< @private - simdjson_really_inline simdjson_result( - error_code error) noexcept; ///< @private - simdjson_really_inline simdjson_result() noexcept = default; - - // - // Iterator interface - // - - // Reads key and value, yielding them to the user. - simdjson_really_inline - simdjson_result - operator*() noexcept; // MUST ONLY BE CALLED ONCE PER ITERATION. - // Assumes it's being compared with the end. true if depth < iter->depth. - simdjson_really_inline bool operator==( - const simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object_iterator> &) const - noexcept; - // Assumes it's being compared with the end. true if depth >= iter->depth. - simdjson_really_inline bool operator!=( - const simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object_iterator> &) const - noexcept; - // Checks for ']' and ',' - simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object_iterator> - &operator++() noexcept; -}; - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/object_iterator.h */ -/* begin file include/simdjson/generic/ondemand/array.h */ - -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -class value; -class document; - -/** - * A forward-only JSON array. - */ -class array { - public: - /** - * Create a new invalid array. - * - * Exists so you can declare a variable and later assign to it before use. - */ - simdjson_really_inline array() noexcept = default; - - /** - * Begin array iteration. - * - * Part of the std::iterable interface. - */ - simdjson_really_inline simdjson_result begin() noexcept; - /** - * Sentinel representing the end of the array. - * - * Part of the std::iterable interface. - */ - simdjson_really_inline simdjson_result end() noexcept; - /** - * This method scans the array and counts the number of elements. - * The count_elements method should always be called before you have begun - * iterating through the array: it is expected that you are pointing at - * the beginning of the array. - * The runtime complexity is linear in the size of the array. After - * calling this function, if successful, the array is 'rewinded' at its - * beginning as if it had never been accessed. If the JSON is malformed - * (e.g., - * there is a missing comma), then an error is returned and it is no longer - * safe to continue. - * - * To check that an array is empty, it is more performant to use - * the is_empty() method. - */ - simdjson_really_inline simdjson_result count_elements() & noexcept; - /** - * This method scans the beginning of the array and checks whether the - * array is empty. - * The runtime complexity is constant time. After - * calling this function, if successful, the array is 'rewinded' at its - * beginning as if it had never been accessed. If the JSON is malformed - * (e.g., - * there is a missing comma), then an error is returned and it is no longer - * safe to continue. - */ - simdjson_really_inline simdjson_result is_empty() & noexcept; - /** - * Reset the iterator so that we are pointing back at the - * beginning of the array. You should still consume values only once even if - * you - * can iterate through the array more than once. If you unescape a string - * within the array more than once, you have unsafe code. Note that - * rewinding - * an array means that you may need to reparse it anew: it is not a free - * operation. - * - * @returns true if the array contains some elements (not empty) - */ - inline simdjson_result reset() & noexcept; - /** - * Get the value associated with the given JSON pointer. We use the RFC - * 6901 - * https://tools.ietf.org/html/rfc6901 standard, interpreting the current - * node - * as the root of its own JSON document. - * - * ondemand::parser parser; - * auto json = R"([ { "foo": { "a": [ 10, 20, 30 ] }} ])"_padded; - * auto doc = parser.iterate(json); - * doc.at_pointer("/0/foo/a/1") == 20 - * - * Note that at_pointer() called on the document automatically calls the - * document's rewind - * method between each call. It invalidates all previously accessed arrays, - * objects and values - * that have not been consumed. Yet it is not the case when calling - * at_pointer on an array - * instance: there is no rewind and no invalidation. - * - * You may only call at_pointer on an array after it has been created, but - * before it has - * been first accessed. When calling at_pointer on an array, the pointer is - * advanced to - * the location indicated by the JSON pointer (in case of success). It is no - * longer possible - * to call at_pointer on the same array. - * - * Also note that at_pointer() relies on find_field() which implies that we - * do not unescape keys when matching. - * - * @return The value associated with the given JSON pointer, or: - * - NO_SUCH_FIELD if a field does not exist in an object - * - INDEX_OUT_OF_BOUNDS if an array index is larger than an array - * length - * - INCORRECT_TYPE if a non-integer is used to access an array - * - INVALID_JSON_POINTER if the JSON pointer is invalid and cannot - * be parsed - */ - inline simdjson_result at_pointer( - std::string_view json_pointer) noexcept; - /** - * Consumes the array and returns a string_view instance corresponding to - * the - * array as represented in JSON. It points inside the original document. - */ - simdjson_really_inline simdjson_result - raw_json() noexcept; - - /** - * Get the value at the given index. This function has linear-time - * complexity. - * This function should only be called once on an array instance since the - * array iterator is not reset between each call. - * - * @return The value at the given index, or: - * - INDEX_OUT_OF_BOUNDS if the array index is larger than an array - * length - */ - simdjson_really_inline simdjson_result at(size_t index) noexcept; - - protected: - /** - * Go to the end of the array, no matter where you are right now. - */ - simdjson_really_inline error_code consume() noexcept; - - /** - * Begin array iteration. - * - * @param iter The iterator. Must be where the initial [ is expected. Will - * be *moved* into the - * resulting array. - * @error INCORRECT_TYPE if the iterator is not at [. - */ - static simdjson_really_inline simdjson_result start( - value_iterator &iter) noexcept; - /** - * Begin array iteration from the root. - * - * @param iter The iterator. Must be where the initial [ is expected. Will - * be *moved* into the - * resulting array. - * @error INCORRECT_TYPE if the iterator is not at [. - * @error TAPE_ERROR if there is no closing ] at the end of the document. - */ - static simdjson_really_inline simdjson_result start_root( - value_iterator &iter) noexcept; - /** - * Begin array iteration. - * - * This version of the method should be called after the initial [ has been - * verified, and is - * intended for use by switch statements that check the type of a value. - * - * @param iter The iterator. Must be after the initial [. Will be *moved* - * into the resulting array. - */ - static simdjson_really_inline simdjson_result started( - value_iterator &iter) noexcept; - - /** - * Create an array at the given Internal array creation. Call array::start() - * or array::started() instead of this. - * - * @param iter The iterator. Must either be at the start of the first - * element with iter.is_alive() - * == true, or past the [] with is_alive() == false if the array is - * empty. Will be *moved* - * into the resulting array. - */ - simdjson_really_inline array(const value_iterator &iter) noexcept; - - /** - * Iterator marking current position. - * - * iter.is_alive() == false indicates iteration is complete. - */ - value_iterator iter{}; - - friend class value; - friend class document; - friend struct simdjson_result; - friend struct simdjson_result; - friend class array_iterator; -}; - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -template <> -struct simdjson_result - : public SIMDJSON_BUILTIN_IMPLEMENTATION:: - implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array> { - public: - simdjson_really_inline simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array - &&value) noexcept; ///< @private - simdjson_really_inline simdjson_result( - error_code error) noexcept; ///< @private - simdjson_really_inline simdjson_result() noexcept = default; - - simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator> - begin() noexcept; - simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator> - end() noexcept; - inline simdjson_result count_elements() & noexcept; - inline simdjson_result is_empty() & noexcept; - inline simdjson_result reset() & noexcept; - simdjson_really_inline - simdjson_result - at(size_t index) noexcept; - simdjson_really_inline - simdjson_result - at_pointer(std::string_view json_pointer) noexcept; -}; - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/array.h */ -/* begin file include/simdjson/generic/ondemand/document.h */ - -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -class parser; -class array; -class object; -class value; -class raw_json_string; -class array_iterator; -class document_stream; - -/** - * A JSON document. It holds a json_iterator instance. - * - * Used by tokens to get text, and string buffer location. - * - * You must keep the document around during iteration. - */ -class document { - public: - /** - * Create a new invalid document. - * - * Exists so you can declare a variable and later assign to it before use. - */ - simdjson_really_inline document() noexcept = default; - simdjson_really_inline document(const document &other) noexcept = - delete; // pass your documents by reference, not by copy - simdjson_really_inline document(document &&other) noexcept = default; - simdjson_really_inline document &operator=(const document &other) noexcept = - delete; - simdjson_really_inline document &operator=(document &&other) noexcept = - default; - - /** - * Cast this JSON value to an array. - * - * @returns An object that can be used to iterate the array. - * @returns INCORRECT_TYPE If the JSON value is not an array. - */ - simdjson_really_inline simdjson_result get_array() & noexcept; - /** - * Cast this JSON value to an object. - * - * @returns An object that can be used to look up or iterate fields. - * @returns INCORRECT_TYPE If the JSON value is not an object. - */ - simdjson_really_inline simdjson_result get_object() & noexcept; - /** - * Cast this JSON value to an unsigned integer. - * - * @returns A signed 64-bit integer. - * @returns INCORRECT_TYPE If the JSON value is not a 64-bit unsigned - * integer. - */ - simdjson_really_inline simdjson_result get_uint64() noexcept; - /** - * Cast this JSON value (inside string) to an unsigned integer. - * - * @returns A signed 64-bit integer. - * @returns INCORRECT_TYPE If the JSON value is not a 64-bit unsigned - * integer. - */ - simdjson_really_inline simdjson_result - get_uint64_in_string() noexcept; - /** - * Cast this JSON value to a signed integer. - * - * @returns A signed 64-bit integer. - * @returns INCORRECT_TYPE If the JSON value is not a 64-bit integer. - */ - simdjson_really_inline simdjson_result get_int64() noexcept; - /** - * Cast this JSON value (inside string) to a signed integer. - * - * @returns A signed 64-bit integer. - * @returns INCORRECT_TYPE If the JSON value is not a 64-bit integer. - */ - simdjson_really_inline simdjson_result - get_int64_in_string() noexcept; - /** - * Cast this JSON value to a double. - * - * @returns A double. - * @returns INCORRECT_TYPE If the JSON value is not a valid floating-point - * number. - */ - simdjson_really_inline simdjson_result get_double() noexcept; - - /** - * Cast this JSON value (inside string) to a double. - * - * @returns A double. - * @returns INCORRECT_TYPE If the JSON value is not a valid floating-point - * number. - */ - simdjson_really_inline simdjson_result - get_double_in_string() noexcept; - /** - * Cast this JSON value to a string. - * - * The string is guaranteed to be valid UTF-8. - * - * Important: Calling get_string() twice on the same document is an error. - * - * @returns An UTF-8 string. The string is stored in the parser and will be - * invalidated the next - * time it parses a document or when it is destroyed. - * @returns INCORRECT_TYPE if the JSON value is not a string. - */ - simdjson_really_inline simdjson_result - get_string() noexcept; - /** - * Cast this JSON value to a raw_json_string. - * - * The string is guaranteed to be valid UTF-8, and may have escapes in it - * (e.g. \\ or \n). - * - * @returns A pointer to the raw JSON for the given string. - * @returns INCORRECT_TYPE if the JSON value is not a string. - */ - simdjson_really_inline simdjson_result - get_raw_json_string() noexcept; - /** - * Cast this JSON value to a bool. - * - * @returns A bool value. - * @returns INCORRECT_TYPE if the JSON value is not true or false. - */ - simdjson_really_inline simdjson_result get_bool() noexcept; - /** - * Cast this JSON value to a value when the document is an object or an - * array. - * - * @returns A value if a JSON array or object cannot be found. - * @returns SCALAR_DOCUMENT_AS_VALUE error is the document is a scalar (see - * is_scalar() function). - */ - simdjson_really_inline simdjson_result get_value() noexcept; - - /** - * Checks if this JSON value is null. - * - * @returns Whether the value is null. - */ - simdjson_really_inline bool is_null() noexcept; - - /** - * Get this value as the given type. - * - * Supported types: object, array, raw_json_string, string_view, uint64_t, - * int64_t, double, bool - * - * You may use get_double(), get_bool(), get_uint64(), get_int64(), - * get_object(), get_array(), get_raw_json_string(), or get_string() - * instead. - * - * @returns A value of the given type, parsed from the JSON. - * @returns INCORRECT_TYPE If the JSON value is not the given type. - */ - template - simdjson_really_inline simdjson_result get() & noexcept { - // Unless the simdjson library provides an inline implementation, - // calling this method should - // immediately fail. - static_assert(!sizeof(T), - "The get method with given type is not implemented by " - "the simdjson library."); - } - /** @overload template simdjson_result get() & noexcept */ - template - simdjson_really_inline simdjson_result get() && noexcept { - // Unless the simdjson library provides an inline implementation, - // calling this method should - // immediately fail. - static_assert(!sizeof(T), - "The get method with given type is not implemented by " - "the simdjson library."); - } - - /** - * Get this value as the given type. - * - * Supported types: object, array, raw_json_string, string_view, uint64_t, - * int64_t, double, bool, value - * - * Be mindful that the document instance must remain in scope while you are - * accessing object, array and value instances. - * - * @param out This is set to a value of the given type, parsed from the - * JSON. If there is an error, this may not be initialized. - * @returns INCORRECT_TYPE If the JSON value is not an object. - * @returns SUCCESS If the parse succeeded and the out parameter was set to - * the value. - */ - template - simdjson_really_inline error_code get(T &out) & noexcept; - /** @overload template error_code get(T &out) & noexcept */ - template - simdjson_really_inline error_code get(T &out) && noexcept; - -#if SIMDJSON_EXCEPTIONS - /** - * Cast this JSON value to an array. - * - * @returns An object that can be used to iterate the array. - * @exception simdjson_error(INCORRECT_TYPE) If the JSON value is not an - * array. - */ - simdjson_really_inline operator array() & noexcept(false); - /** - * Cast this JSON value to an object. - * - * @returns An object that can be used to look up or iterate fields. - * @exception simdjson_error(INCORRECT_TYPE) If the JSON value is not an - * object. - */ - simdjson_really_inline operator object() & noexcept(false); - /** - * Cast this JSON value to an unsigned integer. - * - * @returns A signed 64-bit integer. - * @exception simdjson_error(INCORRECT_TYPE) If the JSON value is not a - * 64-bit unsigned integer. - */ - simdjson_really_inline operator uint64_t() noexcept(false); - /** - * Cast this JSON value to a signed integer. - * - * @returns A signed 64-bit integer. - * @exception simdjson_error(INCORRECT_TYPE) If the JSON value is not a - * 64-bit integer. - */ - simdjson_really_inline operator int64_t() noexcept(false); - /** - * Cast this JSON value to a double. - * - * @returns A double. - * @exception simdjson_error(INCORRECT_TYPE) If the JSON value is not a - * valid floating-point number. - */ - simdjson_really_inline operator double() noexcept(false); - /** - * Cast this JSON value to a string. - * - * The string is guaranteed to be valid UTF-8. - * - * @returns An UTF-8 string. The string is stored in the parser and will be - * invalidated the next - * time it parses a document or when it is destroyed. - * @exception simdjson_error(INCORRECT_TYPE) if the JSON value is not a - * string. - */ - simdjson_really_inline operator std::string_view() noexcept(false); - /** - * Cast this JSON value to a raw_json_string. - * - * The string is guaranteed to be valid UTF-8, and may have escapes in it - * (e.g. \\ or \n). - * - * @returns A pointer to the raw JSON for the given string. - * @exception simdjson_error(INCORRECT_TYPE) if the JSON value is not a - * string. - */ - simdjson_really_inline operator raw_json_string() noexcept(false); - /** - * Cast this JSON value to a bool. - * - * @returns A bool value. - * @exception simdjson_error(INCORRECT_TYPE) if the JSON value is not true - * or false. - */ - simdjson_really_inline operator bool() noexcept(false); - /** - * Cast this JSON value to a value. - * - * @returns A value value. - * @exception if a JSON value cannot be found - */ - simdjson_really_inline operator value() noexcept(false); -#endif - /** - * This method scans the array and counts the number of elements. - * The count_elements method should always be called before you have begun - * iterating through the array: it is expected that you are pointing at - * the beginning of the array. - * The runtime complexity is linear in the size of the array. After - * calling this function, if successful, the array is 'rewinded' at its - * beginning as if it had never been accessed. If the JSON is malformed - * (e.g., - * there is a missing comma), then an error is returned and it is no longer - * safe to continue. - */ - simdjson_really_inline simdjson_result count_elements() & noexcept; - /** - * This method scans the object and counts the number of key-value pairs. - * The count_fields method should always be called before you have begun - * iterating through the object: it is expected that you are pointing at - * the beginning of the object. - * The runtime complexity is linear in the size of the object. After - * calling this function, if successful, the object is 'rewinded' at its - * beginning as if it had never been accessed. If the JSON is malformed - * (e.g., - * there is a missing comma), then an error is returned and it is no longer - * safe to continue. - * - * To check that an object is empty, it is more performant to use - * the is_empty() method. - */ - simdjson_really_inline simdjson_result count_fields() & noexcept; - /** - * Get the value at the given index in the array. This function has - * linear-time complexity. - * This function should only be called once on an array instance since the - * array iterator is not reset between each call. - * - * @return The value at the given index, or: - * - INDEX_OUT_OF_BOUNDS if the array index is larger than an array - * length - */ - simdjson_really_inline simdjson_result at(size_t index) & noexcept; - /** - * Begin array iteration. - * - * Part of the std::iterable interface. - */ - simdjson_really_inline simdjson_result begin() & noexcept; - /** - * Sentinel representing the end of the array. - * - * Part of the std::iterable interface. - */ - simdjson_really_inline simdjson_result end() & noexcept; - - /** - * Look up a field by name on an object (order-sensitive). - * - * The following code reads z, then y, then x, and thus will not retrieve x - * or y if fed the - * JSON `{ "x": 1, "y": 2, "z": 3 }`: - * - * ```c++ - * simdjson::ondemand::parser parser; - * auto obj = parser.parse(R"( { "x": 1, "y": 2, "z": 3 } )"_padded); - * double z = obj.find_field("z"); - * double y = obj.find_field("y"); - * double x = obj.find_field("x"); - * ``` - * - * **Raw Keys:** The lookup will be done against the *raw* key, and will not - * unescape keys. - * e.g. `object["a"]` will match `{ "a": 1 }`, but will *not* match `{ - * "\u0061": 1 }`. - * - * - * You must consume the fields on an object one at a time. A request for a - * new key - * invalidates previous field values: it makes them unsafe. E.g., the array - * given by content["bids"].get_array() should not be accessed after you - * have called - * content["asks"].get_array(). You can detect such mistakes by first - * compiling and running - * the code in Debug mode (or with the macro `SIMDJSON_DEVELOPMENT_CHECKS` - * set to 1): an - * OUT_OF_ORDER_ITERATION error is generated. - * - * You are expected to access keys only once. You should access the value - * corresponding to - * a key a single time. Doing object["mykey"].to_string()and then again - * object["mykey"].to_string() - * is an error. - * - * @param key The key to look up. - * @returns The value of the field, or NO_SUCH_FIELD if the field is not in - * the object. - */ - simdjson_really_inline simdjson_result find_field( - std::string_view key) & - noexcept; - /** @overload simdjson_really_inline simdjson_result - * find_field(std::string_view key) & noexcept; */ - simdjson_really_inline simdjson_result find_field(const char *key) & - noexcept; - - /** - * Look up a field by name on an object, without regard to key order. - * - * **Performance Notes:** This is a bit less performant than find_field(), - * though its effect varies - * and often appears negligible. It starts out normally, starting out at the - * last field; but if - * the field is not found, it scans from the beginning of the object to see - * if it missed it. That - * missing case has a non-cache-friendly bump and lots of extra scanning, - * especially if the object - * in question is large. The fact that the extra code is there also bumps - * the executable size. - * - * It is the default, however, because it would be highly surprising (and - * hard to debug) if the - * default behavior failed to look up a field just because it was in the - * wrong order--and many - * APIs assume this. Therefore, you must be explicit if you want to treat - * objects as out of order. - * - * Use find_field() if you are sure fields will be in order (or are willing - * to treat it as if the - * field wasn't there when they aren't). - * - * You must consume the fields on an object one at a time. A request for a - * new key - * invalidates previous field values: it makes them unsafe. E.g., the array - * given by content["bids"].get_array() should not be accessed after you - * have called - * content["asks"].get_array(). You can detect such mistakes by first - * compiling and running - * the code in Debug mode (or with the macro `SIMDJSON_DEVELOPMENT_CHECKS` - * set to 1): an - * OUT_OF_ORDER_ITERATION error is generated. - * - * You are expected to access keys only once. You should access the value - * corresponding to a key - * a single time. Doing object["mykey"].to_string() and then again - * object["mykey"].to_string() - * is an error. - * - * @param key The key to look up. - * @returns The value of the field, or NO_SUCH_FIELD if the field is not in - * the object. - */ - simdjson_really_inline simdjson_result find_field_unordered( - std::string_view key) & - noexcept; - /** @overload simdjson_really_inline simdjson_result - * find_field_unordered(std::string_view key) & noexcept; */ - simdjson_really_inline simdjson_result find_field_unordered( - const char *key) & - noexcept; - /** @overload simdjson_really_inline simdjson_result - * find_field_unordered(std::string_view key) & noexcept; */ - simdjson_really_inline simdjson_result operator[]( - std::string_view key) & - noexcept; - /** @overload simdjson_really_inline simdjson_result - * find_field_unordered(std::string_view key) & noexcept; */ - simdjson_really_inline simdjson_result operator[](const char *key) & - noexcept; - - /** - * Get the type of this JSON value. - * - * NOTE: If you're only expecting a value to be one type (a typical case), - * it's generally - * better to just call .get_double, .get_string, etc. and check for - * INCORRECT_TYPE (or just - * let it throw an exception). - * - * @error TAPE_ERROR when the JSON value is a bad token like "}" "," or - * "alse". - */ - simdjson_really_inline simdjson_result type() noexcept; - - /** - * Checks whether the document is a scalar (string, number, null, Boolean). - * Returns false when there it is an array or object. - * - * @returns true if the type is string, number, null, Boolean - * @error TAPE_ERROR when the JSON value is a bad token like "}" "," or - * "alse". - */ - simdjson_really_inline simdjson_result is_scalar() noexcept; - - /** - * Checks whether the document is a negative number. - * - * @returns true if the number if negative. - */ - simdjson_really_inline bool is_negative() noexcept; - /** - * Checks whether the document is an integer number. Note that - * this requires to partially parse the number string. If - * the value is determined to be an integer, it may still - * not parse properly as an integer in subsequent steps - * (e.g., it might overflow). - * - * @returns true if the number if negative. - */ - simdjson_really_inline simdjson_result is_integer() noexcept; - /** - * Determine the number type (integer or floating-point number). - * - * get_number_type() is number_type::unsigned_integer if we have - * an integer greater or equal to 9223372036854775808 - * get_number_type() is number_type::signed_integer if we have an - * integer that is less than 9223372036854775808 - * Otherwise, get_number_type() has value number_type::floating_point_number - * - * This function req - * uires processing the number string, but it is expected - * to be faster than get_number().get_number_type() because it is does not - * parse the number value. - * - * @returns the type of the number - */ - simdjson_really_inline simdjson_result - get_number_type() noexcept; - - /** - * Attempt to parse an ondemand::number. An ondemand::number may - * contain an integer value or a floating-point value, the simdjson - * library will autodetect the type. Thus it is a dynamically typed - * number. Before accessing the value, you must determine the detected - * type. - * - * number.get_number_type() is number_type::signed_integer if we have - * an integer in [-9223372036854775808,9223372036854775808) - * You can recover the value by calling number.get_int64() and you - * have that number.is_int64() is true. - * - * number.get_number_type() is number_type::unsigned_integer if we have - * an integer in [9223372036854775808,18446744073709551616) - * You can recover the value by calling number.get_uint64() and you - * have that number.is_uint64() is true. - * - * Otherwise, number.get_number_type() has value - * number_type::floating_point_number - * and we have a binary64 number. - * You can recover the value by calling number.get_double() and you - * have that number.is_double() is true. - * - * You must check the type before accessing the value: it is an error - * to call "get_int64()" when number.get_number_type() is not - * number_type::signed_integer and when number.is_int64() is false. - */ - simdjson_warn_unused simdjson_really_inline simdjson_result - get_number() noexcept; - - /** - * Get the raw JSON for this token. - * - * The string_view will always point into the input buffer. - * - * The string_view will start at the beginning of the token, and include the - * entire token - * *as well as all spaces until the next token (or EOF).* This means, for - * example, that a - * string token always begins with a " and is always terminated by the final - * ", possibly - * followed by a number of spaces. - * - * The string_view is *not* null-terminated. If this is a scalar (string, - * number, - * boolean, or null), the character after the end of the string_view may be - * the padded buffer. - * - * Tokens include: - * - { - * - [ - * - "a string (possibly with UTF-8 or backslashed characters like \\\")". - * - -1.2e-100 - * - true - * - false - * - null - */ - simdjson_really_inline simdjson_result - raw_json_token() noexcept; - - /** - * Reset the iterator inside the document instance so we are pointing back - * at the - * beginning of the document, as if it had just been created. It invalidates - * all - * values, objects and arrays that you have created so far (including - * unescaped strings). - */ - inline void rewind() noexcept; - /** - * Returns debugging information. - */ - inline std::string to_debug_string() noexcept; - /** - * Some unrecoverable error conditions may render the document instance - * unusable. - * The is_alive() method returns true when the document is still suitable. - */ - inline bool is_alive() noexcept; - - /** - * Returns the current location in the document if in bounds. - */ - inline simdjson_result current_location() noexcept; - - /** - * Get the value associated with the given JSON pointer. We use the RFC - * 6901 - * https://tools.ietf.org/html/rfc6901 standard. - * - * ondemand::parser parser; - * auto json = R"({ "foo": { "a": [ 10, 20, 30 ] }})"_padded; - * auto doc = parser.iterate(json); - * doc.at_pointer("/foo/a/1") == 20 - * - * It is allowed for a key to be the empty string: - * - * ondemand::parser parser; - * auto json = R"({ "": { "a": [ 10, 20, 30 ] }})"_padded; - * auto doc = parser.iterate(json); - * doc.at_pointer("//a/1") == 20 - * - * Note that at_pointer() automatically calls rewind between each call. Thus - * all values, objects and arrays that you have created so far (including - * unescaped strings) - * are invalidated. After calling at_pointer, you need to consume the - * result: string values - * should be stored in your own variables, arrays should be decoded and - * stored in your own array-like - * structures and so forth. - * - * Also note that at_pointer() relies on find_field() which implies that we - * do not unescape keys when matching - * - * @return The value associated with the given JSON pointer, or: - * - NO_SUCH_FIELD if a field does not exist in an object - * - INDEX_OUT_OF_BOUNDS if an array index is larger than an array - * length - * - INCORRECT_TYPE if a non-integer is used to access an array - * - INVALID_JSON_POINTER if the JSON pointer is invalid and cannot - * be parsed - * - SCALAR_DOCUMENT_AS_VALUE if the json_pointer is empty and the - * document is not a scalar (see is_scalar() function). - */ - simdjson_really_inline simdjson_result at_pointer( - std::string_view json_pointer) noexcept; - /** - * Consumes the document and returns a string_view instance corresponding to - * the - * document as represented in JSON. It points inside the original byte array - * containing - * the JSON document. - */ - simdjson_really_inline simdjson_result - raw_json() noexcept; - - protected: - /** - * Consumes the document. - */ - simdjson_really_inline error_code consume() noexcept; - - simdjson_really_inline document(ondemand::json_iterator &&iter) noexcept; - simdjson_really_inline const uint8_t *text(uint32_t idx) const noexcept; - - simdjson_really_inline value_iterator resume_value_iterator() noexcept; - simdjson_really_inline value_iterator get_root_value_iterator() noexcept; - simdjson_really_inline simdjson_result - start_or_resume_object() noexcept; - static simdjson_really_inline document - start(ondemand::json_iterator &&iter) noexcept; - - // - // Fields - // - json_iterator iter{}; ///< Current position in the document - static constexpr depth_t DOCUMENT_DEPTH = - 0; ///< document depth is always 0 - - friend class array_iterator; - friend class value; - friend class ondemand::parser; - friend class object; - friend class array; - friend class field; - friend class token; - friend class document_stream; -}; - - -/** - * A document_reference is a thin wrapper around a document reference instance. - */ -class document_reference { - public: - simdjson_really_inline document_reference() noexcept; - simdjson_really_inline document_reference(document &d) noexcept; - simdjson_really_inline document_reference( - const document_reference &other) noexcept = default; - simdjson_really_inline document_reference &operator=( - const document_reference &other) noexcept = default; - simdjson_really_inline void rewind() noexcept; - simdjson_really_inline simdjson_result get_array() & noexcept; - simdjson_really_inline simdjson_result get_object() & noexcept; - simdjson_really_inline simdjson_result get_uint64() noexcept; - simdjson_really_inline simdjson_result get_int64() noexcept; - simdjson_really_inline simdjson_result get_double() noexcept; - simdjson_really_inline simdjson_result - get_string() noexcept; - simdjson_really_inline simdjson_result - get_raw_json_string() noexcept; - simdjson_really_inline simdjson_result get_bool() noexcept; - simdjson_really_inline simdjson_result get_value() noexcept; - - simdjson_really_inline bool is_null() noexcept; - simdjson_really_inline simdjson_result - raw_json() noexcept; - simdjson_really_inline operator document &() const noexcept; - -#if SIMDJSON_EXCEPTIONS - simdjson_really_inline operator array() & noexcept(false); - simdjson_really_inline operator object() & noexcept(false); - simdjson_really_inline operator uint64_t() noexcept(false); - simdjson_really_inline operator int64_t() noexcept(false); - simdjson_really_inline operator double() noexcept(false); - simdjson_really_inline operator std::string_view() noexcept(false); - simdjson_really_inline operator raw_json_string() noexcept(false); - simdjson_really_inline operator bool() noexcept(false); - simdjson_really_inline operator value() noexcept(false); -#endif - simdjson_really_inline simdjson_result count_elements() & noexcept; - simdjson_really_inline simdjson_result count_fields() & noexcept; - simdjson_really_inline simdjson_result at(size_t index) & noexcept; - simdjson_really_inline simdjson_result begin() & noexcept; - simdjson_really_inline simdjson_result end() & noexcept; - simdjson_really_inline simdjson_result find_field( - std::string_view key) & - noexcept; - simdjson_really_inline simdjson_result find_field(const char *key) & - noexcept; - simdjson_really_inline simdjson_result operator[]( - std::string_view key) & - noexcept; - simdjson_really_inline simdjson_result operator[](const char *key) & - noexcept; - simdjson_really_inline simdjson_result find_field_unordered( - std::string_view key) & - noexcept; - simdjson_really_inline simdjson_result find_field_unordered( - const char *key) & - noexcept; - - simdjson_really_inline simdjson_result type() noexcept; - simdjson_really_inline simdjson_result is_scalar() noexcept; - - simdjson_really_inline simdjson_result - current_location() noexcept; - simdjson_really_inline bool is_negative() noexcept; - simdjson_really_inline simdjson_result is_integer() noexcept; - simdjson_really_inline simdjson_result - get_number_type() noexcept; - simdjson_really_inline simdjson_result get_number() noexcept; - simdjson_really_inline simdjson_result - raw_json_token() noexcept; - simdjson_really_inline simdjson_result at_pointer( - std::string_view json_pointer) noexcept; - - private: - document *doc{nullptr}; -}; -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -template <> -struct simdjson_result - : public SIMDJSON_BUILTIN_IMPLEMENTATION:: - implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document> { - public: - simdjson_really_inline simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document - &&value) noexcept; ///< @private - simdjson_really_inline simdjson_result( - error_code error) noexcept; ///< @private - simdjson_really_inline simdjson_result() noexcept = default; - simdjson_really_inline error_code rewind() noexcept; - - simdjson_really_inline - simdjson_result - get_array() & noexcept; - simdjson_really_inline - simdjson_result - get_object() & noexcept; - simdjson_really_inline simdjson_result get_uint64() noexcept; - simdjson_really_inline simdjson_result get_int64() noexcept; - simdjson_really_inline simdjson_result get_double() noexcept; - simdjson_really_inline simdjson_result - get_double_from_string() noexcept; - simdjson_really_inline simdjson_result - get_string() noexcept; - simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::raw_json_string> - get_raw_json_string() noexcept; - simdjson_really_inline simdjson_result get_bool() noexcept; - simdjson_really_inline - simdjson_result - get_value() noexcept; - simdjson_really_inline bool is_null() noexcept; - - template - simdjson_really_inline simdjson_result get() & noexcept; - template - simdjson_really_inline simdjson_result get() && noexcept; - - template - simdjson_really_inline error_code get(T &out) & noexcept; - template - simdjson_really_inline error_code get(T &out) && noexcept; - -#if SIMDJSON_EXCEPTIONS - simdjson_really_inline - operator SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array() & - noexcept(false); - simdjson_really_inline - operator SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object() & - noexcept(false); - simdjson_really_inline operator uint64_t() noexcept(false); - simdjson_really_inline operator int64_t() noexcept(false); - simdjson_really_inline operator double() noexcept(false); - simdjson_really_inline operator std::string_view() noexcept(false); - simdjson_really_inline operator SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand:: - raw_json_string() noexcept(false); - simdjson_really_inline operator bool() noexcept(false); - simdjson_really_inline - operator SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value() noexcept(false); -#endif - simdjson_really_inline simdjson_result count_elements() & noexcept; - simdjson_really_inline simdjson_result count_fields() & noexcept; - simdjson_really_inline - simdjson_result - at(size_t index) & noexcept; - simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator> - begin() & noexcept; - simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator> - end() & noexcept; - simdjson_really_inline - simdjson_result - find_field(std::string_view key) & noexcept; - simdjson_really_inline - simdjson_result - find_field(const char *key) & noexcept; - simdjson_really_inline - simdjson_result - operator[](std::string_view key) & noexcept; - simdjson_really_inline - simdjson_result - operator[](const char *key) & noexcept; - simdjson_really_inline - simdjson_result - find_field_unordered(std::string_view key) & noexcept; - simdjson_really_inline - simdjson_result - find_field_unordered(const char *key) & noexcept; - simdjson_really_inline - simdjson_result - type() noexcept; - simdjson_really_inline simdjson_result is_scalar() noexcept; - simdjson_really_inline simdjson_result - current_location() noexcept; - simdjson_really_inline bool is_negative() noexcept; - simdjson_really_inline simdjson_result is_integer() noexcept; - simdjson_really_inline - simdjson_result - get_number_type() noexcept; - simdjson_really_inline - simdjson_result - get_number() noexcept; - /** @copydoc simdjson_really_inline std::string_view - * document::raw_json_token() const noexcept */ - simdjson_really_inline simdjson_result - raw_json_token() noexcept; - - simdjson_really_inline - simdjson_result - at_pointer(std::string_view json_pointer) noexcept; -}; - - -} // namespace simdjson - - -namespace simdjson { - -template <> -struct simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference> - : public SIMDJSON_BUILTIN_IMPLEMENTATION:: - implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference> { - public: - simdjson_really_inline simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference value, - error_code error) noexcept; - simdjson_really_inline simdjson_result() noexcept = default; - simdjson_really_inline error_code rewind() noexcept; - - simdjson_really_inline - simdjson_result - get_array() & noexcept; - simdjson_really_inline - simdjson_result - get_object() & noexcept; - simdjson_really_inline simdjson_result get_uint64() noexcept; - simdjson_really_inline simdjson_result get_int64() noexcept; - simdjson_really_inline simdjson_result get_double() noexcept; - simdjson_really_inline simdjson_result - get_string() noexcept; - simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::raw_json_string> - get_raw_json_string() noexcept; - simdjson_really_inline simdjson_result get_bool() noexcept; - simdjson_really_inline - simdjson_result - get_value() noexcept; - simdjson_really_inline bool is_null() noexcept; - -#if SIMDJSON_EXCEPTIONS - simdjson_really_inline - operator SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array() & - noexcept(false); - simdjson_really_inline - operator SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object() & - noexcept(false); - simdjson_really_inline operator uint64_t() noexcept(false); - simdjson_really_inline operator int64_t() noexcept(false); - simdjson_really_inline operator double() noexcept(false); - simdjson_really_inline operator std::string_view() noexcept(false); - simdjson_really_inline operator SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand:: - raw_json_string() noexcept(false); - simdjson_really_inline operator bool() noexcept(false); - simdjson_really_inline - operator SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value() noexcept(false); -#endif - simdjson_really_inline simdjson_result count_elements() & noexcept; - simdjson_really_inline simdjson_result count_fields() & noexcept; - simdjson_really_inline - simdjson_result - at(size_t index) & noexcept; - simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator> - begin() & noexcept; - simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator> - end() & noexcept; - simdjson_really_inline - simdjson_result - find_field(std::string_view key) & noexcept; - simdjson_really_inline - simdjson_result - find_field(const char *key) & noexcept; - simdjson_really_inline - simdjson_result - operator[](std::string_view key) & noexcept; - simdjson_really_inline - simdjson_result - operator[](const char *key) & noexcept; - simdjson_really_inline - simdjson_result - find_field_unordered(std::string_view key) & noexcept; - simdjson_really_inline - simdjson_result - find_field_unordered(const char *key) & noexcept; - simdjson_really_inline - simdjson_result - type() noexcept; - simdjson_really_inline simdjson_result is_scalar() noexcept; - simdjson_really_inline simdjson_result - current_location() noexcept; - simdjson_really_inline bool is_negative() noexcept; - simdjson_really_inline simdjson_result is_integer() noexcept; - simdjson_really_inline - simdjson_result - get_number_type() noexcept; - simdjson_really_inline - simdjson_result - get_number() noexcept; - /** @copydoc simdjson_really_inline std::string_view - * document_reference::raw_json_token() const noexcept */ - simdjson_really_inline simdjson_result - raw_json_token() noexcept; - - simdjson_really_inline - simdjson_result - at_pointer(std::string_view json_pointer) noexcept; -}; - - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/document.h */ -/* begin file include/simdjson/generic/ondemand/value.h */ - -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -class array; -class document; -class field; -class object; -class raw_json_string; - -/** - * An ephemeral JSON value returned during iteration. - */ -class value { - public: - /** - * Create a new invalid value. - * - * Exists so you can declare a variable and later assign to it before use. - */ - simdjson_really_inline value() noexcept = default; - - /** - * Get this value as the given type. - * - * Supported types: object, array, raw_json_string, string_view, uint64_t, - * int64_t, double, bool - * - * You may use get_double(), get_bool(), get_uint64(), get_int64(), - * get_object(), get_array(), get_raw_json_string(), or get_string() - * instead. - * - * @returns A value of the given type, parsed from the JSON. - * @returns INCORRECT_TYPE If the JSON value is not the given type. - */ - template - simdjson_really_inline simdjson_result get() noexcept { - // Unless the simdjson library provides an inline implementation, - // calling this method should - // immediately fail. - static_assert(!sizeof(T), - "The get method with given type is not implemented by " - "the simdjson library."); - } - - /** - * Get this value as the given type. - * - * Supported types: object, array, raw_json_string, string_view, uint64_t, - * int64_t, double, bool - * - * @param out This is set to a value of the given type, parsed from the - * JSON. If there is an error, this may not be initialized. - * @returns INCORRECT_TYPE If the JSON value is not an object. - * @returns SUCCESS If the parse succeeded and the out parameter was set to - * the value. - */ - template - simdjson_really_inline error_code get(T &out) noexcept; - - /** - * Cast this JSON value to an array. - * - * @returns An object that can be used to iterate the array. - * @returns INCORRECT_TYPE If the JSON value is not an array. - */ - simdjson_really_inline simdjson_result get_array() noexcept; - - /** - * Cast this JSON value to an object. - * - * @returns An object that can be used to look up or iterate fields. - * @returns INCORRECT_TYPE If the JSON value is not an object. - */ - simdjson_really_inline simdjson_result get_object() noexcept; - - /** - * Cast this JSON value to an unsigned integer. - * - * @returns A unsigned 64-bit integer. - * @returns INCORRECT_TYPE If the JSON value is not a 64-bit unsigned - * integer. - */ - simdjson_really_inline simdjson_result get_uint64() noexcept; - - /** - * Cast this JSON value (inside string) to a unsigned integer. - * - * @returns A unsigned 64-bit integer. - * @returns INCORRECT_TYPE If the JSON value is not a 64-bit unsigned - * integer. - */ - simdjson_really_inline simdjson_result - get_uint64_in_string() noexcept; - - /** - * Cast this JSON value to a signed integer. - * - * @returns A signed 64-bit integer. - * @returns INCORRECT_TYPE If the JSON value is not a 64-bit integer. - */ - simdjson_really_inline simdjson_result get_int64() noexcept; - - /** - * Cast this JSON value (inside string) to a signed integer. - * - * @returns A signed 64-bit integer. - * @returns INCORRECT_TYPE If the JSON value is not a 64-bit integer. - */ - simdjson_really_inline simdjson_result - get_int64_in_string() noexcept; - - /** - * Cast this JSON value to a double. - * - * @returns A double. - * @returns INCORRECT_TYPE If the JSON value is not a valid floating-point - * number. - */ - simdjson_really_inline simdjson_result get_double() noexcept; - - /** - * Cast this JSON value (inside string) to a double - * - * @returns A double. - * @returns INCORRECT_TYPE If the JSON value is not a valid floating-point - * number. - */ - simdjson_really_inline simdjson_result - get_double_in_string() noexcept; - - /** - * Cast this JSON value to a string. - * - * The string is guaranteed to be valid UTF-8. - * - * Equivalent to get(). - * - * Important: a value should be consumed once. Calling get_string() twice on - * the same value - * is an error. - * - * @returns An UTF-8 string. The string is stored in the parser and will be - * invalidated the next - * time it parses a document or when it is destroyed. - * @returns INCORRECT_TYPE if the JSON value is not a string. - */ - simdjson_really_inline simdjson_result - get_string() noexcept; - - /** - * Cast this JSON value to a raw_json_string. - * - * The string is guaranteed to be valid UTF-8, and may have escapes in it - * (e.g. \\ or \n). - * - * @returns A pointer to the raw JSON for the given string. - * @returns INCORRECT_TYPE if the JSON value is not a string. - */ - simdjson_really_inline simdjson_result - get_raw_json_string() noexcept; - - /** - * Cast this JSON value to a bool. - * - * @returns A bool value. - * @returns INCORRECT_TYPE if the JSON value is not true or false. - */ - simdjson_really_inline simdjson_result get_bool() noexcept; - - /** - * Checks if this JSON value is null. - * - * @returns Whether the value is null. - */ - simdjson_really_inline bool is_null() noexcept; - -#if SIMDJSON_EXCEPTIONS - /** - * Cast this JSON value to an array. - * - * @returns An object that can be used to iterate the array. - * @exception simdjson_error(INCORRECT_TYPE) If the JSON value is not an - * array. - */ - simdjson_really_inline operator array() noexcept(false); - /** - * Cast this JSON value to an object. - * - * @returns An object that can be used to look up or iterate fields. - * @exception simdjson_error(INCORRECT_TYPE) If the JSON value is not an - * object. - */ - simdjson_really_inline operator object() noexcept(false); - /** - * Cast this JSON value to an unsigned integer. - * - * @returns A signed 64-bit integer. - * @exception simdjson_error(INCORRECT_TYPE) If the JSON value is not a - * 64-bit unsigned integer. - */ - simdjson_really_inline operator uint64_t() noexcept(false); - /** - * Cast this JSON value to a signed integer. - * - * @returns A signed 64-bit integer. - * @exception simdjson_error(INCORRECT_TYPE) If the JSON value is not a - * 64-bit integer. - */ - simdjson_really_inline operator int64_t() noexcept(false); - /** - * Cast this JSON value to a double. - * - * @returns A double. - * @exception simdjson_error(INCORRECT_TYPE) If the JSON value is not a - * valid floating-point number. - */ - simdjson_really_inline operator double() noexcept(false); - /** - * Cast this JSON value to a string. - * - * The string is guaranteed to be valid UTF-8. - * - * Equivalent to get(). - * - * @returns An UTF-8 string. The string is stored in the parser and will be - * invalidated the next - * time it parses a document or when it is destroyed. - * @exception simdjson_error(INCORRECT_TYPE) if the JSON value is not a - * string. - */ - simdjson_really_inline operator std::string_view() noexcept(false); - /** - * Cast this JSON value to a raw_json_string. - * - * The string is guaranteed to be valid UTF-8, and may have escapes in it - * (e.g. \\ or \n). - * - * @returns A pointer to the raw JSON for the given string. - * @exception simdjson_error(INCORRECT_TYPE) if the JSON value is not a - * string. - */ - simdjson_really_inline operator raw_json_string() noexcept(false); - /** - * Cast this JSON value to a bool. - * - * @returns A bool value. - * @exception simdjson_error(INCORRECT_TYPE) if the JSON value is not true - * or false. - */ - simdjson_really_inline operator bool() noexcept(false); -#endif - - /** - * Begin array iteration. - * - * Part of the std::iterable interface. - * - * @returns INCORRECT_TYPE If the JSON value is not an array. - */ - simdjson_really_inline simdjson_result begin() & noexcept; - /** - * Sentinel representing the end of the array. - * - * Part of the std::iterable interface. - */ - simdjson_really_inline simdjson_result end() & noexcept; - /** - * This method scans the array and counts the number of elements. - * The count_elements method should always be called before you have begun - * iterating through the array: it is expected that you are pointing at - * the beginning of the array. - * The runtime complexity is linear in the size of the array. After - * calling this function, if successful, the array is 'rewinded' at its - * beginning as if it had never been accessed. If the JSON is malformed - * (e.g., - * there is a missing comma), then an error is returned and it is no longer - * safe to continue. - */ - simdjson_really_inline simdjson_result count_elements() & noexcept; - /** - * This method scans the object and counts the number of key-value pairs. - * The count_fields method should always be called before you have begun - * iterating through the object: it is expected that you are pointing at - * the beginning of the object. - * The runtime complexity is linear in the size of the object. After - * calling this function, if successful, the object is 'rewinded' at its - * beginning as if it had never been accessed. If the JSON is malformed - * (e.g., - * there is a missing comma), then an error is returned and it is no longer - * safe to continue. - * - * To check that an object is empty, it is more performant to use - * the is_empty() method on the object instance. - */ - simdjson_really_inline simdjson_result count_fields() & noexcept; - /** - * Get the value at the given index in the array. This function has - * linear-time complexity. - * This function should only be called once on an array instance since the - * array iterator is not reset between each call. - * - * @return The value at the given index, or: - * - INDEX_OUT_OF_BOUNDS if the array index is larger than an array - * length - */ - simdjson_really_inline simdjson_result at(size_t index) noexcept; - /** - * Look up a field by name on an object (order-sensitive). - * - * The following code reads z, then y, then x, and thus will not retrieve x - or y if fed the - * JSON `{ "x": 1, "y": 2, "z": 3 }`: - * - * ```c++ - * simdjson::ondemand::parser parser; - * auto obj = parser.parse(R"( { "x": 1, "y": 2, "z": 3 } )"_padded); - * double z = obj.find_field("z"); - * double y = obj.find_field("y"); - * double x = obj.find_field("x"); - * ``` - * If you have multiple fields with a matching key ({"x": 1, "x": 1}) be - mindful - * that only one field is returned. - - * **Raw Keys:** The lookup will be done against the *raw* key, and will not - unescape keys. - * e.g. `object["a"]` will match `{ "a": 1 }`, but will *not* match `{ - "\u0061": 1 }`. - * - * @param key The key to look up. - * @returns The value of the field, or NO_SUCH_FIELD if the field is not in - the object. - */ - simdjson_really_inline simdjson_result find_field( - std::string_view key) noexcept; - /** @overload simdjson_really_inline simdjson_result - * find_field(std::string_view key) noexcept; */ - simdjson_really_inline simdjson_result find_field( - const char *key) noexcept; - - /** - * Look up a field by name on an object, without regard to key order. - * - * **Performance Notes:** This is a bit less performant than find_field(), - * though its effect varies - * and often appears negligible. It starts out normally, starting out at the - * last field; but if - * the field is not found, it scans from the beginning of the object to see - * if it missed it. That - * missing case has a non-cache-friendly bump and lots of extra scanning, - * especially if the object - * in question is large. The fact that the extra code is there also bumps - * the executable size. - * - * It is the default, however, because it would be highly surprising (and - * hard to debug) if the - * default behavior failed to look up a field just because it was in the - * wrong order--and many - * APIs assume this. Therefore, you must be explicit if you want to treat - * objects as out of order. - * - * If you have multiple fields with a matching key ({"x": 1, "x": 1}) be - * mindful - * that only one field is returned. - * - * Use find_field() if you are sure fields will be in order (or are willing - * to treat it as if the - * field wasn't there when they aren't). - * - * @param key The key to look up. - * @returns The value of the field, or NO_SUCH_FIELD if the field is not in - * the object. - */ - simdjson_really_inline simdjson_result find_field_unordered( - std::string_view key) noexcept; - /** @overload simdjson_really_inline simdjson_result - * find_field_unordered(std::string_view key) noexcept; */ - simdjson_really_inline simdjson_result find_field_unordered( - const char *key) noexcept; - /** @overload simdjson_really_inline simdjson_result - * find_field_unordered(std::string_view key) noexcept; */ - simdjson_really_inline simdjson_result operator[]( - std::string_view key) noexcept; - /** @overload simdjson_really_inline simdjson_result - * find_field_unordered(std::string_view key) noexcept; */ - simdjson_really_inline simdjson_result operator[]( - const char *key) noexcept; - - /** - * Get the type of this JSON value. - * - * NOTE: If you're only expecting a value to be one type (a typical case), - * it's generally - * better to just call .get_double, .get_string, etc. and check for - * INCORRECT_TYPE (or just - * let it throw an exception). - * - * @return The type of JSON value (json_type::array, json_type::object, - * json_type::string, - * json_type::number, json_type::boolean, or json_type::null). - * @error TAPE_ERROR when the JSON value is a bad token like "}" "," or - * "alse". - */ - simdjson_really_inline simdjson_result type() noexcept; - - /** - * Checks whether the value is a scalar (string, number, null, Boolean). - * Returns false when there it is an array or object. - * - * @returns true if the type is string, number, null, Boolean - * @error TAPE_ERROR when the JSON value is a bad token like "}" "," or - * "alse". - */ - simdjson_really_inline simdjson_result is_scalar() noexcept; - - /** - * Checks whether the value is a negative number. - * - * @returns true if the number if negative. - */ - simdjson_really_inline bool is_negative() noexcept; - /** - * Checks whether the value is an integer number. Note that - * this requires to partially parse the number string. If - * the value is determined to be an integer, it may still - * not parse properly as an integer in subsequent steps - * (e.g., it might overflow). - * - * Performance note: if you call this function systematically - * before parsing a number, you may have fallen for a performance - * anti-pattern. - * - * @returns true if the number if negative. - */ - simdjson_really_inline simdjson_result is_integer() noexcept; - /** - * Determine the number type (integer or floating-point number). - * - * get_number_type() is number_type::unsigned_integer if we have - * an integer greater or equal to 9223372036854775808 - * get_number_type() is number_type::signed_integer if we have an - * integer that is less than 9223372036854775808 - * Otherwise, get_number_type() has value number_type::floating_point_number - * - * This function requires processing the number string, but it is expected - * to be faster than get_number().get_number_type() because it is does not - * parse the number value. - * - * @returns the type of the number - */ - simdjson_really_inline simdjson_result - get_number_type() noexcept; - - /** - * Attempt to parse an ondemand::number. An ondemand::number may - * contain an integer value or a floating-point value, the simdjson - * library will autodetect the type. Thus it is a dynamically typed - * number. Before accessing the value, you must determine the detected - * type. - * - * number.get_number_type() is number_type::signed_integer if we have - * an integer in [-9223372036854775808,9223372036854775808) - * You can recover the value by calling number.get_int64() and you - * have that number.is_int64() is true. - * - * number.get_number_type() is number_type::unsigned_integer if we have - * an integer in [9223372036854775808,18446744073709551616) - * You can recover the value by calling number.get_uint64() and you - * have that number.is_uint64() is true. - * - * Otherwise, number.get_number_type() has value - * number_type::floating_point_number - * and we have a binary64 number. - * You can recover the value by calling number.get_double() and you - * have that number.is_double() is true. - * - * You must check the type before accessing the value: it is an error - * to call "get_int64()" when number.get_number_type() is not - * number_type::signed_integer and when number.is_int64() is false. - * - * Performance note: this is designed with performance in mind. When - * calling 'get_number()', you scan the number string only once, determining - * efficiently the type and storing it in an efficient manner. - */ - simdjson_warn_unused simdjson_really_inline simdjson_result - get_number() noexcept; - - - /** - * Get the raw JSON for this token. - * - * The string_view will always point into the input buffer. - * - * The string_view will start at the beginning of the token, and include the - * entire token - * *as well as all spaces until the next token (or EOF).* This means, for - * example, that a - * string token always begins with a " and is always terminated by the final - * ", possibly - * followed by a number of spaces. - * - * The string_view is *not* null-terminated. However, if this is a scalar - * (string, number, - * boolean, or null), the character after the end of the string_view is - * guaranteed to be - * a non-space token. - * - * Tokens include: - * - { - * - [ - * - "a string (possibly with UTF-8 or backslashed characters like \\\")". - * - -1.2e-100 - * - true - * - false - * - null - */ - simdjson_really_inline std::string_view raw_json_token() noexcept; - - /** - * Returns the current location in the document if in bounds. - */ - simdjson_really_inline simdjson_result - current_location() noexcept; - - /** - * Get the value associated with the given JSON pointer. We use the RFC - * 6901 - * https://tools.ietf.org/html/rfc6901 standard. - * - * ondemand::parser parser; - * auto json = R"({ "foo": { "a": [ 10, 20, 30 ] }})"_padded; - * auto doc = parser.iterate(json); - * doc.at_pointer("/foo/a/1") == 20 - * - * It is allowed for a key to be the empty string: - * - * ondemand::parser parser; - * auto json = R"({ "": { "a": [ 10, 20, 30 ] }})"_padded; - * auto doc = parser.iterate(json); - * doc.at_pointer("//a/1") == 20 - * - * Note that at_pointer() called on the document automatically calls the - * document's rewind - * method between each call. It invalidates all previously accessed arrays, - * objects and values - * that have not been consumed. - * - * Calling at_pointer() on non-document instances (e.g., arrays and objects) - * is not - * standardized (by RFC 6901). We provide some experimental support for JSON - * pointers - * on non-document instances. Yet it is not the case when calling - * at_pointer on an array - * or an object instance: there is no rewind and no invalidation. - * - * You may only call at_pointer on an array after it has been created, but - * before it has - * been first accessed. When calling at_pointer on an array, the pointer is - * advanced to - * the location indicated by the JSON pointer (in case of success). It is no - * longer possible - * to call at_pointer on the same array. - * - * You may call at_pointer more than once on an object, but each time the - * pointer is advanced - * to be within the value matched by the key indicated by the JSON pointer - * query. Thus any preceeding - * key (as well as the current key) can no longer be used with following - * JSON pointer calls. - * - * Also note that at_pointer() relies on find_field() which implies that we - * do not unescape keys when matching - * - * @return The value associated with the given JSON pointer, or: - * - NO_SUCH_FIELD if a field does not exist in an object - * - INDEX_OUT_OF_BOUNDS if an array index is larger than an array - * length - * - INCORRECT_TYPE if a non-integer is used to access an array - * - INVALID_JSON_POINTER if the JSON pointer is invalid and cannot - * be parsed - */ - simdjson_really_inline simdjson_result at_pointer( - std::string_view json_pointer) noexcept; - - protected: - /** - * Create a value. - */ - simdjson_really_inline value(const value_iterator &iter) noexcept; - - /** - * Skip this value, allowing iteration to continue. - */ - simdjson_really_inline void skip() noexcept; - - /** - * Start a value at the current position. - * - * (It should already be started; this is just a self-documentation method.) - */ - static simdjson_really_inline value - start(const value_iterator &iter) noexcept; - - /** - * Resume a value. - */ - static simdjson_really_inline value - resume(const value_iterator &iter) noexcept; - - /** - * Get the object, starting or resuming it as necessary - */ - simdjson_really_inline simdjson_result - start_or_resume_object() noexcept; - - // simdjson_really_inline void log_value(const char *type) const noexcept; - // simdjson_really_inline void log_error(const char *message) const - // noexcept; - - value_iterator iter{}; - - friend class document; - friend class array_iterator; - friend class field; - friend class object; - friend struct simdjson_result; - friend struct simdjson_result; -}; - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -template <> -struct simdjson_result - : public SIMDJSON_BUILTIN_IMPLEMENTATION:: - implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value> { - public: - simdjson_really_inline simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value - &&value) noexcept; ///< @private - simdjson_really_inline simdjson_result( - error_code error) noexcept; ///< @private - simdjson_really_inline simdjson_result() noexcept = default; - - simdjson_really_inline - simdjson_result - get_array() noexcept; - simdjson_really_inline - simdjson_result - get_object() noexcept; - - simdjson_really_inline simdjson_result get_uint64() noexcept; - simdjson_really_inline simdjson_result - get_uint64_in_string() noexcept; - simdjson_really_inline simdjson_result get_int64() noexcept; - simdjson_really_inline simdjson_result - get_int64_in_string() noexcept; - simdjson_really_inline simdjson_result get_double() noexcept; - simdjson_really_inline simdjson_result - get_double_in_string() noexcept; - simdjson_really_inline simdjson_result - get_string() noexcept; - simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::raw_json_string> - get_raw_json_string() noexcept; - simdjson_really_inline simdjson_result get_bool() noexcept; - simdjson_really_inline bool is_null() noexcept; - - template - simdjson_really_inline simdjson_result get() noexcept; - - template - simdjson_really_inline error_code get(T &out) noexcept; - -#if SIMDJSON_EXCEPTIONS - simdjson_really_inline - operator SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array() noexcept(false); - simdjson_really_inline - operator SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object() noexcept( - false); - simdjson_really_inline operator uint64_t() noexcept(false); - simdjson_really_inline operator int64_t() noexcept(false); - simdjson_really_inline operator double() noexcept(false); - simdjson_really_inline operator std::string_view() noexcept(false); - simdjson_really_inline operator SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand:: - raw_json_string() noexcept(false); - simdjson_really_inline operator bool() noexcept(false); -#endif - simdjson_really_inline simdjson_result count_elements() & noexcept; - simdjson_really_inline simdjson_result count_fields() & noexcept; - simdjson_really_inline - simdjson_result - at(size_t index) noexcept; - simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator> - begin() & noexcept; - simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator> - end() & noexcept; - - /** - * Look up a field by name on an object (order-sensitive). - * - * The following code reads z, then y, then x, and thus will not retrieve x - * or y if fed the - * JSON `{ "x": 1, "y": 2, "z": 3 }`: - * - * ```c++ - * simdjson::ondemand::parser parser; - * auto obj = parser.parse(R"( { "x": 1, "y": 2, "z": 3 } )"_padded); - * double z = obj.find_field("z"); - * double y = obj.find_field("y"); - * double x = obj.find_field("x"); - * ``` - * - * **Raw Keys:** The lookup will be done against the *raw* key, and will not - * unescape keys. - * e.g. `object["a"]` will match `{ "a": 1 }`, but will *not* match `{ - * "\u0061": 1 }`. - * - * @param key The key to look up. - * @returns The value of the field, or NO_SUCH_FIELD if the field is not in - * the object. - */ - simdjson_really_inline - simdjson_result - find_field(std::string_view key) noexcept; - /** @overload simdjson_really_inline - * simdjson_result - * find_field(std::string_view key) noexcept; */ - simdjson_really_inline - simdjson_result - find_field(const char *key) noexcept; - - /** - * Look up a field by name on an object, without regard to key order. - * - * **Performance Notes:** This is a bit less performant than find_field(), - * though its effect varies - * and often appears negligible. It starts out normally, starting out at the - * last field; but if - * the field is not found, it scans from the beginning of the object to see - * if it missed it. That - * missing case has a non-cache-friendly bump and lots of extra scanning, - * especially if the object - * in question is large. The fact that the extra code is there also bumps - * the executable size. - * - * It is the default, however, because it would be highly surprising (and - * hard to debug) if the - * default behavior failed to look up a field just because it was in the - * wrong order--and many - * APIs assume this. Therefore, you must be explicit if you want to treat - * objects as out of order. - * - * Use find_field() if you are sure fields will be in order (or are willing - * to treat it as if the - * field wasn't there when they aren't). - * - * @param key The key to look up. - * @returns The value of the field, or NO_SUCH_FIELD if the field is not in - * the object. - */ - simdjson_really_inline - simdjson_result - find_field_unordered(std::string_view key) noexcept; - /** @overload simdjson_really_inline - * simdjson_result - * find_field_unordered(std::string_view key) noexcept; */ - simdjson_really_inline - simdjson_result - find_field_unordered(const char *key) noexcept; - /** @overload simdjson_really_inline - * simdjson_result - * find_field_unordered(std::string_view key) noexcept; */ - simdjson_really_inline - simdjson_result - operator[](std::string_view key) noexcept; - /** @overload simdjson_really_inline - * simdjson_result - * find_field_unordered(std::string_view key) noexcept; */ - simdjson_really_inline - simdjson_result - operator[](const char *key) noexcept; - - /** - * Get the type of this JSON value. - * - * NOTE: If you're only expecting a value to be one type (a typical case), - * it's generally - * better to just call .get_double, .get_string, etc. and check for - * INCORRECT_TYPE (or just - * let it throw an exception). - */ - simdjson_really_inline - simdjson_result - type() noexcept; - simdjson_really_inline simdjson_result is_scalar() noexcept; - simdjson_really_inline simdjson_result is_negative() noexcept; - simdjson_really_inline simdjson_result is_integer() noexcept; - simdjson_really_inline - simdjson_result - get_number_type() noexcept; - simdjson_really_inline - simdjson_result - get_number() noexcept; - - /** @copydoc simdjson_really_inline std::string_view value::raw_json_token() - * const noexcept */ - simdjson_really_inline simdjson_result - raw_json_token() noexcept; - - /** @copydoc simdjson_really_inline simdjson_result - * current_location() noexcept */ - simdjson_really_inline simdjson_result - current_location() noexcept; - - simdjson_really_inline - simdjson_result - at_pointer(std::string_view json_pointer) noexcept; -}; - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/value.h */ -/* begin file include/simdjson/generic/ondemand/field.h */ - -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -/** - * A JSON field (key/value pair) in an object. - * - * Returned from object iteration. - * - * Extends from std::pair so you can use C++ algorithms - * that rely on pairs. - */ -class field : public std::pair { - public: - /** - * Create a new invalid field. - * - * Exists so you can declare a variable and later assign to it before use. - */ - simdjson_really_inline field() noexcept; - - /** - * Get the key as a string_view (for higher speed, consider raw_key). - * We deliberately use a more cumbersome name (unescaped_key) to force users - * to think twice about using it. - * - * This consumes the key: once you have called unescaped_key(), you cannot - * call it again nor can you call key(). - */ - simdjson_really_inline simdjson_warn_unused - simdjson_result - unescaped_key() noexcept; - /** - * Get the key as a raw_json_string. Can be used for direct comparison with - * an unescaped C string: e.g., key() == "test". - */ - simdjson_really_inline raw_json_string key() const noexcept; - /** - * Get the field value. - */ - simdjson_really_inline ondemand::value &value() & noexcept; - /** - * @overload ondemand::value &ondemand::value() & noexcept - */ - simdjson_really_inline ondemand::value value() && noexcept; - - protected: - simdjson_really_inline field(raw_json_string key, - ondemand::value &&value) noexcept; - static simdjson_really_inline simdjson_result start( - value_iterator &parent_iter) noexcept; - static simdjson_really_inline simdjson_result start( - const value_iterator &parent_iter, raw_json_string key) noexcept; - friend struct simdjson_result; - friend class object_iterator; -}; - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -template <> -struct simdjson_result - : public SIMDJSON_BUILTIN_IMPLEMENTATION:: - implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::field> { - public: - simdjson_really_inline simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::field - &&value) noexcept; ///< @private - simdjson_really_inline simdjson_result( - error_code error) noexcept; ///< @private - simdjson_really_inline simdjson_result() noexcept = default; - - simdjson_really_inline simdjson_result - unescaped_key() noexcept; - simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::raw_json_string> - key() noexcept; - simdjson_really_inline - simdjson_result - value() noexcept; -}; - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/field.h */ -/* begin file include/simdjson/generic/ondemand/object.h */ - -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -/** - * A forward-only JSON object field iterator. - */ -class object { - public: - /** - * Create a new invalid object. - * - * Exists so you can declare a variable and later assign to it before use. - */ - simdjson_really_inline object() noexcept = default; - - simdjson_really_inline simdjson_result begin() noexcept; - simdjson_really_inline simdjson_result end() noexcept; - /** - * Look up a field by name on an object (order-sensitive). - * - * The following code reads z, then y, then x, and thus will not retrieve x - * or y if fed the - * JSON `{ "x": 1, "y": 2, "z": 3 }`: - * - * ```c++ - * simdjson::ondemand::parser parser; - * auto obj = parser.parse(R"( { "x": 1, "y": 2, "z": 3 } )"_padded); - * double z = obj.find_field("z"); - * double y = obj.find_field("y"); - * double x = obj.find_field("x"); - * ``` - * If you have multiple fields with a matching key ({"x": 1, "x": 1}) be - * mindful - * that only one field is returned. - * - * **Raw Keys:** The lookup will be done against the *raw* key, and will not - * unescape keys. - * e.g. `object["a"]` will match `{ "a": 1 }`, but will *not* match `{ - * "\u0061": 1 }`. - * - * You must consume the fields on an object one at a time. A request for a - * new key - * invalidates previous field values: it makes them unsafe. E.g., the array - * given by content["bids"].get_array() should not be accessed after you - * have called - * content["asks"].get_array(). You can detect such mistakes by first - * compiling and running - * the code in Debug mode (or with the macro `SIMDJSON_DEVELOPMENT_CHECKS` - * set to 1): an - * OUT_OF_ORDER_ITERATION error is generated. - * - * You are expected to access keys only once. You should access the value - * corresponding to a - * key a single time. Doing object["mykey"].to_string() and then again - * object["mykey"].to_string() - * is an error. - * - * @param key The key to look up. - * @returns The value of the field, or NO_SUCH_FIELD if the field is not in - * the object. - */ - simdjson_really_inline simdjson_result find_field( - std::string_view key) & - noexcept; - /** @overload simdjson_really_inline simdjson_result - * find_field(std::string_view key) & noexcept; */ - simdjson_really_inline simdjson_result find_field( - std::string_view key) && - noexcept; - - /** - * Look up a field by name on an object, without regard to key order. - * - * **Performance Notes:** This is a bit less performant than find_field(), - * though its effect varies - * and often appears negligible. It starts out normally, starting out at the - * last field; but if - * the field is not found, it scans from the beginning of the object to see - * if it missed it. That - * missing case has a non-cache-friendly bump and lots of extra scanning, - * especially if the object - * in question is large. The fact that the extra code is there also bumps - * the executable size. - * - * It is the default, however, because it would be highly surprising (and - * hard to debug) if the - * default behavior failed to look up a field just because it was in the - * wrong order--and many - * APIs assume this. Therefore, you must be explicit if you want to treat - * objects as out of order. - * - * Use find_field() if you are sure fields will be in order (or are willing - * to treat it as if the - * field wasn't there when they aren't). - * - * If you have multiple fields with a matching key ({"x": 1, "x": 1}) be - * mindful - * that only one field is returned. - * - * You must consume the fields on an object one at a time. A request for a - * new key - * invalidates previous field values: it makes them unsafe. E.g., the array - * given by content["bids"].get_array() should not be accessed after you - * have called - * content["asks"].get_array(). You can detect such mistakes by first - * compiling and running - * the code in Debug mode (or with the macro `SIMDJSON_DEVELOPMENT_CHECKS` - * set to 1): an - * OUT_OF_ORDER_ITERATION error is generated. - * - * You are expected to access keys only once. You should access the value - * corresponding to a key - * a single time. Doing object["mykey"].to_string() and then again - * object["mykey"].to_string() is an error. - * - * @param key The key to look up. - * @returns The value of the field, or NO_SUCH_FIELD if the field is not in - * the object. - */ - simdjson_really_inline simdjson_result find_field_unordered( - std::string_view key) & - noexcept; - /** @overload simdjson_really_inline simdjson_result - * find_field_unordered(std::string_view key) & noexcept; */ - simdjson_really_inline simdjson_result find_field_unordered( - std::string_view key) && - noexcept; - /** @overload simdjson_really_inline simdjson_result - * find_field_unordered(std::string_view key) & noexcept; */ - simdjson_really_inline simdjson_result operator[]( - std::string_view key) & - noexcept; - /** @overload simdjson_really_inline simdjson_result - * find_field_unordered(std::string_view key) & noexcept; */ - simdjson_really_inline simdjson_result operator[]( - std::string_view key) && - noexcept; - - /** - * Get the value associated with the given JSON pointer. We use the RFC 6901 - * https://tools.ietf.org/html/rfc6901 standard, interpreting the current - * node - * as the root of its own JSON document. - * - * ondemand::parser parser; - * auto json = R"({ "foo": { "a": [ 10, 20, 30 ] }})"_padded; - * auto doc = parser.iterate(json); - * doc.at_pointer("/foo/a/1") == 20 - * - * It is allowed for a key to be the empty string: - * - * ondemand::parser parser; - * auto json = R"({ "": { "a": [ 10, 20, 30 ] }})"_padded; - * auto doc = parser.iterate(json); - * doc.at_pointer("//a/1") == 20 - * - * Note that at_pointer() called on the document automatically calls the - * document's rewind - * method between each call. It invalidates all previously accessed arrays, - * objects and values - * that have not been consumed. Yet it is not the case when calling - * at_pointer on an object - * instance: there is no rewind and no invalidation. - * - * You may call at_pointer more than once on an object, but each time the - * pointer is advanced - * to be within the value matched by the key indicated by the JSON pointer - * query. Thus any preceeding - * key (as well as the current key) can no longer be used with following - * JSON pointer calls. - * - * Also note that at_pointer() relies on find_field() which implies that we - * do not unescape keys when matching. - * - * @return The value associated with the given JSON pointer, or: - * - NO_SUCH_FIELD if a field does not exist in an object - * - INDEX_OUT_OF_BOUNDS if an array index is larger than an array - * length - * - INCORRECT_TYPE if a non-integer is used to access an array - * - INVALID_JSON_POINTER if the JSON pointer is invalid and cannot - * be parsed - */ - inline simdjson_result at_pointer( - std::string_view json_pointer) noexcept; - - /** - * Reset the iterator so that we are pointing back at the - * beginning of the object. You should still consume values only once even - * if you - * can iterate through the object more than once. If you unescape a string - * within - * the object more than once, you have unsafe code. Note that rewinding an - * object - * means that you may need to reparse it anew: it is not a free operation. - * - * @returns true if the object contains some elements (not empty) - */ - inline simdjson_result reset() & noexcept; - /** - * This method scans the beginning of the object and checks whether the - * object is empty. - * The runtime complexity is constant time. After - * calling this function, if successful, the object is 'rewinded' at its - * beginning as if it had never been accessed. If the JSON is malformed - * (e.g., - * there is a missing comma), then an error is returned and it is no longer - * safe to continue. - */ - inline simdjson_result is_empty() & noexcept; - /** - * This method scans the object and counts the number of key-value pairs. - * The count_fields method should always be called before you have begun - * iterating through the object: it is expected that you are pointing at - * the beginning of the object. - * The runtime complexity is linear in the size of the object. After - * calling this function, if successful, the object is 'rewinded' at its - * beginning as if it had never been accessed. If the JSON is malformed - * (e.g., - * there is a missing comma), then an error is returned and it is no longer - * safe to continue. - * - * To check that an object is empty, it is more performant to use - * the is_empty() method. - */ - simdjson_really_inline simdjson_result count_fields() & noexcept; - /** - * Consumes the object and returns a string_view instance corresponding to - * the - * object as represented in JSON. It points inside the original byte array - * containg - * the JSON document. - */ - simdjson_really_inline simdjson_result - raw_json() noexcept; - - protected: - /** - * Go to the end of the object, no matter where you are right now. - */ - simdjson_really_inline error_code consume() noexcept; - static simdjson_really_inline simdjson_result start( - value_iterator &iter) noexcept; - static simdjson_really_inline simdjson_result start_root( - value_iterator &iter) noexcept; - static simdjson_really_inline simdjson_result started( - value_iterator &iter) noexcept; - static simdjson_really_inline object - resume(const value_iterator &iter) noexcept; - simdjson_really_inline object(const value_iterator &iter) noexcept; - - simdjson_warn_unused simdjson_really_inline error_code - find_field_raw(const std::string_view key) noexcept; - - value_iterator iter{}; - - friend class value; - friend class document; - friend struct simdjson_result; -}; - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -template <> -struct simdjson_result - : public SIMDJSON_BUILTIN_IMPLEMENTATION:: - implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object> { - public: - simdjson_really_inline simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object - &&value) noexcept; ///< @private - simdjson_really_inline simdjson_result( - error_code error) noexcept; ///< @private - simdjson_really_inline simdjson_result() noexcept = default; - - simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object_iterator> - begin() noexcept; - simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object_iterator> - end() noexcept; - simdjson_really_inline - simdjson_result - find_field(std::string_view key) & noexcept; - simdjson_really_inline - simdjson_result - find_field(std::string_view key) && noexcept; - simdjson_really_inline - simdjson_result - find_field_unordered(std::string_view key) & noexcept; - simdjson_really_inline - simdjson_result - find_field_unordered(std::string_view key) && noexcept; - simdjson_really_inline - simdjson_result - operator[](std::string_view key) & noexcept; - simdjson_really_inline - simdjson_result - operator[](std::string_view key) && noexcept; - simdjson_really_inline - simdjson_result - at_pointer(std::string_view json_pointer) noexcept; - inline simdjson_result reset() noexcept; - inline simdjson_result is_empty() noexcept; - inline simdjson_result count_fields() & noexcept; -}; - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/object.h */ -/* begin file include/simdjson/generic/ondemand/parser.h */ - -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -class array; -class object; -class value; -class raw_json_string; -class document_stream; - -/** - * The default batch size for document_stream instances for this On Demand - * kernel. - * Note that different On Demand kernel may use a different DEFAULT_BATCH_SIZE - * value - * in the future. - */ -static constexpr size_t DEFAULT_BATCH_SIZE = 1000000; -/** - * Some adversary might try to set the batch size to 0 or 1, which might cause - * problems. - * We set a minimum of 32B since anything else is highly likely to be an error. - * In practice, - * most users will want a much larger batch size. - * - * All non-negative MINIMAL_BATCH_SIZE values should be 'safe' except that, - * obviously, no JSON - * document can ever span 0 or 1 byte and that very large values would create - * memory allocation issues. - */ -static constexpr size_t MINIMAL_BATCH_SIZE = 32; - -/** - * A JSON fragment iterator. - * - * This holds the actual iterator as well as the buffer for writing strings. - */ -class parser { - public: - /** - * Create a JSON parser. - * - * The new parser will have zero capacity. - */ - inline explicit parser( - size_t max_capacity = SIMDJSON_MAXSIZE_BYTES) noexcept; - - inline parser(parser &&other) noexcept = default; - simdjson_really_inline parser(const parser &other) = delete; - simdjson_really_inline parser &operator=(const parser &other) = delete; - simdjson_really_inline parser &operator=(parser &&other) noexcept = default; - - /** Deallocate the JSON parser. */ - inline ~parser() noexcept = default; - - /** - * Start iterating an on-demand JSON document. - * - * ondemand::parser parser; - * document doc = parser.iterate(json); - * - * It is expected that the content is a valid UTF-8 file, containing a valid - * JSON document. - * Otherwise the iterate method may return an error. In particular, the - * whole input should be - * valid: we do not attempt to tolerate incorrect content either before or - * after a JSON - * document. - * - * ### IMPORTANT: Validate what you use - * - * Calling iterate on an invalid JSON document may not immediately trigger - * an error. The call to - * iterate does not parse and validate the whole document. - * - * ### IMPORTANT: Buffer Lifetime - * - * Because parsing is done while you iterate, you *must* keep the JSON - * buffer around at least as - * long as the document iteration. - * - * ### IMPORTANT: Document Lifetime - * - * Only one iteration at a time can happen per parser, and the parser *must* - * be kept alive during - * iteration to ensure intermediate buffers can be accessed. Any document - * must be destroyed before - * you call parse() again or destroy the parser. - * - * ### REQUIRED: Buffer Padding - * - * The buffer must have at least SIMDJSON_PADDING extra allocated bytes. It - * does not matter what - * those bytes are initialized to, as long as they are allocated. - * - * @param json The JSON to parse. - * @param len The length of the JSON. - * @param capacity The number of bytes allocated in the JSON (must be at - * least len+SIMDJSON_PADDING). - * - * @return The document, or an error: - * - INSUFFICIENT_PADDING if the input has less than - * SIMDJSON_PADDING extra bytes. - * - MEMALLOC if realloc_if_needed the parser does not have enough - * capacity, and memory - * allocation fails. - * - EMPTY if the document is all whitespace. - * - UTF8_ERROR if the document is not valid UTF-8. - * - UNESCAPED_CHARS if a string contains control characters that - * must be escaped - * - UNCLOSED_STRING if there is an unclosed string in the document. - */ - simdjson_warn_unused simdjson_result iterate( - padded_string_view json) & - noexcept; - /** @overload simdjson_result iterate(padded_string_view json) & - * noexcept */ - simdjson_warn_unused simdjson_result iterate(const char *json, - size_t len, - size_t capacity) & - noexcept; - /** @overload simdjson_result iterate(padded_string_view json) & - * noexcept */ - simdjson_warn_unused simdjson_result iterate(const uint8_t *json, - size_t len, - size_t capacity) & - noexcept; - /** @overload simdjson_result iterate(padded_string_view json) & - * noexcept */ - simdjson_warn_unused simdjson_result iterate( - std::string_view json, size_t capacity) & - noexcept; - /** @overload simdjson_result iterate(padded_string_view json) & - * noexcept */ - simdjson_warn_unused simdjson_result iterate( - const std::string &json) & - noexcept; - /** @overload simdjson_result iterate(padded_string_view json) & - * noexcept */ - simdjson_warn_unused simdjson_result iterate( - const simdjson_result &json) & - noexcept; - /** @overload simdjson_result iterate(padded_string_view json) & - * noexcept */ - simdjson_warn_unused simdjson_result iterate( - const simdjson_result &json) & - noexcept; - /** @overload simdjson_result iterate(padded_string_view json) & - * noexcept */ - simdjson_warn_unused simdjson_result iterate( - padded_string &&json) &noexcept = delete; - - /** - * @private - * - * Start iterating an on-demand JSON document. - * - * ondemand::parser parser; - * json_iterator doc = parser.iterate(json); - * - * ### IMPORTANT: Buffer Lifetime - * - * Because parsing is done while you iterate, you *must* keep the JSON - * buffer around at least as - * long as the document iteration. - * - * ### IMPORTANT: Document Lifetime - * - * Only one iteration at a time can happen per parser, and the parser *must* - * be kept alive during - * iteration to ensure intermediate buffers can be accessed. Any document - * must be destroyed before - * you call parse() again or destroy the parser. - * - * The ondemand::document instance holds the iterator. The document must - * remain in scope - * while you are accessing instances of ondemand::value, ondemand::object, - * ondemand::array. - * - * ### REQUIRED: Buffer Padding - * - * The buffer must have at least SIMDJSON_PADDING extra allocated bytes. It - * does not matter what - * those bytes are initialized to, as long as they are allocated. - * - * @param json The JSON to parse. - * - * @return The iterator, or an error: - * - INSUFFICIENT_PADDING if the input has less than - * SIMDJSON_PADDING extra bytes. - * - MEMALLOC if realloc_if_needed the parser does not have enough - * capacity, and memory - * allocation fails. - * - EMPTY if the document is all whitespace. - * - UTF8_ERROR if the document is not valid UTF-8. - * - UNESCAPED_CHARS if a string contains control characters that - * must be escaped - * - UNCLOSED_STRING if there is an unclosed string in the document. - */ - simdjson_warn_unused simdjson_result iterate_raw( - padded_string_view json) & - noexcept; - - - /** - * Parse a buffer containing many JSON documents. - * - * auto json = R"({ "foo": 1 } { "foo": 2 } { "foo": 3 } )"_padded; - * ondemand::parser parser; - * ondemand::document_stream docs = parser.iterate_many(json); - * for (auto & doc : docs) { - * std::cout << doc["foo"] << std::endl; - * } - * // Prints 1 2 3 - * - * No copy of the input buffer is made. - * - * The function is lazy: it may be that no more than one JSON document at a - * time is parsed. - * - * The caller is responsabile to ensure that the input string data remains - * unchanged and is - * not deleted during the loop. - * - * ### Format - * - * The buffer must contain a series of one or more JSON documents, - * concatenated into a single - * buffer, separated by ASCII whitespace. It effectively parses until it has - * a fully valid document, - * then starts parsing the next document at that point. (It does this with - * more parallelism and - * lookahead than you might think, though.) - * - * documents that consist of an object or array may omit the whitespace - * between them, concatenating - * with no separator. Documents that consist of a single primitive (i.e. - * documents that are not - * arrays or objects) MUST be separated with ASCII whitespace. - * - * The characters inside a JSON document, and between JSON documents, must - * be valid Unicode (UTF-8). - * - * The documents must not exceed batch_size bytes (by default 1MB) or they - * will fail to parse. - * Setting batch_size to excessively large or excessively small values may - * impact negatively the - * performance. - * - * ### REQUIRED: Buffer Padding - * - * The buffer must have at least SIMDJSON_PADDING extra allocated bytes. It - * does not matter what - * those bytes are initialized to, as long as they are allocated. - * - * ### Threads - * - * When compiled with SIMDJSON_THREADS_ENABLED, this method will use a - * single thread under the - * hood to do some lookahead. - * - * ### Parser Capacity - * - * If the parser's current capacity is less than batch_size, it will - * allocate enough capacity - * to handle it (up to max_capacity). - * - * @param buf The concatenated JSON to parse. - * @param len The length of the concatenated JSON. - * @param batch_size The batch size to use. MUST be larger than the largest - * document. The sweet - * spot is cache-related: small enough to fit in cache, - * yet big enough to - * parse as many documents as possible in one tight loop. - * Defaults to 10MB, which has been a reasonable sweet - * spot in our tests. - * @return The stream, or an error. An empty input will yield 0 documents - * rather than an EMPTY error. Errors: - * - MEMALLOC if the parser does not have enough capacity and memory - * allocation fails - * - CAPACITY if the parser does not have enough capacity and - * batch_size > max_capacity. - * - other json errors if parsing fails. You should not rely on - * these errors to always the same for the - * same document: they may vary under runtime dispatch (so they - * may vary depending on your system and hardware). - */ - inline simdjson_result iterate_many( - const uint8_t *buf, - size_t len, - size_t batch_size = DEFAULT_BATCH_SIZE) noexcept; - /** @overload parse_many(const uint8_t *buf, size_t len, size_t batch_size) - */ - inline simdjson_result iterate_many( - const char *buf, - size_t len, - size_t batch_size = DEFAULT_BATCH_SIZE) noexcept; - /** @overload parse_many(const uint8_t *buf, size_t len, size_t batch_size) - */ - inline simdjson_result iterate_many( - const std::string &s, size_t batch_size = DEFAULT_BATCH_SIZE) noexcept; - inline simdjson_result iterate_many( - const std::string &&s, size_t batch_size) = delete; // unsafe - /** @overload parse_many(const uint8_t *buf, size_t len, size_t batch_size) - */ - inline simdjson_result iterate_many( - const padded_string &s, - size_t batch_size = DEFAULT_BATCH_SIZE) noexcept; - inline simdjson_result iterate_many( - const padded_string &&s, size_t batch_size) = delete; // unsafe - - /** @private We do not want to allow implicit conversion from C string to - * std::string. */ - simdjson_result iterate_many( - const char *buf, - size_t batch_size = DEFAULT_BATCH_SIZE) noexcept = delete; - - /** The capacity of this parser (the largest document it can process). */ - simdjson_really_inline size_t capacity() const noexcept; - /** The maximum capacity of this parser (the largest document it is allowed - * to process). */ - simdjson_really_inline size_t max_capacity() const noexcept; - simdjson_really_inline void set_max_capacity(size_t max_capacity) noexcept; - /** The maximum depth of this parser (the most deeply nested objects and - * arrays it can process). */ - simdjson_really_inline size_t max_depth() const noexcept; - - /** - * Ensure this parser has enough memory to process JSON documents up to - * `capacity` bytes in length - * and `max_depth` depth. - * - * @param capacity The new capacity. - * @param max_depth The new max_depth. Defaults to DEFAULT_MAX_DEPTH. - * @return The error, if there is one. - */ - simdjson_warn_unused error_code - allocate(size_t capacity, size_t max_depth = DEFAULT_MAX_DEPTH) noexcept; - -#ifdef SIMDJSON_THREADS_ENABLED - /** - * The parser instance can use threads when they are available to speed up - * some - * operations. It is enabled by default. Changing this attribute will change - * the - * behavior of the parser for future operations. - */ - bool threaded{true}; -#endif - - private: - /** @private [for benchmarking access] The implementation to use */ - std::unique_ptr implementation{}; - size_t _capacity{0}; - size_t _max_capacity; - size_t _max_depth{DEFAULT_MAX_DEPTH}; - std::unique_ptr string_buf{}; -#ifdef SIMDJSON_DEVELOPMENT_CHECKS - std::unique_ptr start_positions{}; -#endif - - friend class json_iterator; - friend class document_stream; -}; - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -template <> -struct simdjson_result - : public SIMDJSON_BUILTIN_IMPLEMENTATION:: - implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::parser> { - public: - simdjson_really_inline simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::parser - &&value) noexcept; ///< @private - simdjson_really_inline simdjson_result( - error_code error) noexcept; ///< @private - simdjson_really_inline simdjson_result() noexcept = default; -}; - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/parser.h */ -/* begin file include/simdjson/generic/ondemand/document_stream.h */ -#ifdef SIMDJSON_THREADS_ENABLED -#include -#include -#include -#endif - -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -class parser; -class json_iterator; -class document; - -#ifdef SIMDJSON_THREADS_ENABLED -/** @private Custom worker class **/ -struct stage1_worker { - stage1_worker() noexcept = default; - stage1_worker(const stage1_worker &) = delete; - stage1_worker(stage1_worker &&) = delete; - stage1_worker operator=(const stage1_worker &) = delete; - ~stage1_worker(); - /** - * We only start the thread when it is needed, not at object construction, - *this may throw. - * You should only call this once. - **/ - void start_thread(); - /** - * Start a stage 1 job. You should first call 'run', then 'finish'. - * You must call start_thread once before. - */ - void run(document_stream *ds, parser *stage1, size_t next_batch_start); - /** Wait for the run to finish (blocking). You should first call 'run', then - * 'finish'. **/ - void finish(); - - private: - /** - * Normally, we would never stop the thread. But we do in the destructor. - * This function is only safe assuming that you are not waiting for results. - *You - * should have called run, then finish, and be done. - **/ - void stop_thread(); - - std::thread thread{}; - /** These three variables define the work done by the thread. **/ - ondemand::parser *stage1_thread_parser{}; - size_t _next_batch_start{}; - document_stream *owner{}; - /** - * We have two state variables. This could be streamlined to one variable in - * the future but - * we use two for clarity. - */ - bool has_work{false}; - bool can_work{true}; - - /** - * We lock using a mutex. - */ - std::mutex locking_mutex{}; - std::condition_variable cond_var{}; - - friend class document_stream; -}; -#endif // SIMDJSON_THREADS_ENABLED - -/** - * A forward-only stream of documents. - * - * Produced by parser::iterate_many. - * - */ -class document_stream { - public: - /** - * Construct an uninitialized document_stream. - * - * ```c++ - * document_stream docs; - * auto error = parser.iterate_many(json).get(docs); - * ``` - */ - simdjson_really_inline document_stream() noexcept; - /** Move one document_stream to another. */ - simdjson_really_inline document_stream(document_stream &&other) noexcept = - default; - /** Move one document_stream to another. */ - simdjson_really_inline document_stream &operator=( - document_stream &&other) noexcept = default; - - simdjson_really_inline ~document_stream() noexcept; - - /** - * Returns the input size in bytes. - */ - inline size_t size_in_bytes() const noexcept; - - /** - * After iterating through the stream, this method - * returns the number of bytes that were not parsed at the end - * of the stream. If truncated_bytes() differs from zero, - * then the input was truncated maybe because incomplete JSON - * documents were found at the end of the stream. You - * may need to process the bytes in the interval - * [size_in_bytes()-truncated_bytes(), size_in_bytes()). - * - * You should only call truncated_bytes() after streaming through all - * documents, like so: - * - * document_stream stream = parser.iterate_many(json,window); - * for(auto & doc : stream) { - * // do something with doc - * } - * size_t truncated = stream.truncated_bytes(); - * - */ - inline size_t truncated_bytes() const noexcept; - - class iterator { - public: - using value_type = simdjson_result; - using reference = value_type; - - using difference_type = std::ptrdiff_t; - - using iterator_category = std::input_iterator_tag; - - /** - * Default constructor. - */ - simdjson_really_inline iterator() noexcept; - /** - * Get the current document (or error). - */ - simdjson_really_inline simdjson_result - operator*() noexcept; - /** - * Advance to the next document (prefix). - */ - inline iterator &operator++() noexcept; - /** - * Check if we're at the end yet. - * @param other the end iterator to compare to. - */ - simdjson_really_inline bool operator!=(const iterator &other) const - noexcept; - /** - * @private - * - * Gives the current index in the input document in bytes. - * - * document_stream stream = parser.parse_many(json,window); - * for(auto i = stream.begin(); i != stream.end(); ++i) { - * auto doc = *i; - * size_t index = i.current_index(); - * } - * - * This function (current_index()) is experimental and the usage - * may change in future versions of simdjson: we find the API somewhat - * awkward and we would like to offer something friendlier. - */ - simdjson_really_inline size_t current_index() const noexcept; - - /** - * @private - * - * Gives a view of the current document at the current position. - * - * document_stream stream = parser.iterate_many(json,window); - * for(auto i = stream.begin(); i != stream.end(); ++i) { - * std::string_view v = i.source(); - * } - * - * The returned string_view instance is simply a map to the (unparsed) - * source string: it may thus include white-space characters and all - * manner - * of padding. - * - * This function (source()) is experimental and the usage - * may change in future versions of simdjson: we find the API somewhat - * awkward and we would like to offer something friendlier. - * - */ - simdjson_really_inline std::string_view source() const noexcept; - - /** - * Returns error of the stream (if any). - */ - inline error_code error() const noexcept; - - private: - simdjson_really_inline iterator(document_stream *s, - bool finished) noexcept; - /** The document_stream we're iterating through. */ - document_stream *stream; - /** Whether we're finished or not. */ - bool finished; - - friend class document; - friend class document_stream; - friend class json_iterator; - }; - - /** - * Start iterating the documents in the stream. - */ - simdjson_really_inline iterator begin() noexcept; - /** - * The end of the stream, for iterator comparison purposes. - */ - simdjson_really_inline iterator end() noexcept; - - private: - document_stream &operator=(const document_stream &) = - delete; // Disallow copying - document_stream(const document_stream &other) = delete; // Disallow copying - - /** - * Construct a document_stream. Does not allocate or parse anything until - * the iterator is - * used. - * - * @param parser is a reference to the parser instance used to generate this - * document_stream - * @param buf is the raw byte buffer we need to process - * @param len is the length of the raw byte buffer in bytes - * @param batch_size is the size of the windows (must be strictly greater or - * equal to the largest JSON document) - */ - simdjson_really_inline document_stream(ondemand::parser &parser, - const uint8_t *buf, - size_t len, - size_t batch_size) noexcept; - - /** - * Parse the first document in the buffer. Used by begin(), to handle - * allocation and - * initialization. - */ - inline void start() noexcept; - - /** - * Parse the next document found in the buffer previously given to - * document_stream. - * - * The content should be a valid JSON document encoded as UTF-8. If there is - * a - * UTF-8 BOM, the caller is responsible for omitting it, UTF-8 BOM are - * discouraged. - * - * You do NOT need to pre-allocate a parser. This function takes care of - * pre-allocating a capacity defined by the batch_size defined when creating - * the - * document_stream object. - * - * The function returns simdjson::EMPTY if there is no more data to be - * parsed. - * - * The function returns simdjson::SUCCESS (as integer = 0) in case of - * success - * and indicates that the buffer has successfully been parsed to the end. - * Every document it contained has been parsed without error. - * - * The function returns an error code from simdjson/simdjson.h in case of - * failure - * such as simdjson::CAPACITY, simdjson::MEMALLOC, simdjson::DEPTH_ERROR and - * so forth; - * the simdjson::error_message function converts these error codes into a - * string). - * - * You can also check validity by calling parser.is_valid(). The same parser - * can - * and should be reused for the other documents in the buffer. - */ - inline void next() noexcept; - - /** Move the json_iterator of the document to the location of the next - * document in the stream. */ - inline void next_document() noexcept; - - /** Get the next document index. */ - inline size_t next_batch_start() const noexcept; - - /** Pass the next batch through stage 1 with the given parser. */ - inline error_code run_stage1(ondemand::parser &p, - size_t batch_start) noexcept; - - // Fields - ondemand::parser *parser; - const uint8_t *buf; - size_t len; - size_t batch_size; - /** - * We are going to use just one document instance. The document owns - * the json_iterator. It implies that we only ever pass a reference - * to the document to the users. - */ - document doc{}; - /** The error (or lack thereof) from the current document. */ - error_code error; - size_t batch_start{0}; - size_t doc_index{}; - -#ifdef SIMDJSON_THREADS_ENABLED - /** Indicates whether we use threads. Note that this needs to be a constant - * during the execution of the parsing. */ - bool use_thread; - - inline void load_from_stage1_thread() noexcept; - - /** Start a thread to run stage 1 on the next batch. */ - inline void start_stage1_thread() noexcept; - - /** Wait for the stage 1 thread to finish and capture the results. */ - inline void finish_stage1_thread() noexcept; - - /** The error returned from the stage 1 thread. */ - error_code stage1_thread_error{UNINITIALIZED}; - /** The thread used to run stage 1 against the next batch in the background. - */ - std::unique_ptr worker{new (std::nothrow) stage1_worker()}; - /** - * The parser used to run stage 1 in the background. Will be swapped - * with the regular parser when finished. - */ - ondemand::parser stage1_thread_parser{}; - - friend struct stage1_worker; -#endif // SIMDJSON_THREADS_ENABLED - - friend class parser; - friend class document; - friend class json_iterator; - friend struct simdjson_result; - friend struct internal::simdjson_result_base; -}; // document_stream - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { -template <> -struct simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_stream> - : public SIMDJSON_BUILTIN_IMPLEMENTATION:: - implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_stream> { - public: - simdjson_really_inline simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_stream - &&value) noexcept; ///< @private - simdjson_really_inline simdjson_result( - error_code error) noexcept; ///< @private - simdjson_really_inline simdjson_result() noexcept = default; -}; - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/document_stream.h */ -/* begin file include/simdjson/generic/ondemand/serialization.h */ - -namespace simdjson { -/** - * Create a string-view instance out of a document instance. The string-view - * instance - * contains JSON text that is suitable to be parsed as JSON again. - */ -inline simdjson_result to_json_string( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document &x) noexcept; -/** - * Create a string-view instance out of a value instance. The string-view - * instance - * contains JSON text that is suitable to be parsed as JSON again. The value - * must - * not have been accessed previously. - */ -inline simdjson_result to_json_string( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value &x) noexcept; -/** - * Create a string-view instance out of an object instance. The string-view - * instance - * contains JSON text that is suitable to be parsed as JSON again. - */ -inline simdjson_result to_json_string( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object &x) noexcept; -/** - * Create a string-view instance out of an array instance. The string-view - * instance - * contains JSON text that is suitable to be parsed as JSON again. - */ -inline simdjson_result to_json_string( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array &x) noexcept; -inline simdjson_result to_json_string( - simdjson_result x); -inline simdjson_result to_json_string( - simdjson_result x); -inline simdjson_result to_json_string( - simdjson_result x); -inline simdjson_result to_json_string( - simdjson_result x); -} // namespace simdjson - -/** - * We want to support argument-dependent lookup (ADL). - * Hence we should define operator<< in the namespace - * where the argument (here value, object, etc.) resides. - * Credit: @madhur4127 - * See https://github.com/simdjson/simdjson/issues/1768 - */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -/** - * Print JSON to an output stream. - * - * @param out The output stream. - * @param value The element. - * @throw if there is an error with the underlying output stream. simdjson - * itself will not throw. - */ -inline std::ostream &operator<<( - std::ostream &out, - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value x); -#if SIMDJSON_EXCEPTIONS -inline std::ostream &operator<<( - std::ostream &out, - simdjson::simdjson_result< - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value> x); -#endif -/** - * Print JSON to an output stream. - * - * @param out The output stream. - * @param value The array. - * @throw if there is an error with the underlying output stream. simdjson - * itself will not throw. - */ -inline std::ostream &operator<<( - std::ostream &out, - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array value); -#if SIMDJSON_EXCEPTIONS -inline std::ostream &operator<<( - std::ostream &out, - simdjson::simdjson_result< - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array> x); -#endif -/** - * Print JSON to an output stream. - * - * @param out The output stream. - * @param value The array. - * @throw if there is an error with the underlying output stream. simdjson - * itself will not throw. - */ -inline std::ostream &operator<<( - std::ostream &out, - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document &value); -#if SIMDJSON_EXCEPTIONS -inline std::ostream &operator<<( - std::ostream &out, - simdjson::simdjson_result< - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document> &&x); -#endif -inline std::ostream &operator<<( - std::ostream &out, - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference - &value); -#if SIMDJSON_EXCEPTIONS -inline std::ostream &operator<<( - std::ostream &out, - simdjson::simdjson_result< - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference> - &&x); -#endif -/** - * Print JSON to an output stream. - * - * @param out The output stream. - * @param value The object. - * @throw if there is an error with the underlying output stream. simdjson - * itself will not throw. - */ -inline std::ostream &operator<<( - std::ostream &out, - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object value); -#if SIMDJSON_EXCEPTIONS -inline std::ostream &operator<<( - std::ostream &out, - simdjson::simdjson_result< - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object> x); -#endif -} -} -} // namespace simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand - /* end file include/simdjson/generic/ondemand/serialization.h */ - /* end file include/simdjson/generic/ondemand.h */ - -// Inline definitions -/* begin file include/simdjson/generic/implementation_simdjson_result_base-inl.h - */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { - -// -// internal::implementation_simdjson_result_base inline implementation -// - -template - simdjson_really_inline void implementation_simdjson_result_base::tie( - T &value, error_code &error) && - noexcept { - error = this->second; - if (!error) { - value = - std::forward>(*this).first; - } -} - -template - simdjson_warn_unused simdjson_really_inline error_code - implementation_simdjson_result_base::get(T &value) && - noexcept { - error_code error; - std::forward>(*this).tie(value, - error); - return error; -} - -template -simdjson_really_inline error_code -implementation_simdjson_result_base::error() const noexcept { - return this->second; -} - -#if SIMDJSON_EXCEPTIONS - -template - simdjson_really_inline T &implementation_simdjson_result_base::value() & - noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return this->first; -} - -template - simdjson_really_inline T && - implementation_simdjson_result_base::value() && - noexcept(false) { - return std::forward>(*this) - .take_value(); -} - -template - simdjson_really_inline T && - implementation_simdjson_result_base::take_value() && - noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return std::forward(this->first); -} - -template - simdjson_really_inline implementation_simdjson_result_base:: - operator T &&() && - noexcept(false) { - return std::forward>(*this) - .take_value(); -} - -#endif // SIMDJSON_EXCEPTIONS - -template -simdjson_really_inline const T & -implementation_simdjson_result_base::value_unsafe() const &noexcept { - return this->first; -} - -template - simdjson_really_inline T & - implementation_simdjson_result_base::value_unsafe() & - noexcept { - return this->first; -} - -template - simdjson_really_inline T && - implementation_simdjson_result_base::value_unsafe() && - noexcept { - return std::forward(this->first); -} - -template -simdjson_really_inline -implementation_simdjson_result_base::implementation_simdjson_result_base( - T &&value, error_code error) noexcept : first{std::forward(value)}, - second{error} {} -template -simdjson_really_inline implementation_simdjson_result_base< - T>::implementation_simdjson_result_base(error_code error) noexcept - : implementation_simdjson_result_base(T{}, error) {} -template -simdjson_really_inline implementation_simdjson_result_base< - T>::implementation_simdjson_result_base(T &&value) noexcept - : implementation_simdjson_result_base(std::forward(value), SUCCESS) {} - -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson -/* end file include/simdjson/generic/implementation_simdjson_result_base-inl.h - */ -/* begin file include/simdjson/generic/ondemand-inl.h */ -/* begin file include/simdjson/generic/ondemand/json_type-inl.h */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -inline std::ostream &operator<<(std::ostream &out, json_type type) noexcept { - switch (type) { - case json_type::array: - out << "array"; - break; - case json_type::object: - out << "object"; - break; - case json_type::number: - out << "number"; - break; - case json_type::string: - out << "string"; - break; - case json_type::boolean: - out << "boolean"; - break; - case json_type::null: - out << "null"; - break; - default: - SIMDJSON_UNREACHABLE(); - } - return out; -} - -inline std::ostream &operator<<(std::ostream &out, number_type type) noexcept { - switch (type) { - case number_type::signed_integer: - out << "integer in [-9223372036854775808,9223372036854775808)"; - break; - case number_type::unsigned_integer: - out << "unsigned integer in " - "[9223372036854775808,18446744073709551616)"; - break; - case number_type::floating_point_number: - out << "floating-point number (binary64)"; - break; - default: - SIMDJSON_UNREACHABLE(); - } - return out; -} -#if SIMDJSON_EXCEPTIONS -inline std::ostream &operator<<( - std::ostream &out, simdjson_result &type) noexcept(false) { - return out << type.value(); -} -#endif - - -simdjson_really_inline number_type number::get_number_type() const noexcept { - return type; -} - -simdjson_really_inline bool number::is_uint64() const noexcept { - return get_number_type() == number_type::unsigned_integer; -} - -simdjson_really_inline uint64_t number::get_uint64() const noexcept { - return payload.unsigned_integer; -} - -simdjson_really_inline number::operator uint64_t() const noexcept { - return get_uint64(); -} - - -simdjson_really_inline bool number::is_int64() const noexcept { - return get_number_type() == number_type::signed_integer; -} - -simdjson_really_inline int64_t number::get_int64() const noexcept { - return payload.signed_integer; -} - -simdjson_really_inline number::operator int64_t() const noexcept { - return get_int64(); -} - -simdjson_really_inline bool number::is_double() const noexcept { - return get_number_type() == number_type::floating_point_number; -} - -simdjson_really_inline double number::get_double() const noexcept { - return payload.floating_point_number; -} - -simdjson_really_inline number::operator double() const noexcept { - return get_double(); -} - -simdjson_really_inline double number::as_double() const noexcept { - if (is_double()) { - return payload.floating_point_number; - } - if (is_int64()) { - return double(payload.signed_integer); - } - return double(payload.unsigned_integer); -} - -simdjson_really_inline void number::append_s64(int64_t value) noexcept { - payload.signed_integer = value; - type = number_type::signed_integer; -} - -simdjson_really_inline void number::append_u64(uint64_t value) noexcept { - payload.unsigned_integer = value; - type = number_type::unsigned_integer; -} - -simdjson_really_inline void number::append_double(double value) noexcept { - payload.floating_point_number = value; - type = number_type::floating_point_number; -} - -simdjson_really_inline void number::skip_double() noexcept { - type = number_type::floating_point_number; -} - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -simdjson_really_inline -simdjson_result:: - simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::json_type &&value) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::json_type>( - std::forward( - value)) {} -simdjson_really_inline -simdjson_result:: - simdjson_result(error_code error) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::json_type>(error) {} - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/json_type-inl.h */ -/* begin file include/simdjson/generic/ondemand/logger-inl.h */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { -namespace logger { - -static constexpr const char *DASHES = - "--------------------------------------------------------------------------" - "--------------------------------------------------------------------------" - "--------------------------------------------------------------------------" - "----------------------------------"; -static constexpr const int LOG_EVENT_LEN = 20; -static constexpr const int LOG_BUFFER_LEN = 30; -static constexpr const int LOG_SMALL_BUFFER_LEN = 10; -static int log_depth = 0; // Not threadsafe. Log only. - -// Helper to turn unprintable or newline characters into spaces -static inline char printable_char(char c) { - if (c >= 0x20) { - return c; - } else { - return ' '; - } -} - -inline void log_event(const json_iterator &iter, - const char *type, - std::string_view detail, - int delta, - int depth_delta) noexcept { - log_line(iter, "", type, detail, delta, depth_delta); -} - -inline void log_value(const json_iterator &iter, - token_position index, - depth_t depth, - const char *type, - std::string_view detail) noexcept { - log_line(iter, index, depth, "", type, detail); -} -inline void log_value(const json_iterator &iter, - const char *type, - std::string_view detail, - int delta, - int depth_delta) noexcept { - log_line(iter, "", type, detail, delta, depth_delta); -} - -inline void log_start_value(const json_iterator &iter, - token_position index, - depth_t depth, - const char *type, - std::string_view detail) noexcept { - log_line(iter, index, depth, "+", type, detail); - if (LOG_ENABLED) { - log_depth++; - } -} -inline void log_start_value(const json_iterator &iter, - const char *type, - int delta, - int depth_delta) noexcept { - log_line(iter, "+", type, "", delta, depth_delta); - if (LOG_ENABLED) { - log_depth++; - } -} - -inline void log_end_value(const json_iterator &iter, - const char *type, - int delta, - int depth_delta) noexcept { - if (LOG_ENABLED) { - log_depth--; - } - log_line(iter, "-", type, "", delta, depth_delta); -} - -inline void log_error(const json_iterator &iter, - const char *error, - const char *detail, - int delta, - int depth_delta) noexcept { - log_line(iter, "ERROR: ", error, detail, delta, depth_delta); -} -inline void log_error(const json_iterator &iter, - token_position index, - depth_t depth, - const char *error, - const char *detail) noexcept { - log_line(iter, index, depth, "ERROR: ", error, detail); -} - -inline void log_event(const value_iterator &iter, - const char *type, - std::string_view detail, - int delta, - int depth_delta) noexcept { - log_event(iter.json_iter(), type, detail, delta, depth_delta); -} - -inline void log_value(const value_iterator &iter, - const char *type, - std::string_view detail, - int delta, - int depth_delta) noexcept { - log_value(iter.json_iter(), type, detail, delta, depth_delta); -} - -inline void log_start_value(const value_iterator &iter, - const char *type, - int delta, - int depth_delta) noexcept { - log_start_value(iter.json_iter(), type, delta, depth_delta); -} - -inline void log_end_value(const value_iterator &iter, - const char *type, - int delta, - int depth_delta) noexcept { - log_end_value(iter.json_iter(), type, delta, depth_delta); -} - -inline void log_error(const value_iterator &iter, - const char *error, - const char *detail, - int delta, - int depth_delta) noexcept { - log_error(iter.json_iter(), error, detail, delta, depth_delta); -} - -inline void log_headers() noexcept { - if (LOG_ENABLED) { - // Technically a static variable is not thread-safe, but if you are - // using threads - // and logging... well... - static bool displayed_hint{false}; - log_depth = 0; - printf("\n"); - if (!displayed_hint) { - // We only print this helpful header once. - printf( - "# Logging provides the depth and position of the iterator " - "user-visible steps:\n"); - printf( - "# +array says 'this is where we were when we discovered the " - "start array'\n"); - printf( - "# -array says 'this is where we were when we ended the " - "array'\n"); - printf( - "# skip says 'this is a structural or value I am skipping'\n"); - printf( - "# +/-skip says 'this is a start/end array or object I am " - "skipping'\n"); - printf("#\n"); - printf( - "# The identation of the terms (array, string,...) indicates " - "the depth,\n"); - printf("# in addition to the depth being displayed.\n"); - printf("#\n"); - printf( - "# Every token in the document has a single depth determined " - "by the tokens before it,\n"); - printf("# and is not affected by what the token actually is.\n"); - printf("#\n"); - printf( - "# Not all structural elements are presented as tokens in the " - "logs.\n"); - printf("#\n"); - printf( - "# We never give control to the user within an empty array or " - "an empty object.\n"); - printf("#\n"); - printf( - "# Inside an array, having a depth greater than the array's " - "depth means that\n"); - printf("# we are pointing inside a value.\n"); - printf( - "# Having a depth equal to the array means that we are " - "pointing right before a value.\n"); - printf( - "# Having a depth smaller than the array means that we have " - "moved beyond the array.\n"); - displayed_hint = true; - } - printf("\n"); - printf("| %-*s ", LOG_EVENT_LEN, "Event"); - printf("| %-*s ", LOG_BUFFER_LEN, "Buffer"); - printf("| %-*s ", LOG_SMALL_BUFFER_LEN, "Next"); - // printf("| %-*s ", 5, "Next#"); - printf("| %-*s ", 5, "Depth"); - printf("| Detail "); - printf("|\n"); - - printf("|%.*s", LOG_EVENT_LEN + 2, DASHES); - printf("|%.*s", LOG_BUFFER_LEN + 2, DASHES); - printf("|%.*s", LOG_SMALL_BUFFER_LEN + 2, DASHES); - // printf("|%.*s", 5+2, DASHES); - printf("|%.*s", 5 + 2, DASHES); - printf("|--------"); - printf("|\n"); - fflush(stdout); - } -} - -inline void log_line(const json_iterator &iter, - const char *title_prefix, - const char *title, - std::string_view detail, - int delta, - int depth_delta) noexcept { - log_line(iter, - iter.position() + delta, - depth_t(iter.depth() + depth_delta), - title_prefix, - title, - detail); -} -inline void log_line(const json_iterator &iter, - token_position index, - depth_t depth, - const char *title_prefix, - const char *title, - std::string_view detail) noexcept { - if (LOG_ENABLED) { - const int indent = depth * 2; - const auto buf = iter.token.buf; - printf("| %*s%s%-*s ", - indent, - "", - title_prefix, - LOG_EVENT_LEN - indent - int(strlen(title_prefix)), - title); - { - // Print the current structural. - printf("| "); - auto current_structural = &buf[*index]; - for (int i = 0; i < LOG_BUFFER_LEN; i++) { - printf("%c", printable_char(current_structural[i])); - } - printf(" "); - } - { - // Print the next structural. - printf("| "); - auto next_structural = &buf[*(index + 1)]; - for (int i = 0; i < LOG_SMALL_BUFFER_LEN; i++) { - printf("%c", printable_char(next_structural[i])); - } - printf(" "); - } - // printf("| %5u ", *(index+1)); - printf("| %5u ", depth); - printf("| %.*s ", int(detail.size()), detail.data()); - printf("|\n"); - fflush(stdout); - } -} - -} // namespace logger -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/logger-inl.h */ -/* begin file include/simdjson/generic/ondemand/raw_json_string-inl.h */ -namespace simdjson { - -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -simdjson_really_inline raw_json_string::raw_json_string( - const uint8_t *_buf) noexcept : buf{_buf} {} - -simdjson_really_inline const char *raw_json_string::raw() const noexcept { - return reinterpret_cast(buf); -} -simdjson_really_inline simdjson_warn_unused simdjson_result -raw_json_string::unescape(uint8_t *&dst) const noexcept { - uint8_t *end = stringparsing::parse_string(buf, dst); - if (!end) { - return STRING_ERROR; - } - std::string_view result(reinterpret_cast(dst), end - dst); - dst = end; - return result; -} - -simdjson_really_inline bool raw_json_string::is_free_from_unescaped_quote( - std::string_view target) noexcept { - size_t pos{0}; - // if the content has no escape character, just scan through it quickly! - for (; pos < target.size() && target[pos] != '\\'; pos++) { - } - // slow path may begin. - bool escaping{false}; - for (; pos < target.size(); pos++) { - if ((target[pos] == '"') && !escaping) { - return false; - } else if (target[pos] == '\\') { - escaping = !escaping; - } else { - escaping = false; - } - } - return true; -} - -simdjson_really_inline bool raw_json_string::is_free_from_unescaped_quote( - const char *target) noexcept { - size_t pos{0}; - // if the content has no escape character, just scan through it quickly! - for (; target[pos] && target[pos] != '\\'; pos++) { - } - // slow path may begin. - bool escaping{false}; - for (; target[pos]; pos++) { - if ((target[pos] == '"') && !escaping) { - return false; - } else if (target[pos] == '\\') { - escaping = !escaping; - } else { - escaping = false; - } - } - return true; -} - - -simdjson_really_inline bool raw_json_string::unsafe_is_equal( - size_t length, std::string_view target) const noexcept { - // If we are going to call memcmp, then we must know something about the - // length of the raw_json_string. - return (length >= target.size()) && (raw()[target.size()] == '"') && - !memcmp(raw(), target.data(), target.size()); -} - -simdjson_really_inline bool raw_json_string::unsafe_is_equal( - std::string_view target) const noexcept { - // Assumptions: does not contain unescaped quote characters, and - // the raw content is quote terminated within a valid JSON string. - if (target.size() <= SIMDJSON_PADDING) { - return (raw()[target.size()] == '"') && - !memcmp(raw(), target.data(), target.size()); - } - const char *r{raw()}; - size_t pos{0}; - for (; pos < target.size(); pos++) { - if (r[pos] != target[pos]) { - return false; - } - } - if (r[pos] != '"') { - return false; - } - return true; -} - -simdjson_really_inline bool raw_json_string::is_equal( - std::string_view target) const noexcept { - const char *r{raw()}; - size_t pos{0}; - bool escaping{false}; - for (; pos < target.size(); pos++) { - if (r[pos] != target[pos]) { - return false; - } - // if target is a compile-time constant and it is free from - // quotes, then the next part could get optimized away through - // inlining. - if ((target[pos] == '"') && !escaping) { - // We have reached the end of the raw_json_string but - // the target is not done. - return false; - } else if (target[pos] == '\\') { - escaping = !escaping; - } else { - escaping = false; - } - } - if (r[pos] != '"') { - return false; - } - return true; -} - - -simdjson_really_inline bool raw_json_string::unsafe_is_equal( - const char *target) const noexcept { - // Assumptions: 'target' does not contain unescaped quote characters, is - // null terminated and - // the raw content is quote terminated within a valid JSON string. - const char *r{raw()}; - size_t pos{0}; - for (; target[pos]; pos++) { - if (r[pos] != target[pos]) { - return false; - } - } - if (r[pos] != '"') { - return false; - } - return true; -} - -simdjson_really_inline bool raw_json_string::is_equal(const char *target) const - noexcept { - // Assumptions: does not contain unescaped quote characters, and - // the raw content is quote terminated within a valid JSON string. - const char *r{raw()}; - size_t pos{0}; - bool escaping{false}; - for (; target[pos]; pos++) { - if (r[pos] != target[pos]) { - return false; - } - // if target is a compile-time constant and it is free from - // quotes, then the next part could get optimized away through - // inlining. - if ((target[pos] == '"') && !escaping) { - // We have reached the end of the raw_json_string but - // the target is not done. - return false; - } else if (target[pos] == '\\') { - escaping = !escaping; - } else { - escaping = false; - } - } - if (r[pos] != '"') { - return false; - } - return true; -} - -simdjson_unused simdjson_really_inline bool operator==( - const raw_json_string &a, std::string_view c) noexcept { - return a.unsafe_is_equal(c); -} - -simdjson_unused simdjson_really_inline bool operator==( - std::string_view c, const raw_json_string &a) noexcept { - return a == c; -} - -simdjson_unused simdjson_really_inline bool operator!=( - const raw_json_string &a, std::string_view c) noexcept { - return !(a == c); -} - -simdjson_unused simdjson_really_inline bool operator!=( - std::string_view c, const raw_json_string &a) noexcept { - return !(a == c); -} - - -simdjson_really_inline simdjson_warn_unused simdjson_result -raw_json_string::unescape(json_iterator &iter) const noexcept { - return unescape(iter.string_buf_loc()); -} - - -simdjson_unused simdjson_really_inline std::ostream &operator<<( - std::ostream &out, const raw_json_string &str) noexcept { - bool in_escape = false; - const char *s = str.raw(); - while (true) { - switch (*s) { - case '\\': - in_escape = !in_escape; - break; - case '"': - if (in_escape) { - in_escape = false; - } else { - return out; - } - break; - default: - if (in_escape) { - in_escape = false; - } - } - out << *s; - s++; - } -} - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -simdjson_really_inline -simdjson_result:: - simdjson_result(SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::raw_json_string - &&value) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::raw_json_string>( - std::forward< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::raw_json_string>( - value)) {} -simdjson_really_inline -simdjson_result:: - simdjson_result(error_code error) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::raw_json_string>(error) {} - -simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::raw_json_string>::raw() const - noexcept { - if (error()) { - return error(); - } - return first.raw(); -} -simdjson_really_inline simdjson_warn_unused simdjson_result -simdjson_result:: - unescape(uint8_t *&dst) const noexcept { - if (error()) { - return error(); - } - return first.unescape(dst); -} -simdjson_really_inline simdjson_warn_unused simdjson_result -simdjson_result:: - unescape(SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::json_iterator &iter) - const noexcept { - if (error()) { - return error(); - } - return first.unescape(iter); -} - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/raw_json_string-inl.h */ -/* begin file include/simdjson/generic/ondemand/token_iterator-inl.h */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -simdjson_really_inline token_iterator::token_iterator( - const uint8_t *_buf, token_position position) noexcept - : buf{_buf}, - _position{position} {} - -simdjson_really_inline uint32_t token_iterator::current_offset() const - noexcept { - return *(_position); -} - - -simdjson_really_inline const uint8_t * -token_iterator::return_current_and_advance() noexcept { - return &buf[*(_position++)]; -} - -simdjson_really_inline const uint8_t *token_iterator::peek( - token_position position) const noexcept { - return &buf[*position]; -} -simdjson_really_inline uint32_t -token_iterator::peek_index(token_position position) const noexcept { - return *position; -} -simdjson_really_inline uint32_t -token_iterator::peek_length(token_position position) const noexcept { - return *(position + 1) - *position; -} - -simdjson_really_inline const uint8_t *token_iterator::peek(int32_t delta) const - noexcept { - return &buf[*(_position + delta)]; -} -simdjson_really_inline uint32_t token_iterator::peek_index(int32_t delta) const - noexcept { - return *(_position + delta); -} -simdjson_really_inline uint32_t token_iterator::peek_length(int32_t delta) const - noexcept { - return *(_position + delta + 1) - *(_position + delta); -} - -simdjson_really_inline token_position token_iterator::position() const - noexcept { - return _position; -} -simdjson_really_inline void token_iterator::set_position( - token_position target_position) noexcept { - _position = target_position; -} - -simdjson_really_inline bool token_iterator::operator==( - const token_iterator &other) const noexcept { - return _position == other._position; -} -simdjson_really_inline bool token_iterator::operator!=( - const token_iterator &other) const noexcept { - return _position != other._position; -} -simdjson_really_inline bool token_iterator::operator>( - const token_iterator &other) const noexcept { - return _position > other._position; -} -simdjson_really_inline bool token_iterator::operator>=( - const token_iterator &other) const noexcept { - return _position >= other._position; -} -simdjson_really_inline bool token_iterator::operator<( - const token_iterator &other) const noexcept { - return _position < other._position; -} -simdjson_really_inline bool token_iterator::operator<=( - const token_iterator &other) const noexcept { - return _position <= other._position; -} - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -simdjson_really_inline -simdjson_result:: - simdjson_result(SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::token_iterator - &&value) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::token_iterator>( - std::forward< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::token_iterator>( - value)) {} -simdjson_really_inline -simdjson_result:: - simdjson_result(error_code error) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::token_iterator>(error) {} - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/token_iterator-inl.h */ -/* begin file include/simdjson/generic/ondemand/json_iterator-inl.h */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -simdjson_really_inline json_iterator::json_iterator( - json_iterator &&other) noexcept - : token(std::forward(other.token)), - parser{other.parser}, - _string_buf_loc{other._string_buf_loc}, - error{other.error}, - _depth{other._depth}, - _root{other._root}, - _streaming{other._streaming} { - other.parser = nullptr; -} -simdjson_really_inline json_iterator &json_iterator::operator=( - json_iterator &&other) noexcept { - token = other.token; - parser = other.parser; - _string_buf_loc = other._string_buf_loc; - error = other.error; - _depth = other._depth; - _root = other._root; - _streaming = other._streaming; - other.parser = nullptr; - return *this; -} - -simdjson_really_inline json_iterator::json_iterator( - const uint8_t *buf, ondemand::parser *_parser) noexcept - : token(buf, &_parser->implementation->structural_indexes[0]), - parser{_parser}, - _string_buf_loc{parser->string_buf.get()}, - _depth{1}, - _root{parser->implementation->structural_indexes.get()}, - _streaming{false} - -{ - logger::log_headers(); -#if SIMDJSON_CHECK_EOF - assert_more_tokens(); -#endif -} - -inline void json_iterator::rewind() noexcept { - token.set_position(root_position()); - logger::log_headers(); // We start again - _string_buf_loc = parser->string_buf.get(); - _depth = 1; -} - -// GCC 7 warns when the first line of this function is inlined away into -// oblivion due to the caller -// relating depth and parent_depth, which is a desired effect. The warning does -// not show up if the -// skip_child() function is not marked inline). -SIMDJSON_PUSH_DISABLE_WARNINGS -SIMDJSON_DISABLE_STRICT_OVERFLOW_WARNING -simdjson_warn_unused simdjson_really_inline error_code -json_iterator::skip_child(depth_t parent_depth) noexcept { - if (depth() <= parent_depth) { - return SUCCESS; - } - switch (*return_current_and_advance()) { - // TODO consider whether matching braces is a requirement: if - // non-matching braces indicates - // *missing* braces, then future lookups are not in the object/arrays - // they think they are, - // violating the rule "validate enough structure that the user can be - // confident they are - // looking at the right values." - // PERF TODO we can eliminate the switch here with a lookup of how much - // to add to depth - - // For the first open array/object in a value, we've already incremented - // depth, so keep it the same - // We never stop at colon, but if we did, it wouldn't affect depth - case '[': - case '{': - case ':': - logger::log_start_value(*this, "skip"); - break; - // If there is a comma, we have just finished a value in an - // array/object, and need to get back in - case ',': - logger::log_value(*this, "skip"); - break; - // ] or } means we just finished a value and need to jump out of the - // array/object - case ']': - case '}': - logger::log_end_value(*this, "skip"); - _depth--; - if (depth() <= parent_depth) { - return SUCCESS; - } -#if SIMDJSON_CHECK_EOF - // If there are no more tokens, the parent is incomplete. - if (at_end()) { - return report_error(INCOMPLETE_ARRAY_OR_OBJECT, - "Missing [ or { at start"); - } -#endif // SIMDJSON_CHECK_EOF - break; - case '"': - if (*peek() == ':') { - // We are at a key!!! - // This might happen if you just started an object and you skip - // it immediately. - // Performance note: it would be nice to get rid of this check - // as it is somewhat - // expensive. - // https://github.com/simdjson/simdjson/issues/1742 - logger::log_value(*this, "key"); - return_current_and_advance(); // eat up the ':' - break; // important!!! - } - simdjson_fallthrough; - // Anything else must be a scalar value - default: - // For the first scalar, we will have incremented depth already, so - // we decrement it here. - logger::log_value(*this, "skip"); - _depth--; - if (depth() <= parent_depth) { - return SUCCESS; - } - break; - } - - // Now that we've considered the first value, we only increment/decrement - // for arrays/objects - while (position() < end_position()) { - switch (*return_current_and_advance()) { - case '[': - case '{': - logger::log_start_value(*this, "skip"); - _depth++; - break; - // TODO consider whether matching braces is a requirement: if - // non-matching braces indicates - // *missing* braces, then future lookups are not in the - // object/arrays they think they are, - // violating the rule "validate enough structure that the user can - // be confident they are - // looking at the right values." - // PERF TODO we can eliminate the switch here with a lookup of how - // much to add to depth - case ']': - case '}': - logger::log_end_value(*this, "skip"); - _depth--; - if (depth() <= parent_depth) { - return SUCCESS; - } - break; - default: - logger::log_value(*this, "skip", ""); - break; - } - } - - return report_error(TAPE_ERROR, "not enough close braces"); -} - -SIMDJSON_POP_DISABLE_WARNINGS - -simdjson_really_inline bool json_iterator::at_root() const noexcept { - return position() == root_position(); -} - -simdjson_really_inline bool json_iterator::streaming() const noexcept { - return _streaming; -} - -simdjson_really_inline token_position json_iterator::root_position() const - noexcept { - return _root; -} - -simdjson_really_inline void json_iterator::assert_at_document_depth() const - noexcept { - SIMDJSON_ASSUME(_depth == 1); -} - -simdjson_really_inline void json_iterator::assert_at_root() const noexcept { - SIMDJSON_ASSUME(_depth == 1); -#ifndef SIMDJSON_CLANG_VISUAL_STUDIO - // Under Visual Studio, the next SIMDJSON_ASSUME fails with: the argument - // has side effects that will be discarded. - SIMDJSON_ASSUME(token.position() == _root); -#endif -} - -simdjson_really_inline void json_iterator::assert_more_tokens( - uint32_t required_tokens) const noexcept { - assert_valid_position(token._position + required_tokens - 1); -} - -simdjson_really_inline void json_iterator::assert_valid_position( - token_position position) const noexcept { -#ifndef SIMDJSON_CLANG_VISUAL_STUDIO - SIMDJSON_ASSUME(position >= &parser->implementation->structural_indexes[0]); - SIMDJSON_ASSUME(position < - &parser->implementation->structural_indexes - [parser->implementation->n_structural_indexes]); -#endif -} - -simdjson_really_inline bool json_iterator::at_end() const noexcept { - return position() == end_position(); -} -simdjson_really_inline token_position json_iterator::end_position() const - noexcept { - uint32_t n_structural_indexes{parser->implementation->n_structural_indexes}; - return &parser->implementation->structural_indexes[n_structural_indexes]; -} - -inline std::string json_iterator::to_string() const noexcept { - if (!is_alive()) { - return "dead json_iterator instance"; - } - const char *current_structural = - reinterpret_cast(token.peek()); - return std::string("json_iterator [ depth : ") + std::to_string(_depth) + - std::string(", structural : '") + - std::string(current_structural, 1) + std::string("', offset : ") + - std::to_string(token.current_offset()) + std::string("', error : ") + - error_message(error) + std::string(" ]"); -} - -inline simdjson_result -json_iterator::current_location() noexcept { - if (!is_alive()) { // Unrecoverable error - if (!at_root()) { - return reinterpret_cast(token.peek(-1)); - } else { - return reinterpret_cast(token.peek()); - } - } - if (at_end()) { - return OUT_OF_BOUNDS; - } - return reinterpret_cast(token.peek()); -} - -simdjson_really_inline bool json_iterator::is_alive() const noexcept { - return parser; -} - -simdjson_really_inline void json_iterator::abandon() noexcept { - parser = nullptr; - _depth = 0; -} - -simdjson_really_inline const uint8_t * -json_iterator::return_current_and_advance() noexcept { -#if SIMDJSON_CHECK_EOF - assert_more_tokens(); -#endif // SIMDJSON_CHECK_EOF - return token.return_current_and_advance(); -} - -simdjson_really_inline const uint8_t *json_iterator::unsafe_pointer() const - noexcept { - // deliberately done without safety guard: - return token.peek(0); -} - -simdjson_really_inline const uint8_t *json_iterator::peek(int32_t delta) const - noexcept { -#if SIMDJSON_CHECK_EOF - assert_more_tokens(delta + 1); -#endif // SIMDJSON_CHECK_EOF - return token.peek(delta); -} - -simdjson_really_inline uint32_t json_iterator::peek_length(int32_t delta) const - noexcept { -#if SIMDJSON_CHECK_EOF - assert_more_tokens(delta + 1); -#endif // #if SIMDJSON_CHECK_EOF - return token.peek_length(delta); -} - -simdjson_really_inline const uint8_t *json_iterator::peek( - token_position position) const noexcept { - // todo: currently we require end-of-string buffering, but the following - // assert_valid_position should be turned on if/when we lift that condition. - // assert_valid_position(position); - // This is almost surely related to SIMDJSON_CHECK_EOF but given that - // SIMDJSON_CHECK_EOF - // is ON by default, we have no choice but to disable it for real with a - // comment. - return token.peek(position); -} - -simdjson_really_inline uint32_t -json_iterator::peek_length(token_position position) const noexcept { -#if SIMDJSON_CHECK_EOF - assert_valid_position(position); -#endif // SIMDJSON_CHECK_EOF - return token.peek_length(position); -} - -simdjson_really_inline token_position json_iterator::last_position() const - noexcept { - // The following line fails under some compilers... - // SIMDJSON_ASSUME(parser->implementation->n_structural_indexes > 0); - // since it has side-effects. - uint32_t n_structural_indexes{parser->implementation->n_structural_indexes}; - SIMDJSON_ASSUME(n_structural_indexes > 0); - return &parser->implementation - ->structural_indexes[n_structural_indexes - 1]; -} -simdjson_really_inline const uint8_t *json_iterator::peek_last() const - noexcept { - return token.peek(last_position()); -} - -simdjson_really_inline void json_iterator::ascend_to( - depth_t parent_depth) noexcept { - SIMDJSON_ASSUME(parent_depth >= 0 && parent_depth < INT32_MAX - 1); - SIMDJSON_ASSUME(_depth == parent_depth + 1); - _depth = parent_depth; -} - -simdjson_really_inline void json_iterator::descend_to( - depth_t child_depth) noexcept { - SIMDJSON_ASSUME(child_depth >= 1 && child_depth < INT32_MAX); - SIMDJSON_ASSUME(_depth == child_depth - 1); - _depth = child_depth; -} - -simdjson_really_inline depth_t json_iterator::depth() const noexcept { - return _depth; -} - -simdjson_really_inline uint8_t *&json_iterator::string_buf_loc() noexcept { - return _string_buf_loc; -} - -simdjson_really_inline error_code -json_iterator::report_error(error_code _error, const char *message) noexcept { - SIMDJSON_ASSUME(_error != SUCCESS && _error != UNINITIALIZED && - _error != INCORRECT_TYPE && _error != NO_SUCH_FIELD); - logger::log_error(*this, message); - error = _error; - return error; -} - -simdjson_really_inline token_position json_iterator::position() const noexcept { - return token.position(); -} - -simdjson_really_inline void json_iterator::reenter_child( - token_position position, depth_t child_depth) noexcept { - SIMDJSON_ASSUME(child_depth >= 1 && child_depth < INT32_MAX); - SIMDJSON_ASSUME(_depth == child_depth - 1); -#ifdef SIMDJSON_DEVELOPMENT_CHECKS -#ifndef SIMDJSON_CLANG_VISUAL_STUDIO - SIMDJSON_ASSUME(position >= parser->start_positions[child_depth]); -#endif -#endif - token.set_position(position); - _depth = child_depth; -} - -#ifdef SIMDJSON_DEVELOPMENT_CHECKS - -simdjson_really_inline token_position -json_iterator::start_position(depth_t depth) const noexcept { - return parser->start_positions[depth]; -} - -simdjson_really_inline void json_iterator::set_start_position( - depth_t depth, token_position position) noexcept { - parser->start_positions[depth] = position; -} - -#endif - - -simdjson_really_inline error_code -json_iterator::optional_error(error_code _error, const char *message) noexcept { - SIMDJSON_ASSUME(_error == INCORRECT_TYPE || _error == NO_SUCH_FIELD); - logger::log_error(*this, message); - return _error; -} - -template -simdjson_warn_unused simdjson_really_inline bool json_iterator::copy_to_buffer( - const uint8_t *json, uint32_t max_len, uint8_t (&tmpbuf)[N]) noexcept { - // Let us guard against silly cases: - if ((N < max_len) || (N == 0)) { - return false; - } - // Truncate whitespace to fit the buffer. - if (max_len > N - 1) { - // if (jsoncharutils::is_not_structural_or_whitespace(json[N-1])) { - // return false; } - max_len = N - 1; - } - - // Copy to the buffer. - std::memcpy(tmpbuf, json, max_len); - tmpbuf[max_len] = ' '; - return true; -} - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -simdjson_really_inline -simdjson_result:: - simdjson_result(SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::json_iterator - &&value) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::json_iterator>( - std::forward< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::json_iterator>( - value)) {} -simdjson_really_inline -simdjson_result:: - simdjson_result(error_code error) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::json_iterator>(error) {} - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/json_iterator-inl.h */ -/* begin file include/simdjson/generic/ondemand/value_iterator-inl.h */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -simdjson_really_inline value_iterator::value_iterator( - json_iterator *json_iter, - depth_t depth, - token_position start_position) noexcept : _json_iter{json_iter}, - _depth{depth}, - _start_position{start_position} {} - -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::start_object() noexcept { - SIMDJSON_TRY(start_container('{', "Not an object", "object")); - return started_object(); -} - -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::start_root_object() noexcept { - SIMDJSON_TRY(start_container('{', "Not an object", "object")); - return started_root_object(); -} - -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::started_object() noexcept { - assert_at_container_start(); -#ifdef SIMDJSON_DEVELOPMENT_CHECKS - _json_iter->set_start_position(_depth, start_position()); -#endif - if (*_json_iter->peek() == '}') { - logger::log_value(*_json_iter, "empty object"); - _json_iter->return_current_and_advance(); - end_container(); - return false; - } - return true; -} - -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::started_root_object() noexcept { - // When in streaming mode, we cannot expect peek_last() to be the last - // structural element of the - // current document. It only works in the normal mode where we have indexed - // a single document. - // Note that adding a check for 'streaming' is not expensive since we only - // have at most - // one root element. - if (!_json_iter->streaming() && (*_json_iter->peek_last() != '}')) { - _json_iter->abandon(); - return report_error(INCOMPLETE_ARRAY_OR_OBJECT, "missing } at end"); - } - return started_object(); -} - -simdjson_warn_unused simdjson_really_inline error_code -value_iterator::end_container() noexcept { -#if SIMDJSON_CHECK_EOF - if (depth() > 1 && at_end()) { - return report_error(INCOMPLETE_ARRAY_OR_OBJECT, - "missing parent ] or }"); - } -// if (depth() <= 1 && !at_end()) { return -// report_error(INCOMPLETE_ARRAY_OR_OBJECT, "missing [ or { at start"); } -#endif // SIMDJSON_CHECK_EOF - _json_iter->ascend_to(depth() - 1); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::has_next_field() noexcept { - assert_at_next(); - - // It's illegal to call this unless there are more tokens: anything that - // ends in } or ] is - // obligated to verify there are more tokens if they are not the top level. - switch (*_json_iter->return_current_and_advance()) { - case '}': - logger::log_end_value(*_json_iter, "object"); - SIMDJSON_TRY(end_container()); - return false; - case ',': - return true; - default: - return report_error(TAPE_ERROR, - "Missing comma between object fields"); - } -} - -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::find_field_raw(const std::string_view key) noexcept { - error_code error; - bool has_value; - // - // Initially, the object can be in one of a few different places: - // - // 1. The start of the object, at the first field: - // - // ``` - // { "a": [ 1, 2 ], "b": [ 3, 4 ] } - // ^ (depth 2, index 1) - // ``` - if (at_first_field()) { - has_value = true; - - // - // 2. When a previous search did not yield a value or the object is - // empty: - // - // ``` - // { "a": [ 1, 2 ], "b": [ 3, 4 ] } - // ^ (depth 0) - // { } - // ^ (depth 0, index 2) - // ``` - // - } else if (!is_open()) { -#ifdef SIMDJSON_DEVELOPMENT_CHECKS - // If we're past the end of the object, we're being iterated out of - // order. - // Note: this isn't perfect detection. It's possible the user is inside - // some other object; if so, - // this object iterator will blithely scan that object for fields. - if (_json_iter->depth() < depth() - 1) { - return OUT_OF_ORDER_ITERATION; - } -#endif - return false; - - // 3. When a previous search found a field or an iterator yielded a - // value: - // - // ``` - // // When a field was not fully consumed (or not even touched at - // all) - // { "a": [ 1, 2 ], "b": [ 3, 4 ] } - // ^ (depth 2) - // // When a field was fully consumed - // { "a": [ 1, 2 ], "b": [ 3, 4 ] } - // ^ (depth 1) - // // When the last field was fully consumed - // { "a": [ 1, 2 ], "b": [ 3, 4 ] } - // ^ (depth 1) - // ``` - // - } else { - if ((error = skip_child())) { - abandon(); - return error; - } - if ((error = has_next_field().get(has_value))) { - abandon(); - return error; - } -#ifdef SIMDJSON_DEVELOPMENT_CHECKS - if (_json_iter->start_position(_depth) != start_position()) { - return OUT_OF_ORDER_ITERATION; - } -#endif - } - while (has_value) { - // Get the key and colon, stopping at the value. - raw_json_string actual_key; - // size_t max_key_length = _json_iter->peek_length() - 2; // -2 for the - // two quotes - // Note: _json_iter->peek_length() - 2 might overflow if - // _json_iter->peek_length() < 2. - // field_key() advances the pointer and checks that '"' is found - // (corresponding to a key). - // The depth is left unchanged by field_key(). - if ((error = field_key().get(actual_key))) { - abandon(); - return error; - }; - // field_value() will advance and check that we find a ':' separating - // the - // key and the value. It will also increment the depth by one. - if ((error = field_value())) { - abandon(); - return error; - } - // If it matches, stop and return - // We could do it this way if we wanted to allow arbitrary - // key content (including escaped quotes). - // if (actual_key.unsafe_is_equal(max_key_length, key)) { - // Instead we do the following which may trigger buffer overruns if the - // user provides an adversarial key (containing a well placed unescaped - // quote - // character and being longer than the number of bytes remaining in the - // JSON - // input). - if (actual_key.unsafe_is_equal(key)) { - logger::log_event(*this, "match", key, -2); - // If we return here, then we return while pointing at the ':' that - // we just checked. - return true; - } - - // No match: skip the value and see if , or } is next - logger::log_event(*this, "no match", key, -2); - // The call to skip_child is meant to skip over the value corresponding - // to the key. - // After skip_child(), we are right before the next comma (',') or the - // final brace ('}'). - SIMDJSON_TRY(skip_child()); // Skip the value entirely - // The has_next_field() advances the pointer and check that either ',' - // or '}' is found. - // It returns true if ',' is found, false otherwise. If anything other - // than ',' or '}' is found, - // then we are in error and we abort. - if ((error = has_next_field().get(has_value))) { - abandon(); - return error; - } - } - - // If the loop ended, we're out of fields to look at. - return false; -} - -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::find_field_unordered_raw(const std::string_view key) noexcept { - /** - * When find_field_unordered_raw is called, we can either be pointing at the - * first key, pointing outside (at the closing brace) or if a key was - * matched - * we can be either pointing right afterthe ':' right before the value (that - * we need skip), - * or we may have consumed the value and we might be at a comma or at the - * final brace (ready for a call to has_next_field()). - */ - error_code error; - bool has_value; - - // First, we scan from that point to the end. - // If we don't find a match, we may loop back around, and scan from the - // beginning to that point. - token_position search_start = _json_iter->position(); - - // We want to know whether we need to go back to the beginning. - bool at_first = at_first_field(); - /////////////// - // Initially, the object can be in one of a few different places: - // - // 1. At the first key: - // - // ``` - // { "a": [ 1, 2 ], "b": [ 3, 4 ] } - // ^ (depth 2, index 1) - // ``` - // - if (at_first) { - has_value = true; - - // 2. When a previous search did not yield a value or the object is - // empty: - // - // ``` - // { "a": [ 1, 2 ], "b": [ 3, 4 ] } - // ^ (depth 0) - // { } - // ^ (depth 0, index 2) - // ``` - // - } else if (!is_open()) { -#ifdef SIMDJSON_DEVELOPMENT_CHECKS - // If we're past the end of the object, we're being iterated out of - // order. - // Note: this isn't perfect detection. It's possible the user is inside - // some other object; if so, - // this object iterator will blithely scan that object for fields. - if (_json_iter->depth() < depth() - 1) { - return OUT_OF_ORDER_ITERATION; - } -#endif - SIMDJSON_TRY(reset_object().get(has_value)); - at_first = true; - // 3. When a previous search found a field or an iterator yielded a - // value: - // - // ``` - // // When a field was not fully consumed (or not even touched at - // all) - // { "a": [ 1, 2 ], "b": [ 3, 4 ] } - // ^ (depth 2) - // // When a field was fully consumed - // { "a": [ 1, 2 ], "b": [ 3, 4 ] } - // ^ (depth 1) - // // When the last field was fully consumed - // { "a": [ 1, 2 ], "b": [ 3, 4 ] } - // ^ (depth 1) - // ``` - // - } else { - // If someone queried a key but they not did access the value, then we - // are left pointing - // at the ':' and we need to move forward through the value... If the - // value was - // processed then skip_child() does not move the iterator (but may - // adjust the depth). - if ((error = skip_child())) { - abandon(); - return error; - } - search_start = _json_iter->position(); - if ((error = has_next_field().get(has_value))) { - abandon(); - return error; - } -#ifdef SIMDJSON_DEVELOPMENT_CHECKS - if (_json_iter->start_position(_depth) != start_position()) { - return OUT_OF_ORDER_ITERATION; - } -#endif - } - - // After initial processing, we will be in one of two states: - // - // ``` - // // At the beginning of a field - // { "a": [ 1, 2 ], "b": [ 3, 4 ] } - // ^ (depth 1) - // { "a": [ 1, 2 ], "b": [ 3, 4 ] } - // ^ (depth 1) - // // At the end of the object - // { "a": [ 1, 2 ], "b": [ 3, 4 ] } - // ^ (depth 0) - // ``` - // - // Next, we find a match starting from the current position. - while (has_value) { - SIMDJSON_ASSUME(_json_iter->_depth == - _depth); // We must be at the start of a field - - // Get the key and colon, stopping at the value. - raw_json_string actual_key; - // size_t max_key_length = _json_iter->peek_length() - 2; // -2 for the - // two quotes - // Note: _json_iter->peek_length() - 2 might overflow if - // _json_iter->peek_length() < 2. - // field_key() advances the pointer and checks that '"' is found - // (corresponding to a key). - // The depth is left unchanged by field_key(). - if ((error = field_key().get(actual_key))) { - abandon(); - return error; - }; - // field_value() will advance and check that we find a ':' separating - // the - // key and the value. It will also increment the depth by one. - if ((error = field_value())) { - abandon(); - return error; - } - - // If it matches, stop and return - // We could do it this way if we wanted to allow arbitrary - // key content (including escaped quotes). - // if (actual_key.unsafe_is_equal(max_key_length, key)) { - // Instead we do the following which may trigger buffer overruns if the - // user provides an adversarial key (containing a well placed unescaped - // quote - // character and being longer than the number of bytes remaining in the - // JSON - // input). - if (actual_key.unsafe_is_equal(key)) { - logger::log_event(*this, "match", key, -2); - // If we return here, then we return while pointing at the ':' that - // we just checked. - return true; - } - - // No match: skip the value and see if , or } is next - logger::log_event(*this, "no match", key, -2); - // The call to skip_child is meant to skip over the value corresponding - // to the key. - // After skip_child(), we are right before the next comma (',') or the - // final brace ('}'). - SIMDJSON_TRY(skip_child()); - // The has_next_field() advances the pointer and check that either ',' - // or '}' is found. - // It returns true if ',' is found, false otherwise. If anything other - // than ',' or '}' is found, - // then we are in error and we abort. - if ((error = has_next_field().get(has_value))) { - abandon(); - return error; - } - } - // Performance note: it maybe wasteful to rewind to the beginning when there - // might be - // no other query following. Indeed, it would require reskipping the whole - // object. - // Instead, you can just stay where you are. If there is a new query, there - // is always time - // to rewind. - if (at_first) { - return false; - } - - // If we reach the end without finding a match, search the rest of the - // fields starting at the - // beginning of the object. - // (We have already run through the object before, so we've already - // validated its structure. We - // don't check errors in this bit.) - SIMDJSON_TRY(reset_object().get(has_value)); - while (true) { - SIMDJSON_ASSUME(has_value); // we should reach search_start before ever - // reaching the end of the object - SIMDJSON_ASSUME(_json_iter->_depth == - _depth); // We must be at the start of a field - - // Get the key and colon, stopping at the value. - raw_json_string actual_key; - // size_t max_key_length = _json_iter->peek_length() - 2; // -2 for the - // two quotes - // Note: _json_iter->peek_length() - 2 might overflow if - // _json_iter->peek_length() < 2. - // field_key() advances the pointer and checks that '"' is found - // (corresponding to a key). - // The depth is left unchanged by field_key(). - error = field_key().get(actual_key); - SIMDJSON_ASSUME(!error); - // field_value() will advance and check that we find a ':' separating - // the - // key and the value. It will also increment the depth by one. - error = field_value(); - SIMDJSON_ASSUME(!error); - - // If it matches, stop and return - // We could do it this way if we wanted to allow arbitrary - // key content (including escaped quotes). - // if (actual_key.unsafe_is_equal(max_key_length, key)) { - // Instead we do the following which may trigger buffer overruns if the - // user provides an adversarial key (containing a well placed unescaped - // quote - // character and being longer than the number of bytes remaining in the - // JSON - // input). - if (actual_key.unsafe_is_equal(key)) { - logger::log_event(*this, "match", key, -2); - // If we return here, then we return while pointing at the ':' that - // we just checked. - return true; - } - - // No match: skip the value and see if , or } is next - logger::log_event(*this, "no match", key, -2); - // The call to skip_child is meant to skip over the value corresponding - // to the key. - // After skip_child(), we are right before the next comma (',') or the - // final brace ('}'). - SIMDJSON_TRY(skip_child()); - // If we reached the end of the key-value pair we started from, then we - // know - // that the key is not there so we return false. We are either right - // before - // the next comma or the final brace. - if (_json_iter->position() == search_start) { - return false; - } - // The has_next_field() advances the pointer and check that either ',' - // or '}' is found. - // It returns true if ',' is found, false otherwise. If anything other - // than ',' or '}' is found, - // then we are in error and we abort. - error = has_next_field().get(has_value); - SIMDJSON_ASSUME(!error); - // If we make the mistake of exiting here, then we could be left - // pointing at a key - // in the middle of an object. That's not an allowable state. - } - // If the loop ended, we're out of fields to look at. The program should - // never reach this point. - return false; -} - -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::field_key() noexcept { - assert_at_next(); - - const uint8_t *key = _json_iter->return_current_and_advance(); - if (*(key++) != '"') { - return report_error(TAPE_ERROR, "Object key is not a string"); - } - return raw_json_string(key); -} - -simdjson_warn_unused simdjson_really_inline error_code -value_iterator::field_value() noexcept { - assert_at_next(); - - if (*_json_iter->return_current_and_advance() != ':') { - return report_error(TAPE_ERROR, "Missing colon in object field"); - } - _json_iter->descend_to(depth() + 1); - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::start_array() noexcept { - SIMDJSON_TRY(start_container('[', "Not an array", "array")); - return started_array(); -} - -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::start_root_array() noexcept { - SIMDJSON_TRY(start_container('[', "Not an array", "array")); - return started_root_array(); -} - -inline std::string value_iterator::to_string() const noexcept { - auto answer = std::string("value_iterator [ depth : ") + - std::to_string(_depth) + std::string(", "); - if (_json_iter != nullptr) { - answer += _json_iter->to_string(); - } - answer += std::string(" ]"); - return answer; -} - -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::started_array() noexcept { - assert_at_container_start(); - if (*_json_iter->peek() == ']') { - logger::log_value(*_json_iter, "empty array"); - _json_iter->return_current_and_advance(); - SIMDJSON_TRY(end_container()); - return false; - } - _json_iter->descend_to(depth() + 1); -#ifdef SIMDJSON_DEVELOPMENT_CHECKS - _json_iter->set_start_position(_depth, start_position()); -#endif - return true; -} - -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::started_root_array() noexcept { - // When in streaming mode, we cannot expect peek_last() to be the last - // structural element of the - // current document. It only works in the normal mode where we have indexed - // a single document. - // Note that adding a check for 'streaming' is not expensive since we only - // have at most - // one root element. - if (!_json_iter->streaming() && (*_json_iter->peek_last() != ']')) { - _json_iter->abandon(); - return report_error(INCOMPLETE_ARRAY_OR_OBJECT, "missing ] at end"); - } - return started_array(); -} - -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::has_next_element() noexcept { - assert_at_next(); - - logger::log_event(*this, "has_next_element"); - switch (*_json_iter->return_current_and_advance()) { - case ']': - logger::log_end_value(*_json_iter, "array"); - SIMDJSON_TRY(end_container()); - return false; - case ',': - _json_iter->descend_to(depth() + 1); - return true; - default: - return report_error(TAPE_ERROR, - "Missing comma between array elements"); - } -} - -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::parse_bool(const uint8_t *json) const noexcept { - auto not_true = atomparsing::str4ncmp(json, "true"); - auto not_false = atomparsing::str4ncmp(json, "fals") | (json[4] ^ 'e'); - bool error = - (not_true && not_false) || - jsoncharutils::is_not_structural_or_whitespace(json[not_true ? 5 : 4]); - if (error) { - return incorrect_type_error("Not a boolean"); - } - return simdjson_result(!not_true); -} -simdjson_really_inline bool value_iterator::parse_null( - const uint8_t *json) const noexcept { - return !atomparsing::str4ncmp(json, "null") && - jsoncharutils::is_structural_or_whitespace(json[4]); -} - -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::get_string() noexcept { - return get_raw_json_string().unescape(_json_iter->string_buf_loc()); -} -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::get_raw_json_string() noexcept { - auto json = peek_scalar("string"); - if (*json != '"') { - return incorrect_type_error("Not a string"); - } - advance_scalar("string"); - return raw_json_string(json + 1); -} -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::get_uint64() noexcept { - auto result = numberparsing::parse_unsigned(peek_non_root_scalar("uint64")); - if (result.error() == SUCCESS) { - advance_non_root_scalar("uint64"); - } - return result; -} -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::get_uint64_in_string() noexcept { - auto result = - numberparsing::parse_unsigned_in_string(peek_non_root_scalar("uint64")); - if (result.error() == SUCCESS) { - advance_non_root_scalar("uint64"); - } - return result; -} -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::get_int64() noexcept { - auto result = numberparsing::parse_integer(peek_non_root_scalar("int64")); - if (result.error() == SUCCESS) { - advance_non_root_scalar("int64"); - } - return result; -} -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::get_int64_in_string() noexcept { - auto result = - numberparsing::parse_integer_in_string(peek_non_root_scalar("int64")); - if (result.error() == SUCCESS) { - advance_non_root_scalar("int64"); - } - return result; -} -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::get_double() noexcept { - auto result = numberparsing::parse_double(peek_non_root_scalar("double")); - if (result.error() == SUCCESS) { - advance_non_root_scalar("double"); - } - return result; -} -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::get_double_in_string() noexcept { - auto result = - numberparsing::parse_double_in_string(peek_non_root_scalar("double")); - if (result.error() == SUCCESS) { - advance_non_root_scalar("double"); - } - return result; -} -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::get_bool() noexcept { - auto result = parse_bool(peek_non_root_scalar("bool")); - if (result.error() == SUCCESS) { - advance_non_root_scalar("bool"); - } - return result; -} -simdjson_really_inline bool value_iterator::is_null() noexcept { - auto result = parse_null(peek_non_root_scalar("null")); - if (result) { - advance_non_root_scalar("null"); - } - return result; -} -simdjson_really_inline bool value_iterator::is_negative() noexcept { - return numberparsing::is_negative(peek_non_root_scalar("numbersign")); -} -simdjson_really_inline bool value_iterator::is_root_negative() noexcept { - return numberparsing::is_negative(peek_root_scalar("numbersign")); -} -simdjson_really_inline simdjson_result -value_iterator::is_integer() noexcept { - return numberparsing::is_integer(peek_non_root_scalar("integer")); -} -simdjson_really_inline simdjson_result -value_iterator::get_number_type() noexcept { - return numberparsing::get_number_type(peek_non_root_scalar("integer")); -} -simdjson_really_inline simdjson_result -value_iterator::get_number() noexcept { - number num; - error_code error = - numberparsing::parse_number(peek_non_root_scalar("number"), num); - if (error) { - return error; - } - return num; -} - -simdjson_really_inline simdjson_result -value_iterator::is_root_integer() noexcept { - auto max_len = peek_start_length(); - auto json = peek_root_scalar("is_root_integer"); - uint8_t - tmpbuf[20 + 1]; // <20 digits> is the longest possible unsigned integer - if (!_json_iter->copy_to_buffer(json, max_len, tmpbuf)) { - return false; // if there are more than 20 characters, it cannot be - // represented as an integer. - } - return numberparsing::is_integer(tmpbuf); -} - -simdjson_really_inline - simdjson_result - value_iterator::get_root_number_type() noexcept { - auto max_len = peek_start_length(); - auto json = peek_root_scalar("number"); - // Per - // https://www.exploringbinary.com/maximum-number-of-decimal-digits-in-binary-floating-point-numbers/, - // 1074 is the maximum number of significant fractional digits. Add 8 more - // digits for the biggest - // number: -0.e-308. - uint8_t tmpbuf[1074 + 8 + 1]; - if (!_json_iter->copy_to_buffer(json, max_len, tmpbuf)) { - logger::log_error(*_json_iter, - start_position(), - depth(), - "Root number more than 1082 characters"); - return NUMBER_ERROR; - } - return numberparsing::get_number_type(tmpbuf); -} -simdjson_really_inline simdjson_result -value_iterator::get_root_number() noexcept { - auto max_len = peek_start_length(); - auto json = peek_root_scalar("number"); - // Per - // https://www.exploringbinary.com/maximum-number-of-decimal-digits-in-binary-floating-point-numbers/, - // 1074 is the maximum number of significant fractional digits. Add 8 more - // digits for the biggest - // number: -0.e-308. - uint8_t tmpbuf[1074 + 8 + 1]; - if (!_json_iter->copy_to_buffer(json, max_len, tmpbuf)) { - logger::log_error(*_json_iter, - start_position(), - depth(), - "Root number more than 1082 characters"); - return NUMBER_ERROR; - } - number num; - error_code error = numberparsing::parse_number(tmpbuf, num); - if (error) { - return error; - } - advance_root_scalar("number"); - return num; -} - -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::get_root_string() noexcept { - return get_string(); -} -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::get_root_raw_json_string() noexcept { - return get_raw_json_string(); -} -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::get_root_uint64() noexcept { - auto max_len = peek_start_length(); - auto json = peek_root_scalar("uint64"); - uint8_t - tmpbuf[20 + 1]; // <20 digits> is the longest possible unsigned integer - if (!_json_iter->copy_to_buffer(json, max_len, tmpbuf)) { - logger::log_error(*_json_iter, - start_position(), - depth(), - "Root number more than 20 characters"); - return NUMBER_ERROR; - } - auto result = numberparsing::parse_unsigned(tmpbuf); - if (result.error() == SUCCESS) { - advance_root_scalar("uint64"); - } - return result; -} -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::get_root_uint64_in_string() noexcept { - auto max_len = peek_start_length(); - auto json = peek_root_scalar("uint64"); - uint8_t - tmpbuf[20 + 1]; // <20 digits> is the longest possible unsigned integer - if (!_json_iter->copy_to_buffer(json, max_len, tmpbuf)) { - logger::log_error(*_json_iter, - start_position(), - depth(), - "Root number more than 20 characters"); - return NUMBER_ERROR; - } - auto result = numberparsing::parse_unsigned_in_string(tmpbuf); - if (result.error() == SUCCESS) { - advance_root_scalar("uint64"); - } - return result; -} -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::get_root_int64() noexcept { - auto max_len = peek_start_length(); - auto json = peek_root_scalar("int64"); - uint8_t tmpbuf[20 + 1]; // -<19 digits> is the longest possible integer - if (!_json_iter->copy_to_buffer(json, max_len, tmpbuf)) { - logger::log_error(*_json_iter, - start_position(), - depth(), - "Root number more than 20 characters"); - return NUMBER_ERROR; - } - - auto result = numberparsing::parse_integer(tmpbuf); - if (result.error() == SUCCESS) { - advance_root_scalar("int64"); - } - return result; -} -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::get_root_int64_in_string() noexcept { - auto max_len = peek_start_length(); - auto json = peek_root_scalar("int64"); - uint8_t tmpbuf[20 + 1]; // -<19 digits> is the longest possible integer - if (!_json_iter->copy_to_buffer(json, max_len, tmpbuf)) { - logger::log_error(*_json_iter, - start_position(), - depth(), - "Root number more than 20 characters"); - return NUMBER_ERROR; - } - - auto result = numberparsing::parse_integer_in_string(tmpbuf); - if (result.error() == SUCCESS) { - advance_root_scalar("int64"); - } - return result; -} -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::get_root_double() noexcept { - auto max_len = peek_start_length(); - auto json = peek_root_scalar("double"); - // Per - // https://www.exploringbinary.com/maximum-number-of-decimal-digits-in-binary-floating-point-numbers/, - // 1074 is the maximum number of significant fractional digits. Add 8 more - // digits for the biggest - // number: -0.e-308. - uint8_t tmpbuf[1074 + 8 + 1]; - if (!_json_iter->copy_to_buffer(json, max_len, tmpbuf)) { - logger::log_error(*_json_iter, - start_position(), - depth(), - "Root number more than 1082 characters"); - return NUMBER_ERROR; - } - auto result = numberparsing::parse_double(tmpbuf); - if (result.error() == SUCCESS) { - advance_root_scalar("double"); - } - return result; -} - -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::get_root_double_in_string() noexcept { - auto max_len = peek_start_length(); - auto json = peek_root_scalar("double"); - // Per - // https://www.exploringbinary.com/maximum-number-of-decimal-digits-in-binary-floating-point-numbers/, - // 1074 is the maximum number of significant fractional digits. Add 8 more - // digits for the biggest - // number: -0.e-308. - uint8_t tmpbuf[1074 + 8 + 1]; - if (!_json_iter->copy_to_buffer(json, max_len, tmpbuf)) { - logger::log_error(*_json_iter, - start_position(), - depth(), - "Root number more than 1082 characters"); - return NUMBER_ERROR; - } - auto result = numberparsing::parse_double_in_string(tmpbuf); - if (result.error() == SUCCESS) { - advance_root_scalar("double"); - } - return result; -} -simdjson_warn_unused simdjson_really_inline simdjson_result -value_iterator::get_root_bool() noexcept { - auto max_len = peek_start_length(); - auto json = peek_root_scalar("bool"); - uint8_t tmpbuf[5 + 1]; - if (!_json_iter->copy_to_buffer(json, max_len, tmpbuf)) { - return incorrect_type_error("Not a boolean"); - } - auto result = parse_bool(tmpbuf); - if (result.error() == SUCCESS) { - advance_root_scalar("bool"); - } - return result; -} -simdjson_really_inline bool value_iterator::is_root_null() noexcept { - auto max_len = peek_start_length(); - auto json = peek_root_scalar("null"); - bool result = - (max_len >= 4 && !atomparsing::str4ncmp(json, "null") && - (max_len == 4 || jsoncharutils::is_structural_or_whitespace(json[5]))); - if (result) { - advance_root_scalar("null"); - } - return result; -} - -simdjson_warn_unused simdjson_really_inline error_code -value_iterator::skip_child() noexcept { - SIMDJSON_ASSUME(_json_iter->token._position > _start_position); - SIMDJSON_ASSUME(_json_iter->_depth >= _depth); - - return _json_iter->skip_child(depth()); -} - -simdjson_really_inline value_iterator value_iterator::child() const noexcept { - assert_at_child(); - return {_json_iter, depth() + 1, _json_iter->token.position()}; -} - -// GCC 7 warns when the first line of this function is inlined away into -// oblivion due to the caller -// relating depth and iterator depth, which is a desired effect. It does not -// happen if is_open is -// marked non-inline. -SIMDJSON_PUSH_DISABLE_WARNINGS -SIMDJSON_DISABLE_STRICT_OVERFLOW_WARNING -simdjson_really_inline bool value_iterator::is_open() const noexcept { - return _json_iter->depth() >= depth(); -} -SIMDJSON_POP_DISABLE_WARNINGS - -simdjson_really_inline bool value_iterator::at_end() const noexcept { - return _json_iter->at_end(); -} - -simdjson_really_inline bool value_iterator::at_start() const noexcept { - return _json_iter->token.position() == start_position(); -} - -simdjson_really_inline bool value_iterator::at_first_field() const noexcept { - SIMDJSON_ASSUME(_json_iter->token._position > _start_position); - return _json_iter->token.position() == start_position() + 1; -} - -simdjson_really_inline void value_iterator::abandon() noexcept { - _json_iter->abandon(); -} - -simdjson_warn_unused simdjson_really_inline depth_t -value_iterator::depth() const noexcept { - return _depth; -} -simdjson_warn_unused simdjson_really_inline error_code -value_iterator::error() const noexcept { - return _json_iter->error; -} -simdjson_warn_unused simdjson_really_inline uint8_t *& -value_iterator::string_buf_loc() noexcept { - return _json_iter->string_buf_loc(); -} -simdjson_warn_unused simdjson_really_inline const json_iterator & -value_iterator::json_iter() const noexcept { - return *_json_iter; -} -simdjson_warn_unused simdjson_really_inline json_iterator & -value_iterator::json_iter() noexcept { - return *_json_iter; -} - -simdjson_really_inline const uint8_t *value_iterator::peek_start() const - noexcept { - return _json_iter->peek(start_position()); -} -simdjson_really_inline uint32_t value_iterator::peek_start_length() const - noexcept { - return _json_iter->peek_length(start_position()); -} - -simdjson_really_inline const uint8_t *value_iterator::peek_scalar( - const char *type) noexcept { - logger::log_value(*_json_iter, start_position(), depth(), type); - // If we're not at the position anymore, we don't want to advance the - // cursor. - if (!is_at_start()) { - return peek_start(); - } - - // Get the JSON and advance the cursor, decreasing depth to signify that we - // have retrieved the value. - assert_at_start(); - return _json_iter->peek(); -} - -simdjson_really_inline void value_iterator::advance_scalar( - const char *type) noexcept { - logger::log_value(*_json_iter, start_position(), depth(), type); - // If we're not at the position anymore, we don't want to advance the - // cursor. - if (!is_at_start()) { - return; - } - - // Get the JSON and advance the cursor, decreasing depth to signify that we - // have retrieved the value. - assert_at_start(); - _json_iter->return_current_and_advance(); - _json_iter->ascend_to(depth() - 1); -} - -simdjson_really_inline error_code -value_iterator::start_container(uint8_t start_char, - const char *incorrect_type_message, - const char *type) noexcept { - logger::log_start_value(*_json_iter, start_position(), depth(), type); - // If we're not at the position anymore, we don't want to advance the - // cursor. - const uint8_t *json; - if (!is_at_start()) { -#ifdef SIMDJSON_DEVELOPMENT_CHECKS - if (!is_at_iterator_start()) { - return OUT_OF_ORDER_ITERATION; - } -#endif - json = peek_start(); - if (*json != start_char) { - return incorrect_type_error(incorrect_type_message); - } - } else { - assert_at_start(); - /** - * We should be prudent. Let us peek. If it is not the right type, we - * return an error. Only once we have determined that we have the right - * type are we allowed to advance! - */ - json = _json_iter->peek(); - if (*json != start_char) { - return incorrect_type_error(incorrect_type_message); - } - _json_iter->return_current_and_advance(); - } - - - return SUCCESS; -} - - -simdjson_really_inline const uint8_t *value_iterator::peek_root_scalar( - const char *type) noexcept { - logger::log_value(*_json_iter, start_position(), depth(), type); - if (!is_at_start()) { - return peek_start(); - } - - assert_at_root(); - return _json_iter->peek(); -} -simdjson_really_inline const uint8_t *value_iterator::peek_non_root_scalar( - const char *type) noexcept { - logger::log_value(*_json_iter, start_position(), depth(), type); - if (!is_at_start()) { - return peek_start(); - } - - assert_at_non_root_start(); - return _json_iter->peek(); -} - -simdjson_really_inline void value_iterator::advance_root_scalar( - const char *type) noexcept { - logger::log_value(*_json_iter, start_position(), depth(), type); - if (!is_at_start()) { - return; - } - - assert_at_root(); - _json_iter->return_current_and_advance(); - _json_iter->ascend_to(depth() - 1); -} -simdjson_really_inline void value_iterator::advance_non_root_scalar( - const char *type) noexcept { - logger::log_value(*_json_iter, start_position(), depth(), type); - if (!is_at_start()) { - return; - } - - assert_at_non_root_start(); - _json_iter->return_current_and_advance(); - _json_iter->ascend_to(depth() - 1); -} - -simdjson_really_inline error_code -value_iterator::incorrect_type_error(const char *message) const noexcept { - logger::log_error(*_json_iter, start_position(), depth(), message); - return INCORRECT_TYPE; -} - -simdjson_really_inline bool value_iterator::is_at_start() const noexcept { - return position() == start_position(); -} - -simdjson_really_inline bool value_iterator::is_at_key() const noexcept { - // Keys are at the same depth as the object. - // Note here that we could be safer and check that we are within an object, - // but we do not. - return _depth == _json_iter->_depth && *_json_iter->peek() == '"'; -} - -simdjson_really_inline bool value_iterator::is_at_iterator_start() const - noexcept { - // We can legitimately be either at the first value ([1]), or after the - // array if it's empty ([]). - auto delta = position() - start_position(); - return delta == 1 || delta == 2; -} - -inline void value_iterator::assert_at_start() const noexcept { - SIMDJSON_ASSUME(_json_iter->token._position == _start_position); - SIMDJSON_ASSUME(_json_iter->_depth == _depth); - SIMDJSON_ASSUME(_depth > 0); -} - -inline void value_iterator::assert_at_container_start() const noexcept { - SIMDJSON_ASSUME(_json_iter->token._position == _start_position + 1); - SIMDJSON_ASSUME(_json_iter->_depth == _depth); - SIMDJSON_ASSUME(_depth > 0); -} - -inline void value_iterator::assert_at_next() const noexcept { - SIMDJSON_ASSUME(_json_iter->token._position > _start_position); - SIMDJSON_ASSUME(_json_iter->_depth == _depth); - SIMDJSON_ASSUME(_depth > 0); -} - -simdjson_really_inline void value_iterator::move_at_start() noexcept { - _json_iter->_depth = _depth; - _json_iter->token.set_position(_start_position); -} - -simdjson_really_inline void value_iterator::move_at_container_start() noexcept { - _json_iter->_depth = _depth; - _json_iter->token.set_position(_start_position + 1); -} - -simdjson_really_inline simdjson_result -value_iterator::reset_array() noexcept { - move_at_container_start(); - return started_array(); -} - -simdjson_really_inline simdjson_result -value_iterator::reset_object() noexcept { - move_at_container_start(); - return started_object(); -} - -inline void value_iterator::assert_at_child() const noexcept { - SIMDJSON_ASSUME(_json_iter->token._position > _start_position); - SIMDJSON_ASSUME(_json_iter->_depth == _depth + 1); - SIMDJSON_ASSUME(_depth > 0); -} - -inline void value_iterator::assert_at_root() const noexcept { - assert_at_start(); - SIMDJSON_ASSUME(_depth == 1); -} - -inline void value_iterator::assert_at_non_root_start() const noexcept { - assert_at_start(); - SIMDJSON_ASSUME(_depth > 1); -} - -inline void value_iterator::assert_is_valid() const noexcept { - SIMDJSON_ASSUME(_json_iter != nullptr); -} - -simdjson_really_inline bool value_iterator::is_valid() const noexcept { - return _json_iter != nullptr; -} - -simdjson_really_inline simdjson_result value_iterator::type() const - noexcept { - switch (*peek_start()) { - case '{': - return json_type::object; - case '[': - return json_type::array; - case '"': - return json_type::string; - case 'n': - return json_type::null; - case 't': - case 'f': - return json_type::boolean; - case '-': - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - return json_type::number; - default: - return TAPE_ERROR; - } -} - -simdjson_really_inline token_position value_iterator::start_position() const - noexcept { - return _start_position; -} - -simdjson_really_inline token_position value_iterator::position() const - noexcept { - return _json_iter->position(); -} - -simdjson_really_inline token_position value_iterator::end_position() const - noexcept { - return _json_iter->end_position(); -} - -simdjson_really_inline token_position value_iterator::last_position() const - noexcept { - return _json_iter->last_position(); -} - -simdjson_really_inline error_code -value_iterator::report_error(error_code error, const char *message) noexcept { - return _json_iter->report_error(error, message); -} - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -simdjson_really_inline -simdjson_result:: - simdjson_result(SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value_iterator - &&value) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value_iterator>( - std::forward< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value_iterator>( - value)) {} -simdjson_really_inline -simdjson_result:: - simdjson_result(error_code error) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value_iterator>(error) {} - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/value_iterator-inl.h */ -/* begin file include/simdjson/generic/ondemand/array_iterator-inl.h */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -simdjson_really_inline array_iterator::array_iterator( - const value_iterator &_iter) noexcept : iter{_iter} {} - -simdjson_really_inline simdjson_result array_iterator::operator - *() noexcept { - if (iter.error()) { - iter.abandon(); - return iter.error(); - } - return value(iter.child()); -} -simdjson_really_inline bool array_iterator::operator==( - const array_iterator &other) const noexcept { - return !(*this != other); -} -simdjson_really_inline bool array_iterator::operator!=( - const array_iterator &) const noexcept { - return iter.is_open(); -} -simdjson_really_inline array_iterator &array_iterator::operator++() noexcept { - error_code error; - // PERF NOTE this is a safety rail ... users should exit loops as soon as - // they receive an error, so we'll never get here. - // However, it does not seem to make a perf difference, so we add it out of - // an abundance of caution. - if ((error = iter.error())) { - return *this; - } - if ((error = iter.skip_child())) { - return *this; - } - if ((error = iter.has_next_element().error())) { - return *this; - } - return *this; -} - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -simdjson_really_inline -simdjson_result:: - simdjson_result(SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator - &&value) noexcept - : SIMDJSON_BUILTIN_IMPLEMENTATION::implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator>( - std::forward< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator>( - value)) { - first.iter.assert_is_valid(); -} -simdjson_really_inline -simdjson_result:: - simdjson_result(error_code error) noexcept - : SIMDJSON_BUILTIN_IMPLEMENTATION::implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator>({}, - error) {} - -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value> - simdjson_result:: - operator*() noexcept { - if (error()) { - return error(); - } - return *first; -} -simdjson_really_inline bool -simdjson_result:: -operator==(const simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator> &other) - const noexcept { - if (!first.iter.is_valid()) { - return !error(); - } - return first == other.first; -} -simdjson_really_inline bool -simdjson_result:: -operator!=(const simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator> &other) - const noexcept { - if (!first.iter.is_valid()) { - return error(); - } - return first != other.first; -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator> & - simdjson_result:: - operator++() noexcept { - // Clear the error if there is one, so we don't yield it twice - if (error()) { - second = SUCCESS; - return *this; - } - ++(first); - return *this; -} - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/array_iterator-inl.h */ -/* begin file include/simdjson/generic/ondemand/object_iterator-inl.h */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -// -// object_iterator -// - -simdjson_really_inline object_iterator::object_iterator( - const value_iterator &_iter) noexcept : iter{_iter} {} - -simdjson_really_inline simdjson_result object_iterator::operator - *() noexcept { - error_code error = iter.error(); - if (error) { - iter.abandon(); - return error; - } - auto result = field::start(iter); - // TODO this is a safety rail ... users should exit loops as soon as they - // receive an error. - // Nonetheless, let's see if performance is OK with this if statement--the - // compiler may give it to us for free. - if (result.error()) { - iter.abandon(); - } - return result; -} -simdjson_really_inline bool object_iterator::operator==( - const object_iterator &other) const noexcept { - return !(*this != other); -} -simdjson_really_inline bool object_iterator::operator!=( - const object_iterator &) const noexcept { - return iter.is_open(); -} - -simdjson_really_inline object_iterator &object_iterator::operator++() noexcept { - // TODO this is a safety rail ... users should exit loops as soon as they - // receive an error. - // Nonetheless, let's see if performance is OK with this if statement--the - // compiler may give it to us for free. - if (!iter.is_open()) { - return *this; - } // Iterator will be released if there is an error - - simdjson_unused error_code error; - if ((error = iter.skip_child())) { - return *this; - } - - simdjson_unused bool has_value; - if ((error = iter.has_next_field().get(has_value))) { - return *this; - }; - return *this; -} - -// -// ### Live States -// -// While iterating or looking up values, depth >= iter.depth. at_start may vary. -// Error is -// always SUCCESS: -// -// - Start: This is the state when the object is first found and the iterator is -// just past the {. -// In this state, at_start == true. -// - Next: After we hand a scalar value to the user, or an array/object which -// they then fully -// iterate over, the iterator is at the , or } before the next value. In this -// state, -// depth == iter.depth, at_start == false, and error == SUCCESS. -// - Unfinished Business: When we hand an array/object to the user which they do -// not fully -// iterate over, we need to finish that iteration by skipping child values -// until we reach the -// Next state. In this state, depth > iter.depth, at_start == false, and error -// == SUCCESS. -// -// ## Error States -// -// In error states, we will yield exactly one more value before stopping. -// iter.depth == depth -// and at_start is always false. We decrement after yielding the error, moving -// to the Finished -// state. -// -// - Chained Error: When the object iterator is part of an error chain--for -// example, in -// `for (auto tweet : doc["tweets"])`, where the tweet field may be missing or -// not be an -// object--we yield that error in the loop, exactly once. In this state, error -// != SUCCESS and -// iter.depth == depth, and at_start == false. We decrement depth when we -// yield the error. -// - Missing Comma Error: When the iterator ++ method discovers there is no -// comma between fields, -// we flag that as an error and treat it exactly the same as a Chained Error. -// In this state, -// error == TAPE_ERROR, iter.depth == depth, and at_start == false. -// -// Errors that occur while reading a field to give to the user (such as when the -// key is not a -// string or the field is missing a colon) are yielded immediately. Depth is -// then decremented, -// moving to the Finished state without transitioning through an Error state at -// all. -// -// ## Terminal State -// -// The terminal state has iter.depth < depth. at_start is always false. -// -// - Finished: When we have reached a }, we are finished. We signal this by -// decrementing depth. -// In this state, iter.depth < depth, at_start == false, and error == SUCCESS. -// - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -simdjson_really_inline -simdjson_result:: - simdjson_result(SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object_iterator - &&value) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object_iterator>( - std::forward< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object_iterator>( - value)) { - first.iter.assert_is_valid(); -} -simdjson_really_inline -simdjson_result:: - simdjson_result(error_code error) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object_iterator>({}, - error) {} - -simdjson_really_inline - simdjson_result - simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object_iterator>:: - operator*() noexcept { - if (error()) { - return error(); - } - return *first; -} -// If we're iterating and there is an error, return the error once. -simdjson_really_inline bool -simdjson_result:: -operator==(const simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object_iterator> &other) - const noexcept { - if (!first.iter.is_valid()) { - return !error(); - } - return first == other.first; -} -// If we're iterating and there is an error, return the error once. -simdjson_really_inline bool -simdjson_result:: -operator!=(const simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object_iterator> &other) - const noexcept { - if (!first.iter.is_valid()) { - return error(); - } - return first != other.first; -} -// Checks for ']' and ',' -simdjson_really_inline - simdjson_result - &simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object_iterator>:: - operator++() noexcept { - // Clear the error if there is one, so we don't yield it twice - if (error()) { - second = SUCCESS; - return *this; - } - ++first; - return *this; -} - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/object_iterator-inl.h */ -/* begin file include/simdjson/generic/ondemand/array-inl.h */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -// -// ### Live States -// -// While iterating or looking up values, depth >= iter->depth. at_start may -// vary. Error is -// always SUCCESS: -// -// - Start: This is the state when the array is first found and the iterator is -// just past the `{`. -// In this state, at_start == true. -// - Next: After we hand a scalar value to the user, or an array/object which -// they then fully -// iterate over, the iterator is at the `,` before the next value (or `]`). In -// this state, -// depth == iter->depth, at_start == false, and error == SUCCESS. -// - Unfinished Business: When we hand an array/object to the user which they do -// not fully -// iterate over, we need to finish that iteration by skipping child values -// until we reach the -// Next state. In this state, depth > iter->depth, at_start == false, and -// error == SUCCESS. -// -// ## Error States -// -// In error states, we will yield exactly one more value before stopping. -// iter->depth == depth -// and at_start is always false. We decrement after yielding the error, moving -// to the Finished -// state. -// -// - Chained Error: When the array iterator is part of an error chain--for -// example, in -// `for (auto tweet : doc["tweets"])`, where the tweet element may be missing -// or not be an -// array--we yield that error in the loop, exactly once. In this state, error -// != SUCCESS and -// iter->depth == depth, and at_start == false. We decrement depth when we -// yield the error. -// - Missing Comma Error: When the iterator ++ method discovers there is no -// comma between elements, -// we flag that as an error and treat it exactly the same as a Chained Error. -// In this state, -// error == TAPE_ERROR, iter->depth == depth, and at_start == false. -// -// ## Terminal State -// -// The terminal state has iter->depth < depth. at_start is always false. -// -// - Finished: When we have reached a `]` or have reported an error, we are -// finished. We signal this -// by decrementing depth. In this state, iter->depth < depth, at_start == -// false, and -// error == SUCCESS. -// - -simdjson_really_inline array::array(const value_iterator &_iter) noexcept - : iter{_iter} {} - -simdjson_really_inline simdjson_result array::start( - value_iterator &iter) noexcept { - // We don't need to know if the array is empty to start iteration, but we do - // want to know if there - // is an error--thus `simdjson_unused`. - simdjson_unused bool has_value; - SIMDJSON_TRY(iter.start_array().get(has_value)); - return array(iter); -} -simdjson_really_inline simdjson_result array::start_root( - value_iterator &iter) noexcept { - simdjson_unused bool has_value; - SIMDJSON_TRY(iter.start_root_array().get(has_value)); - return array(iter); -} -simdjson_really_inline simdjson_result array::started( - value_iterator &iter) noexcept { - bool has_value; - SIMDJSON_TRY(iter.started_array().get(has_value)); - return array(iter); -} - -simdjson_really_inline simdjson_result array::begin() noexcept { -#ifdef SIMDJSON_DEVELOPMENT_CHECKS - if (!iter.is_at_iterator_start()) { - return OUT_OF_ORDER_ITERATION; - } -#endif - return array_iterator(iter); -} -simdjson_really_inline simdjson_result array::end() noexcept { - return array_iterator(iter); -} -simdjson_really_inline error_code array::consume() noexcept { - auto error = iter.json_iter().skip_child(iter.depth() - 1); - if (error) { - iter.abandon(); - } - return error; -} - -simdjson_really_inline simdjson_result -array::raw_json() noexcept { - const uint8_t *starting_point{iter.peek_start()}; - auto error = consume(); - if (error) { - return error; - } - // After 'consume()', we could be left pointing just beyond the document, - // but that - // is ok because we are not going to dereference the final pointer position, - // we just - // use it to compute the length in bytes. - const uint8_t *final_point{iter._json_iter->unsafe_pointer()}; - return std::string_view(reinterpret_cast(starting_point), - size_t(final_point - starting_point)); -} - -SIMDJSON_DISABLE_STRICT_OVERFLOW_WARNING -simdjson_really_inline simdjson_result array::count_elements() & - noexcept { - size_t count{0}; - // Important: we do not consume any of the values. - for (simdjson_unused auto v : *this) { - count++; - } - // The above loop will always succeed, but we want to report errors. - if (iter.error()) { - return iter.error(); - } - // We need to move back at the start because we expect users to iterate - // through - // the array after counting the number of elements. - iter.reset_array(); - return count; -} - -simdjson_really_inline simdjson_result array::is_empty() & noexcept { - bool is_not_empty; - auto error = iter.reset_array().get(is_not_empty); - if (error) { - return error; - } - return !is_not_empty; -} - -inline simdjson_result array::reset() & noexcept { - return iter.reset_array(); -} - -inline simdjson_result array::at_pointer( - std::string_view json_pointer) noexcept { - if (json_pointer[0] != '/') { - return INVALID_JSON_POINTER; - } - json_pointer = json_pointer.substr(1); - // - means "the append position" or "the element after the end of the array" - // We don't support this, because we're returning a real element, not a - // position. - if (json_pointer == "-") { - return INDEX_OUT_OF_BOUNDS; - } - - // Read the array index - size_t array_index = 0; - size_t i; - for (i = 0; i < json_pointer.length() && json_pointer[i] != '/'; i++) { - uint8_t digit = uint8_t(json_pointer[i] - '0'); - // Check for non-digit in array index. If it's there, we're trying to - // get a field in an object - if (digit > 9) { - return INCORRECT_TYPE; - } - array_index = array_index * 10 + digit; - } - - // 0 followed by other digits is invalid - if (i > 1 && json_pointer[0] == '0') { - return INVALID_JSON_POINTER; - } // "JSON pointer array index has other characters after 0" - - // Empty string is invalid; so is a "/" with no digits before it - if (i == 0) { - return INVALID_JSON_POINTER; - } // "Empty string in JSON pointer array index" - // Get the child - auto child = at(array_index); - // If there is an error, it ends here - if (child.error()) { - return child; - } - - // If there is a /, we're not done yet, call recursively. - if (i < json_pointer.length()) { - child = child.at_pointer(json_pointer.substr(i)); - } - return child; -} - -simdjson_really_inline simdjson_result array::at(size_t index) noexcept { - size_t i = 0; - for (auto value : *this) { - if (i == index) { - return value; - } - i++; - } - return INDEX_OUT_OF_BOUNDS; -} - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -simdjson_really_inline -simdjson_result:: - simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array &&value) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array>( - std::forward( - value)) {} -simdjson_really_inline -simdjson_result:: - simdjson_result(error_code error) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array>(error) {} - -simdjson_really_inline - simdjson_result - simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array>::begin() noexcept { - if (error()) { - return error(); - } - return first.begin(); -} -simdjson_really_inline - simdjson_result - simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array>::end() noexcept { - if (error()) { - return error(); - } - return first.end(); -} -simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array>::count_elements() & - noexcept { - if (error()) { - return error(); - } - return first.count_elements(); -} -simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array>::is_empty() & - noexcept { - if (error()) { - return error(); - } - return first.is_empty(); -} -simdjson_really_inline - simdjson_result - simdjson_result::at( - size_t index) noexcept { - if (error()) { - return error(); - } - return first.at(index); -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value> -simdjson_result::at_pointer( - std::string_view json_pointer) noexcept { - if (error()) { - return error(); - } - return first.at_pointer(json_pointer); -} -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/array-inl.h */ -/* begin file include/simdjson/generic/ondemand/document-inl.h */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -simdjson_really_inline document::document( - ondemand::json_iterator &&_iter) noexcept - : iter{std::forward(_iter)} { - logger::log_start_value(iter, "document"); -} - -simdjson_really_inline document document::start(json_iterator &&iter) noexcept { - return document(std::forward(iter)); -} - -inline void document::rewind() noexcept { iter.rewind(); } - -inline std::string document::to_debug_string() noexcept { - return iter.to_string(); -} - -inline simdjson_result document::current_location() noexcept { - return iter.current_location(); -} - -inline bool document::is_alive() noexcept { return iter.is_alive(); } -simdjson_really_inline value_iterator -document::resume_value_iterator() noexcept { - return value_iterator(&iter, 1, iter.root_position()); -} -simdjson_really_inline value_iterator -document::get_root_value_iterator() noexcept { - return resume_value_iterator(); -} -simdjson_really_inline simdjson_result -document::start_or_resume_object() noexcept { - if (iter.at_root()) { - return get_object(); - } else { - return object::resume(resume_value_iterator()); - } -} -simdjson_really_inline simdjson_result document::get_value() noexcept { - // Make sure we start any arrays or objects before returning, so that - // start_root_() - // gets called. - iter.assert_at_document_depth(); - switch (*iter.peek()) { - case '[': - case '{': - return value(get_root_value_iterator()); - default: - // Unfortunately, scalar documents are a special case in simdjson - // and they cannot - // be safely converted to value instances. - return SCALAR_DOCUMENT_AS_VALUE; - // return value(get_root_value_iterator()); - } -} -simdjson_really_inline simdjson_result document::get_array() & noexcept { - auto value = get_root_value_iterator(); - return array::start_root(value); -} -simdjson_really_inline simdjson_result document::get_object() & - noexcept { - auto value = get_root_value_iterator(); - return object::start_root(value); -} -simdjson_really_inline simdjson_result -document::get_uint64() noexcept { - return get_root_value_iterator().get_root_uint64(); -} -simdjson_really_inline simdjson_result -document::get_uint64_in_string() noexcept { - return get_root_value_iterator().get_root_uint64_in_string(); -} -simdjson_really_inline simdjson_result document::get_int64() noexcept { - return get_root_value_iterator().get_root_int64(); -} -simdjson_really_inline simdjson_result -document::get_int64_in_string() noexcept { - return get_root_value_iterator().get_root_int64_in_string(); -} -simdjson_really_inline simdjson_result document::get_double() noexcept { - return get_root_value_iterator().get_root_double(); -} -simdjson_really_inline simdjson_result -document::get_double_in_string() noexcept { - return get_root_value_iterator().get_root_double_in_string(); -} -simdjson_really_inline simdjson_result -document::get_string() noexcept { - return get_root_value_iterator().get_root_string(); -} -simdjson_really_inline simdjson_result -document::get_raw_json_string() noexcept { - return get_root_value_iterator().get_root_raw_json_string(); -} -simdjson_really_inline simdjson_result document::get_bool() noexcept { - return get_root_value_iterator().get_root_bool(); -} -simdjson_really_inline bool document::is_null() noexcept { - return get_root_value_iterator().is_root_null(); -} - -template <> - simdjson_really_inline simdjson_result document::get() & noexcept { - return get_array(); -} -template <> - simdjson_really_inline simdjson_result document::get() & noexcept { - return get_object(); -} -template <> - simdjson_really_inline simdjson_result document::get() & - noexcept { - return get_raw_json_string(); -} -template <> - simdjson_really_inline simdjson_result document::get() & - noexcept { - return get_string(); -} -template <> - simdjson_really_inline simdjson_result document::get() & noexcept { - return get_double(); -} -template <> - simdjson_really_inline simdjson_result document::get() & - noexcept { - return get_uint64(); -} -template <> - simdjson_really_inline simdjson_result document::get() & noexcept { - return get_int64(); -} -template <> - simdjson_really_inline simdjson_result document::get() & noexcept { - return get_bool(); -} -template <> - simdjson_really_inline simdjson_result document::get() & noexcept { - return get_value(); -} - -template <> - simdjson_really_inline simdjson_result document::get() && - noexcept { - return get_raw_json_string(); -} -template <> - simdjson_really_inline simdjson_result document::get() && - noexcept { - return get_string(); -} -template <> - simdjson_really_inline simdjson_result document::get() && noexcept { - return std::forward(*this).get_double(); -} -template <> - simdjson_really_inline simdjson_result document::get() && - noexcept { - return std::forward(*this).get_uint64(); -} -template <> - simdjson_really_inline simdjson_result document::get() && - noexcept { - return std::forward(*this).get_int64(); -} -template <> - simdjson_really_inline simdjson_result document::get() && noexcept { - return std::forward(*this).get_bool(); -} -template <> - simdjson_really_inline simdjson_result document::get() && noexcept { - return get_value(); -} - -template - simdjson_really_inline error_code document::get(T &out) & noexcept { - return get().get(out); -} -template - simdjson_really_inline error_code document::get(T &out) && noexcept { - return std::forward(*this).get().get(out); -} - -#if SIMDJSON_EXCEPTIONS -simdjson_really_inline document::operator array() & noexcept(false) { - return get_array(); -} -simdjson_really_inline document::operator object() & noexcept(false) { - return get_object(); -} -simdjson_really_inline document::operator uint64_t() noexcept(false) { - return get_uint64(); -} -simdjson_really_inline document::operator int64_t() noexcept(false) { - return get_int64(); -} -simdjson_really_inline document::operator double() noexcept(false) { - return get_double(); -} -simdjson_really_inline document::operator std::string_view() noexcept(false) { - return get_string(); -} -simdjson_really_inline document::operator raw_json_string() noexcept(false) { - return get_raw_json_string(); -} -simdjson_really_inline document::operator bool() noexcept(false) { - return get_bool(); -} -simdjson_really_inline document::operator value() noexcept(false) { - return get_value(); -} - -#endif -simdjson_really_inline simdjson_result document::count_elements() & - noexcept { - auto a = get_array(); - simdjson_result answer = a.count_elements(); - /* If there was an array, we are now left pointing at its first element. */ - if (answer.error() == SUCCESS) { - iter._depth = - 1; /* undoing the increment so we go back at the doc depth.*/ - iter.assert_at_document_depth(); - } - return answer; -} -simdjson_really_inline simdjson_result document::count_fields() & - noexcept { - auto a = get_object(); - simdjson_result answer = a.count_fields(); - /* If there was an array, we are now left pointing at its first element. */ - if (answer.error() == SUCCESS) { - iter._depth = - 1; /* undoing the increment so we go back at the doc depth.*/ - iter.assert_at_document_depth(); - } - return answer; -} -simdjson_really_inline simdjson_result document::at(size_t index) & - noexcept { - auto a = get_array(); - return a.at(index); -} -simdjson_really_inline simdjson_result document::begin() & - noexcept { - return get_array().begin(); -} -simdjson_really_inline simdjson_result document::end() & - noexcept { - return {}; -} - -simdjson_really_inline simdjson_result document::find_field( - std::string_view key) & - noexcept { - return start_or_resume_object().find_field(key); -} -simdjson_really_inline simdjson_result document::find_field( - const char *key) & - noexcept { - return start_or_resume_object().find_field(key); -} -simdjson_really_inline simdjson_result document::find_field_unordered( - std::string_view key) & - noexcept { - return start_or_resume_object().find_field_unordered(key); -} -simdjson_really_inline simdjson_result document::find_field_unordered( - const char *key) & - noexcept { - return start_or_resume_object().find_field_unordered(key); -} -simdjson_really_inline simdjson_result document::operator[]( - std::string_view key) & - noexcept { - return start_or_resume_object()[key]; -} -simdjson_really_inline simdjson_result document::operator[]( - const char *key) & - noexcept { - return start_or_resume_object()[key]; -} - -simdjson_really_inline error_code document::consume() noexcept { - auto error = iter.skip_child(0); - if (error) { - iter.abandon(); - } - return error; -} - -simdjson_really_inline simdjson_result -document::raw_json() noexcept { - auto _iter = get_root_value_iterator(); - const uint8_t *starting_point{_iter.peek_start()}; - auto error = consume(); - if (error) { - return error; - } - // After 'consume()', we could be left pointing just beyond the document, - // but that - // is ok because we are not going to dereference the final pointer position, - // we just - // use it to compute the length in bytes. - const uint8_t *final_point{iter.unsafe_pointer()}; - return std::string_view(reinterpret_cast(starting_point), - size_t(final_point - starting_point)); -} - -simdjson_really_inline simdjson_result document::type() noexcept { - return get_root_value_iterator().type(); -} - -simdjson_really_inline simdjson_result document::is_scalar() noexcept { - json_type this_type; - auto error = type().get(this_type); - if (error) { - return error; - } - return !((this_type == json_type::array) || - (this_type == json_type::object)); -} - -simdjson_really_inline bool document::is_negative() noexcept { - return get_root_value_iterator().is_root_negative(); -} - -simdjson_really_inline simdjson_result document::is_integer() noexcept { - return get_root_value_iterator().is_root_integer(); -} - -simdjson_really_inline simdjson_result -document::get_number_type() noexcept { - return get_root_value_iterator().get_root_number_type(); -} - -simdjson_really_inline simdjson_result document::get_number() noexcept { - return get_root_value_iterator().get_root_number(); -} - - -simdjson_really_inline simdjson_result -document::raw_json_token() noexcept { - auto _iter = get_root_value_iterator(); - return std::string_view(reinterpret_cast(_iter.peek_start()), - _iter.peek_start_length()); -} - -simdjson_really_inline simdjson_result document::at_pointer( - std::string_view json_pointer) noexcept { - rewind(); // Rewind the document each time at_pointer is called - if (json_pointer.empty()) { - return this->get_value(); - } - json_type t; - SIMDJSON_TRY(type().get(t)); - switch (t) { - case json_type::array: - return (*this).get_array().at_pointer(json_pointer); - case json_type::object: - return (*this).get_object().at_pointer(json_pointer); - default: - return INVALID_JSON_POINTER; - } -} - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -simdjson_really_inline -simdjson_result:: - simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document &&value) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>( - std::forward( - value)) {} -simdjson_really_inline -simdjson_result:: - simdjson_result(error_code error) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>(error) {} -simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>::count_elements() & - noexcept { - if (error()) { - return error(); - } - return first.count_elements(); -} -simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>::count_fields() & - noexcept { - if (error()) { - return error(); - } - return first.count_fields(); -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value> - simdjson_result::at( - size_t index) & - noexcept { - if (error()) { - return error(); - } - return first.at(index); -} -simdjson_really_inline error_code simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>::rewind() noexcept { - if (error()) { - return error(); - } - first.rewind(); - return SUCCESS; -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator> - simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>::begin() & - noexcept { - if (error()) { - return error(); - } - return first.begin(); -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator> - simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>::end() & - noexcept { - return {}; -} -simdjson_really_inline - simdjson_result - simdjson_result:: - find_field_unordered(std::string_view key) & - noexcept { - if (error()) { - return error(); - } - return first.find_field_unordered(key); -} -simdjson_really_inline - simdjson_result - simdjson_result:: - find_field_unordered(const char *key) & - noexcept { - if (error()) { - return error(); - } - return first.find_field_unordered(key); -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value> - simdjson_result:: - operator[](std::string_view key) & - noexcept { - if (error()) { - return error(); - } - return first[key]; -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value> - simdjson_result:: - operator[](const char *key) & - noexcept { - if (error()) { - return error(); - } - return first[key]; -} -simdjson_really_inline - simdjson_result - simdjson_result:: - find_field(std::string_view key) & - noexcept { - if (error()) { - return error(); - } - return first.find_field(key); -} -simdjson_really_inline - simdjson_result - simdjson_result:: - find_field(const char *key) & - noexcept { - if (error()) { - return error(); - } - return first.find_field(key); -} -simdjson_really_inline - simdjson_result - simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>::get_array() & - noexcept { - if (error()) { - return error(); - } - return first.get_array(); -} -simdjson_really_inline - simdjson_result - simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>::get_object() & - noexcept { - if (error()) { - return error(); - } - return first.get_object(); -} -simdjson_really_inline simdjson_result -simdjson_result:: - get_uint64() noexcept { - if (error()) { - return error(); - } - return first.get_uint64(); -} -simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>::get_int64() noexcept { - if (error()) { - return error(); - } - return first.get_int64(); -} -simdjson_really_inline simdjson_result -simdjson_result:: - get_double() noexcept { - if (error()) { - return error(); - } - return first.get_double(); -} -simdjson_really_inline simdjson_result -simdjson_result:: - get_string() noexcept { - if (error()) { - return error(); - } - return first.get_string(); -} -simdjson_really_inline - simdjson_result - simdjson_result:: - get_raw_json_string() noexcept { - if (error()) { - return error(); - } - return first.get_raw_json_string(); -} -simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>::get_bool() noexcept { - if (error()) { - return error(); - } - return first.get_bool(); -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value> -simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>::get_value() noexcept { - if (error()) { - return error(); - } - return first.get_value(); -} -simdjson_really_inline bool simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>::is_null() noexcept { - if (error()) { - return error(); - } - return first.is_null(); -} - -template - simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>::get() & - noexcept { - if (error()) { - return error(); - } - return first.get(); -} -template - simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>::get() && - noexcept { - if (error()) { - return error(); - } - return std::forward( - first) - .get(); -} -template - simdjson_really_inline error_code - simdjson_result::get( - T &out) & - noexcept { - if (error()) { - return error(); - } - return first.get(out); -} -template - simdjson_really_inline error_code - simdjson_result::get( - T &out) && - noexcept { - if (error()) { - return error(); - } - return std::forward( - first) - .get(out); -} - -template <> - simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document> - simdjson_result::get< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>() & - noexcept = delete; -template <> - simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document> - simdjson_result::get< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>() && - noexcept { - if (error()) { - return error(); - } - return std::forward( - first); -} -template <> - simdjson_really_inline error_code - simdjson_result::get< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document &out) & - noexcept = delete; -template <> - simdjson_really_inline error_code - simdjson_result::get< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document &out) && - noexcept { - if (error()) { - return error(); - } - out = std::forward( - first); - return SUCCESS; -} - -simdjson_really_inline - simdjson_result - simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>::type() noexcept { - if (error()) { - return error(); - } - return first.type(); -} - -simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>::is_scalar() noexcept { - if (error()) { - return error(); - } - return first.is_scalar(); -} - - -simdjson_really_inline bool -simdjson_result:: - is_negative() noexcept { - if (error()) { - return error(); - } - return first.is_negative(); -} - -simdjson_really_inline simdjson_result -simdjson_result:: - is_integer() noexcept { - if (error()) { - return error(); - } - return first.is_integer(); -} - -simdjson_really_inline - simdjson_result - simdjson_result:: - get_number_type() noexcept { - if (error()) { - return error(); - } - return first.get_number_type(); -} - -simdjson_really_inline - simdjson_result - simdjson_result:: - get_number() noexcept { - if (error()) { - return error(); - } - return first.get_number(); -} - - -#if SIMDJSON_EXCEPTIONS -simdjson_really_inline - simdjson_result:: - operator SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array() & - noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline - simdjson_result:: - operator SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object() & - noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline - simdjson_result:: - operator uint64_t() noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline - simdjson_result:: - operator int64_t() noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline - simdjson_result:: - operator double() noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline - simdjson_result:: - operator std::string_view() noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>:: -operator SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::raw_json_string() noexcept( - false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline - simdjson_result:: - operator bool() noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document>:: -operator SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value() noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -#endif - - -simdjson_really_inline simdjson_result -simdjson_result:: - current_location() noexcept { - if (error()) { - return error(); - } - return first.current_location(); -} - -simdjson_really_inline simdjson_result -simdjson_result:: - raw_json_token() noexcept { - if (error()) { - return error(); - } - return first.raw_json_token(); -} - -simdjson_really_inline - simdjson_result - simdjson_result:: - at_pointer(std::string_view json_pointer) noexcept { - if (error()) { - return error(); - } - return first.at_pointer(json_pointer); -} - - -} // namespace simdjson - - -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -simdjson_really_inline document_reference::document_reference() noexcept - : doc{nullptr} {} -simdjson_really_inline document_reference::document_reference( - document &d) noexcept : doc(&d) {} -simdjson_really_inline void document_reference::rewind() noexcept { - doc->rewind(); -} -simdjson_really_inline simdjson_result document_reference::get_array() & - noexcept { - return doc->get_array(); -} -simdjson_really_inline simdjson_result - document_reference::get_object() & noexcept { - return doc->get_object(); -} -simdjson_really_inline simdjson_result -document_reference::get_uint64() noexcept { - return doc->get_uint64(); -} -simdjson_really_inline simdjson_result -document_reference::get_int64() noexcept { - return doc->get_int64(); -} -simdjson_really_inline simdjson_result -document_reference::get_double() noexcept { - return doc->get_double(); -} -simdjson_really_inline simdjson_result -document_reference::get_string() noexcept { - return doc->get_string(); -} -simdjson_really_inline simdjson_result -document_reference::get_raw_json_string() noexcept { - return doc->get_raw_json_string(); -} -simdjson_really_inline simdjson_result -document_reference::get_bool() noexcept { - return doc->get_bool(); -} -simdjson_really_inline simdjson_result -document_reference::get_value() noexcept { - return doc->get_value(); -} -simdjson_really_inline bool document_reference::is_null() noexcept { - return doc->is_null(); -} - -#if SIMDJSON_EXCEPTIONS -simdjson_really_inline document_reference::operator array() & noexcept(false) { - return array(*doc); -} -simdjson_really_inline document_reference::operator object() & noexcept(false) { - return object(*doc); -} -simdjson_really_inline document_reference::operator uint64_t() noexcept(false) { - return uint64_t(*doc); -} -simdjson_really_inline document_reference::operator int64_t() noexcept(false) { - return int64_t(*doc); -} -simdjson_really_inline document_reference::operator double() noexcept(false) { - return double(*doc); -} -simdjson_really_inline document_reference::operator std::string_view() noexcept( - false) { - return std::string_view(*doc); -} -simdjson_really_inline document_reference::operator raw_json_string() noexcept( - false) { - return raw_json_string(*doc); -} -simdjson_really_inline document_reference::operator bool() noexcept(false) { - return bool(*doc); -} -simdjson_really_inline document_reference::operator value() noexcept(false) { - return value(*doc); -} -#endif -simdjson_really_inline simdjson_result - document_reference::count_elements() & noexcept { - return doc->count_elements(); -} -simdjson_really_inline simdjson_result - document_reference::count_fields() & noexcept { - return doc->count_fields(); -} -simdjson_really_inline simdjson_result document_reference::at( - size_t index) & - noexcept { - return doc->at(index); -} -simdjson_really_inline simdjson_result - document_reference::begin() & noexcept { - return doc->begin(); -} -simdjson_really_inline simdjson_result - document_reference::end() & noexcept { - return doc->end(); -} -simdjson_really_inline simdjson_result document_reference::find_field( - std::string_view key) & - noexcept { - return doc->find_field(key); -} -simdjson_really_inline simdjson_result document_reference::find_field( - const char *key) & - noexcept { - return doc->find_field(key); -} -simdjson_really_inline simdjson_result document_reference::operator[]( - std::string_view key) & - noexcept { - return (*doc)[key]; -} -simdjson_really_inline simdjson_result document_reference::operator[]( - const char *key) & - noexcept { - return (*doc)[key]; -} -simdjson_really_inline simdjson_result - document_reference::find_field_unordered(std::string_view key) & noexcept { - return doc->find_field_unordered(key); -} -simdjson_really_inline simdjson_result - document_reference::find_field_unordered(const char *key) & noexcept { - return doc->find_field_unordered(key); -} -simdjson_really_inline simdjson_result -document_reference::type() noexcept { - return doc->type(); -} -simdjson_really_inline simdjson_result -document_reference::is_scalar() noexcept { - return doc->is_scalar(); -} -simdjson_really_inline simdjson_result -document_reference::current_location() noexcept { - return doc->current_location(); -} -simdjson_really_inline bool document_reference::is_negative() noexcept { - return doc->is_negative(); -} -simdjson_really_inline simdjson_result -document_reference::is_integer() noexcept { - return doc->is_integer(); -} -simdjson_really_inline simdjson_result -document_reference::get_number_type() noexcept { - return doc->get_number_type(); -} -simdjson_really_inline simdjson_result -document_reference::get_number() noexcept { - return doc->get_number(); -} -simdjson_really_inline simdjson_result -document_reference::raw_json_token() noexcept { - return doc->raw_json_token(); -} -simdjson_really_inline simdjson_result document_reference::at_pointer( - std::string_view json_pointer) noexcept { - return doc->at_pointer(json_pointer); -} -simdjson_really_inline simdjson_result -document_reference::raw_json() noexcept { - return doc->raw_json(); -} -simdjson_really_inline document_reference::operator document &() const - noexcept { - return *doc; -} - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - - -namespace simdjson { -simdjson_really_inline -simdjson_result:: - simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference value, - error_code error) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference>( - std::forward< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference>( - value), - error) {} - - -simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference>:: - count_elements() & - noexcept { - if (error()) { - return error(); - } - return first.count_elements(); -} -simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference>:: - count_fields() & - noexcept { - if (error()) { - return error(); - } - return first.count_fields(); -} -simdjson_really_inline - simdjson_result - simdjson_result::at(size_t index) & - noexcept { - if (error()) { - return error(); - } - return first.at(index); -} -simdjson_really_inline error_code -simdjson_result:: - rewind() noexcept { - if (error()) { - return error(); - } - first.rewind(); - return SUCCESS; -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator> - simdjson_result::begin() & - noexcept { - if (error()) { - return error(); - } - return first.begin(); -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator> - simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference>::end() & - noexcept { - return {}; -} -simdjson_really_inline - simdjson_result - simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference>:: - find_field_unordered(std::string_view key) & - noexcept { - if (error()) { - return error(); - } - return first.find_field_unordered(key); -} -simdjson_really_inline - simdjson_result - simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference>:: - find_field_unordered(const char *key) & - noexcept { - if (error()) { - return error(); - } - return first.find_field_unordered(key); -} -simdjson_really_inline - simdjson_result - simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference>:: - operator[](std::string_view key) & - noexcept { - if (error()) { - return error(); - } - return first[key]; -} -simdjson_really_inline - simdjson_result - simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference>:: - operator[](const char *key) & - noexcept { - if (error()) { - return error(); - } - return first[key]; -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value> - simdjson_result::find_field(std::string_view key) & - noexcept { - if (error()) { - return error(); - } - return first.find_field(key); -} -simdjson_really_inline - simdjson_result - simdjson_result::find_field(const char *key) & - noexcept { - if (error()) { - return error(); - } - return first.find_field(key); -} -simdjson_really_inline - simdjson_result - simdjson_result::get_array() & - noexcept { - if (error()) { - return error(); - } - return first.get_array(); -} -simdjson_really_inline - simdjson_result - simdjson_result::get_object() & - noexcept { - if (error()) { - return error(); - } - return first.get_object(); -} -simdjson_really_inline simdjson_result -simdjson_result:: - get_uint64() noexcept { - if (error()) { - return error(); - } - return first.get_uint64(); -} -simdjson_really_inline simdjson_result -simdjson_result:: - get_int64() noexcept { - if (error()) { - return error(); - } - return first.get_int64(); -} -simdjson_really_inline simdjson_result -simdjson_result:: - get_double() noexcept { - if (error()) { - return error(); - } - return first.get_double(); -} -simdjson_really_inline simdjson_result -simdjson_result:: - get_string() noexcept { - if (error()) { - return error(); - } - return first.get_string(); -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::raw_json_string> -simdjson_result:: - get_raw_json_string() noexcept { - if (error()) { - return error(); - } - return first.get_raw_json_string(); -} -simdjson_really_inline simdjson_result -simdjson_result:: - get_bool() noexcept { - if (error()) { - return error(); - } - return first.get_bool(); -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value> -simdjson_result:: - get_value() noexcept { - if (error()) { - return error(); - } - return first.get_value(); -} -simdjson_really_inline bool -simdjson_result:: - is_null() noexcept { - if (error()) { - return error(); - } - return first.is_null(); -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::json_type> -simdjson_result:: - type() noexcept { - if (error()) { - return error(); - } - return first.type(); -} -simdjson_really_inline simdjson_result -simdjson_result:: - is_scalar() noexcept { - if (error()) { - return error(); - } - return first.is_scalar(); -} -simdjson_really_inline bool -simdjson_result:: - is_negative() noexcept { - if (error()) { - return error(); - } - return first.is_negative(); -} -simdjson_really_inline simdjson_result -simdjson_result:: - is_integer() noexcept { - if (error()) { - return error(); - } - return first.is_integer(); -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::number_type> -simdjson_result:: - get_number_type() noexcept { - if (error()) { - return error(); - } - return first.get_number_type(); -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::number> -simdjson_result:: - get_number() noexcept { - if (error()) { - return error(); - } - return first.get_number(); -} -#if SIMDJSON_EXCEPTIONS -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference>:: - operator SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array() & - noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference>:: - operator SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object() & - noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference>:: -operator uint64_t() noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference>:: -operator int64_t() noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference>:: -operator double() noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference>:: -operator std::string_view() noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference>:: -operator SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::raw_json_string() noexcept( - false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference>:: -operator bool() noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference>:: -operator SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value() noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -#endif - -simdjson_really_inline simdjson_result -simdjson_result:: - current_location() noexcept { - if (error()) { - return error(); - } - return first.current_location(); -} - -simdjson_really_inline simdjson_result -simdjson_result:: - raw_json_token() noexcept { - if (error()) { - return error(); - } - return first.raw_json_token(); -} - -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value> -simdjson_result:: - at_pointer(std::string_view json_pointer) noexcept { - if (error()) { - return error(); - } - return first.at_pointer(json_pointer); -} - - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/document-inl.h */ -/* begin file include/simdjson/generic/ondemand/value-inl.h */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -simdjson_really_inline value::value(const value_iterator &_iter) noexcept - : iter{_iter} {} -simdjson_really_inline value value::start(const value_iterator &iter) noexcept { - return iter; -} -simdjson_really_inline value -value::resume(const value_iterator &iter) noexcept { - return iter; -} - -simdjson_really_inline simdjson_result value::get_array() noexcept { - return array::start(iter); -} -simdjson_really_inline simdjson_result value::get_object() noexcept { - return object::start(iter); -} -simdjson_really_inline simdjson_result -value::start_or_resume_object() noexcept { - if (iter.at_start()) { - return get_object(); - } else { - return object::resume(iter); - } -} - -simdjson_really_inline simdjson_result -value::get_raw_json_string() noexcept { - return iter.get_raw_json_string(); -} -simdjson_really_inline simdjson_result -value::get_string() noexcept { - return iter.get_string(); -} -simdjson_really_inline simdjson_result value::get_double() noexcept { - return iter.get_double(); -} -simdjson_really_inline simdjson_result -value::get_double_in_string() noexcept { - return iter.get_double_in_string(); -} -simdjson_really_inline simdjson_result value::get_uint64() noexcept { - return iter.get_uint64(); -} -simdjson_really_inline simdjson_result -value::get_uint64_in_string() noexcept { - return iter.get_uint64_in_string(); -} -simdjson_really_inline simdjson_result value::get_int64() noexcept { - return iter.get_int64(); -} -simdjson_really_inline simdjson_result -value::get_int64_in_string() noexcept { - return iter.get_int64_in_string(); -} -simdjson_really_inline simdjson_result value::get_bool() noexcept { - return iter.get_bool(); -} -simdjson_really_inline bool value::is_null() noexcept { return iter.is_null(); } - -template <> -simdjson_really_inline simdjson_result value::get() noexcept { - return get_array(); -} -template <> -simdjson_really_inline simdjson_result value::get() noexcept { - return get_object(); -} -template <> -simdjson_really_inline simdjson_result value::get() noexcept { - return get_raw_json_string(); -} -template <> -simdjson_really_inline simdjson_result value::get() noexcept { - return get_string(); -} -template <> -simdjson_really_inline simdjson_result value::get() noexcept { - return get_number(); -} -template <> -simdjson_really_inline simdjson_result value::get() noexcept { - return get_double(); -} -template <> -simdjson_really_inline simdjson_result value::get() noexcept { - return get_uint64(); -} -template <> -simdjson_really_inline simdjson_result value::get() noexcept { - return get_int64(); -} -template <> -simdjson_really_inline simdjson_result value::get() noexcept { - return get_bool(); -} - -template -simdjson_really_inline error_code value::get(T &out) noexcept { - return get().get(out); -} - -#if SIMDJSON_EXCEPTIONS -simdjson_really_inline value::operator array() noexcept(false) { - return get_array(); -} -simdjson_really_inline value::operator object() noexcept(false) { - return get_object(); -} -simdjson_really_inline value::operator uint64_t() noexcept(false) { - return get_uint64(); -} -simdjson_really_inline value::operator int64_t() noexcept(false) { - return get_int64(); -} -simdjson_really_inline value::operator double() noexcept(false) { - return get_double(); -} -simdjson_really_inline value::operator std::string_view() noexcept(false) { - return get_string(); -} -simdjson_really_inline value::operator raw_json_string() noexcept(false) { - return get_raw_json_string(); -} -simdjson_really_inline value::operator bool() noexcept(false) { - return get_bool(); -} -#endif - -simdjson_really_inline simdjson_result value::begin() & - noexcept { - return get_array().begin(); -} -simdjson_really_inline simdjson_result value::end() & noexcept { - return {}; -} -simdjson_really_inline simdjson_result value::count_elements() & - noexcept { - simdjson_result answer; - auto a = get_array(); - answer = a.count_elements(); - // count_elements leaves you pointing inside the array, at the first - // element. - // We need to move back so that the user can create a new array (which - // requires that - // we point at '['). - iter.move_at_start(); - return answer; -} -simdjson_really_inline simdjson_result value::count_fields() & - noexcept { - simdjson_result answer; - auto a = get_object(); - answer = a.count_fields(); - iter.move_at_start(); - return answer; -} -simdjson_really_inline simdjson_result value::at(size_t index) noexcept { - auto a = get_array(); - return a.at(index); -} - -simdjson_really_inline simdjson_result value::find_field( - std::string_view key) noexcept { - return start_or_resume_object().find_field(key); -} -simdjson_really_inline simdjson_result value::find_field( - const char *key) noexcept { - return start_or_resume_object().find_field(key); -} - -simdjson_really_inline simdjson_result value::find_field_unordered( - std::string_view key) noexcept { - return start_or_resume_object().find_field_unordered(key); -} -simdjson_really_inline simdjson_result value::find_field_unordered( - const char *key) noexcept { - return start_or_resume_object().find_field_unordered(key); -} - -simdjson_really_inline simdjson_result value::operator[]( - std::string_view key) noexcept { - return start_or_resume_object()[key]; -} -simdjson_really_inline simdjson_result value::operator[]( - const char *key) noexcept { - return start_or_resume_object()[key]; -} - -simdjson_really_inline simdjson_result value::type() noexcept { - return iter.type(); -} - -simdjson_really_inline simdjson_result value::is_scalar() noexcept { - json_type this_type; - auto error = type().get(this_type); - if (error) { - return error; - } - return !((this_type == json_type::array) || - (this_type == json_type::object)); -} - -simdjson_really_inline bool value::is_negative() noexcept { - return iter.is_negative(); -} - -simdjson_really_inline simdjson_result value::is_integer() noexcept { - return iter.is_integer(); -} -simdjson_warn_unused simdjson_really_inline simdjson_result -value::get_number_type() noexcept { - return iter.get_number_type(); -} -simdjson_warn_unused simdjson_really_inline simdjson_result -value::get_number() noexcept { - return iter.get_number(); -} - -simdjson_really_inline std::string_view value::raw_json_token() noexcept { - return std::string_view(reinterpret_cast(iter.peek_start()), - iter.peek_start_length()); -} - -simdjson_really_inline simdjson_result -value::current_location() noexcept { - return iter.json_iter().current_location(); -} - -simdjson_really_inline simdjson_result value::at_pointer( - std::string_view json_pointer) noexcept { - json_type t; - SIMDJSON_TRY(type().get(t)); - switch (t) { - case json_type::array: - return (*this).get_array().at_pointer(json_pointer); - case json_type::object: - return (*this).get_object().at_pointer(json_pointer); - default: - return INVALID_JSON_POINTER; - } -} - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -simdjson_really_inline -simdjson_result:: - simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value &&value) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>( - std::forward( - value)) {} -simdjson_really_inline -simdjson_result:: - simdjson_result(error_code error) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>(error) {} -simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>::count_elements() & - noexcept { - if (error()) { - return error(); - } - return first.count_elements(); -} -simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>::count_fields() & - noexcept { - if (error()) { - return error(); - } - return first.count_fields(); -} -simdjson_really_inline - simdjson_result - simdjson_result::at( - size_t index) noexcept { - if (error()) { - return error(); - } - return first.at(index); -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator> - simdjson_result::begin() & - noexcept { - if (error()) { - return error(); - } - return first.begin(); -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array_iterator> - simdjson_result::end() & - noexcept { - if (error()) { - return error(); - } - return {}; -} - -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value> -simdjson_result::find_field( - std::string_view key) noexcept { - if (error()) { - return error(); - } - return first.find_field(key); -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value> -simdjson_result::find_field( - const char *key) noexcept { - if (error()) { - return error(); - } - return first.find_field(key); -} - -simdjson_really_inline - simdjson_result - simdjson_result:: - find_field_unordered(std::string_view key) noexcept { - if (error()) { - return error(); - } - return first.find_field_unordered(key); -} -simdjson_really_inline - simdjson_result - simdjson_result:: - find_field_unordered(const char *key) noexcept { - if (error()) { - return error(); - } - return first.find_field_unordered(key); -} - -simdjson_really_inline - simdjson_result - simdjson_result:: - operator[](std::string_view key) noexcept { - if (error()) { - return error(); - } - return first[key]; -} -simdjson_really_inline - simdjson_result - simdjson_result:: - operator[](const char *key) noexcept { - if (error()) { - return error(); - } - return first[key]; -} - -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array> -simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>::get_array() noexcept { - if (error()) { - return error(); - } - return first.get_array(); -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object> -simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>::get_object() noexcept { - if (error()) { - return error(); - } - return first.get_object(); -} -simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>::get_uint64() noexcept { - if (error()) { - return error(); - } - return first.get_uint64(); -} -simdjson_really_inline simdjson_result -simdjson_result:: - get_uint64_in_string() noexcept { - if (error()) { - return error(); - } - return first.get_uint64_in_string(); -} -simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>::get_int64() noexcept { - if (error()) { - return error(); - } - return first.get_int64(); -} -simdjson_really_inline simdjson_result -simdjson_result:: - get_int64_in_string() noexcept { - if (error()) { - return error(); - } - return first.get_int64_in_string(); -} -simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>::get_double() noexcept { - if (error()) { - return error(); - } - return first.get_double(); -} -simdjson_really_inline simdjson_result -simdjson_result:: - get_double_in_string() noexcept { - if (error()) { - return error(); - } - return first.get_double_in_string(); -} -simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>::get_string() noexcept { - if (error()) { - return error(); - } - return first.get_string(); -} -simdjson_really_inline - simdjson_result - simdjson_result:: - get_raw_json_string() noexcept { - if (error()) { - return error(); - } - return first.get_raw_json_string(); -} -simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>::get_bool() noexcept { - if (error()) { - return error(); - } - return first.get_bool(); -} -simdjson_really_inline bool simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>::is_null() noexcept { - if (error()) { - return false; - } - return first.is_null(); -} - -template -simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>::get() noexcept { - if (error()) { - return error(); - } - return first.get(); -} -template -simdjson_really_inline error_code -simdjson_result::get( - T &out) noexcept { - if (error()) { - return error(); - } - return first.get(out); -} - -template <> -simdjson_really_inline - simdjson_result - simdjson_result::get< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>() noexcept { - if (error()) { - return error(); - } - return std::move(first); -} -template <> -simdjson_really_inline error_code -simdjson_result::get< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value &out) noexcept { - if (error()) { - return error(); - } - out = first; - return SUCCESS; -} - -simdjson_really_inline - simdjson_result - simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>::type() noexcept { - if (error()) { - return error(); - } - return first.type(); -} -simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>::is_scalar() noexcept { - if (error()) { - return error(); - } - return first.is_scalar(); -} -simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>::is_negative() noexcept { - if (error()) { - return error(); - } - return first.is_negative(); -} -simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>::is_integer() noexcept { - if (error()) { - return error(); - } - return first.is_integer(); -} -simdjson_really_inline - simdjson_result - simdjson_result:: - get_number_type() noexcept { - if (error()) { - return error(); - } - return first.get_number_type(); -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::number> -simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>::get_number() noexcept { - if (error()) { - return error(); - } - return first.get_number(); -} -#if SIMDJSON_EXCEPTIONS -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>:: -operator SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array() noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>:: -operator SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object() noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline - simdjson_result:: - operator uint64_t() noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline - simdjson_result:: - operator int64_t() noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline - simdjson_result:: - operator double() noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline - simdjson_result:: - operator std::string_view() noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value>:: -operator SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::raw_json_string() noexcept( - false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -simdjson_really_inline - simdjson_result:: - operator bool() noexcept(false) { - if (error()) { - throw simdjson_error(error()); - } - return first; -} -#endif - -simdjson_really_inline simdjson_result -simdjson_result:: - raw_json_token() noexcept { - if (error()) { - return error(); - } - return first.raw_json_token(); -} - -simdjson_really_inline simdjson_result -simdjson_result:: - current_location() noexcept { - if (error()) { - return error(); - } - return first.current_location(); -} - -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value> -simdjson_result::at_pointer( - std::string_view json_pointer) noexcept { - if (error()) { - return error(); - } - return first.at_pointer(json_pointer); -} - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/value-inl.h */ -/* begin file include/simdjson/generic/ondemand/field-inl.h */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -// clang 6 doesn't think the default constructor can be noexcept, so we make it -// explicit -simdjson_really_inline field::field() noexcept - : std::pair() {} - -simdjson_really_inline field::field(raw_json_string key, - ondemand::value &&value) noexcept - : std::pair( - key, std::forward(value)) {} - -simdjson_really_inline simdjson_result field::start( - value_iterator &parent_iter) noexcept { - raw_json_string key; - SIMDJSON_TRY(parent_iter.field_key().get(key)); - SIMDJSON_TRY(parent_iter.field_value()); - return field::start(parent_iter, key); -} - -simdjson_really_inline simdjson_result field::start( - const value_iterator &parent_iter, raw_json_string key) noexcept { - return field(key, parent_iter.child()); -} - -simdjson_really_inline simdjson_warn_unused simdjson_result -field::unescaped_key() noexcept { - SIMDJSON_ASSUME(first.buf != nullptr); // We would like to call .alive() - // but Visual Studio won't let us. - simdjson_result answer = - first.unescape(second.iter.string_buf_loc()); - first.consume(); - return answer; -} - -simdjson_really_inline raw_json_string field::key() const noexcept { - SIMDJSON_ASSUME(first.buf != nullptr); // We would like to call .alive() by - // Visual Studio won't let us. - return first; -} - -simdjson_really_inline value &field::value() & noexcept { return second; } - -simdjson_really_inline value field::value() && noexcept { - return std::forward(*this).second; -} - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -simdjson_really_inline -simdjson_result:: - simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::field &&value) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::field>( - std::forward( - value)) {} -simdjson_really_inline -simdjson_result:: - simdjson_result(error_code error) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::field>(error) {} - -simdjson_really_inline - simdjson_result - simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::field>::key() noexcept { - if (error()) { - return error(); - } - return first.key(); -} -simdjson_really_inline simdjson_result -simdjson_result:: - unescaped_key() noexcept { - if (error()) { - return error(); - } - return first.unescaped_key(); -} -simdjson_really_inline - simdjson_result - simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::field>::value() noexcept { - if (error()) { - return error(); - } - return std::move(first.value()); -} - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/field-inl.h */ -/* begin file include/simdjson/generic/ondemand/object-inl.h */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -simdjson_really_inline simdjson_result object::find_field_unordered( - const std::string_view key) & - noexcept { - bool has_value; - SIMDJSON_TRY(iter.find_field_unordered_raw(key).get(has_value)); - if (!has_value) { - return NO_SUCH_FIELD; - } - return value(iter.child()); -} -simdjson_really_inline simdjson_result object::find_field_unordered( - const std::string_view key) && - noexcept { - bool has_value; - SIMDJSON_TRY(iter.find_field_unordered_raw(key).get(has_value)); - if (!has_value) { - return NO_SUCH_FIELD; - } - return value(iter.child()); -} -simdjson_really_inline simdjson_result object::operator[]( - const std::string_view key) & - noexcept { - return find_field_unordered(key); -} -simdjson_really_inline simdjson_result object::operator[]( - const std::string_view key) && - noexcept { - return std::forward(*this).find_field_unordered(key); -} -simdjson_really_inline simdjson_result object::find_field( - const std::string_view key) & - noexcept { - bool has_value; - SIMDJSON_TRY(iter.find_field_raw(key).get(has_value)); - if (!has_value) { - return NO_SUCH_FIELD; - } - return value(iter.child()); -} -simdjson_really_inline simdjson_result object::find_field( - const std::string_view key) && - noexcept { - bool has_value; - SIMDJSON_TRY(iter.find_field_raw(key).get(has_value)); - if (!has_value) { - return NO_SUCH_FIELD; - } - return value(iter.child()); -} - -simdjson_really_inline simdjson_result object::start( - value_iterator &iter) noexcept { - SIMDJSON_TRY(iter.start_object().error()); - return object(iter); -} -simdjson_really_inline simdjson_result object::start_root( - value_iterator &iter) noexcept { - SIMDJSON_TRY(iter.start_root_object().error()); - return object(iter); -} -simdjson_really_inline error_code object::consume() noexcept { - if (iter.is_at_key()) { - /** - * whenever you are pointing at a key, calling skip_child() is - * unsafe because you will hit a string and you will assume that - * it is string value, and this mistake will lead you to make bad - * depth computation. - */ - /** - * We want to 'consume' the key. We could really - * just do _json_iter->return_current_and_advance(); at this - * point, but, for clarity, we will use the high-level API to - * eat the key. We assume that the compiler optimizes away - * most of the work. - */ - simdjson_unused raw_json_string actual_key; - auto error = iter.field_key().get(actual_key); - if (error) { - iter.abandon(); - return error; - }; - // Let us move to the value while we are at it. - if ((error = iter.field_value())) { - iter.abandon(); - return error; - } - } - auto error_skip = iter.json_iter().skip_child(iter.depth() - 1); - if (error_skip) { - iter.abandon(); - } - return error_skip; -} - -simdjson_really_inline simdjson_result -object::raw_json() noexcept { - const uint8_t *starting_point{iter.peek_start()}; - auto error = consume(); - if (error) { - return error; - } - const uint8_t *final_point{iter._json_iter->peek(0)}; - return std::string_view(reinterpret_cast(starting_point), - size_t(final_point - starting_point)); -} - -simdjson_really_inline simdjson_result object::started( - value_iterator &iter) noexcept { - SIMDJSON_TRY(iter.started_object().error()); - return object(iter); -} - -simdjson_really_inline object -object::resume(const value_iterator &iter) noexcept { - return iter; -} - -simdjson_really_inline object::object(const value_iterator &_iter) noexcept - : iter{_iter} {} - -simdjson_really_inline simdjson_result -object::begin() noexcept { -#ifdef SIMDJSON_DEVELOPMENT_CHECKS - if (!iter.is_at_iterator_start()) { - return OUT_OF_ORDER_ITERATION; - } -#endif - return object_iterator(iter); -} -simdjson_really_inline simdjson_result object::end() noexcept { - return object_iterator(iter); -} - -inline simdjson_result object::at_pointer( - std::string_view json_pointer) noexcept { - if (json_pointer[0] != '/') { - return INVALID_JSON_POINTER; - } - json_pointer = json_pointer.substr(1); - size_t slash = json_pointer.find('/'); - std::string_view key = json_pointer.substr(0, slash); - // Grab the child with the given key - simdjson_result child; - - // If there is an escape character in the key, unescape it and then get the - // child. - size_t escape = key.find('~'); - if (escape != std::string_view::npos) { - // Unescape the key - std::string unescaped(key); - do { - switch (unescaped[escape + 1]) { - case '0': - unescaped.replace(escape, 2, "~"); - break; - case '1': - unescaped.replace(escape, 2, "/"); - break; - default: - return INVALID_JSON_POINTER; // "Unexpected ~ escape - // character in JSON - // pointer"); - } - escape = unescaped.find('~', escape + 1); - } while (escape != std::string::npos); - child = find_field(unescaped); // Take note find_field does not - // unescape keys when matching - } else { - child = find_field(key); - } - if (child.error()) { - return child; // we do not continue if there was an error - } - // If there is a /, we have to recurse and look up more of the path - if (slash != std::string_view::npos) { - child = child.at_pointer(json_pointer.substr(slash)); - } - return child; -} - -simdjson_really_inline simdjson_result object::count_fields() & - noexcept { - size_t count{0}; - // Important: we do not consume any of the values. - for (simdjson_unused auto v : *this) { - count++; - } - // The above loop will always succeed, but we want to report errors. - if (iter.error()) { - return iter.error(); - } - // We need to move back at the start because we expect users to iterate - // through - // the object after counting the number of elements. - iter.reset_object(); - return count; -} - -simdjson_really_inline simdjson_result object::is_empty() & noexcept { - bool is_not_empty; - auto error = iter.reset_object().get(is_not_empty); - if (error) { - return error; - } - return !is_not_empty; -} - -simdjson_really_inline simdjson_result object::reset() & noexcept { - return iter.reset_object(); -} - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -simdjson_really_inline -simdjson_result:: - simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object &&value) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object>( - std::forward( - value)) {} -simdjson_really_inline -simdjson_result:: - simdjson_result(error_code error) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object>(error) {} - -simdjson_really_inline - simdjson_result - simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object>::begin() noexcept { - if (error()) { - return error(); - } - return first.begin(); -} -simdjson_really_inline - simdjson_result - simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object>::end() noexcept { - if (error()) { - return error(); - } - return first.end(); -} -simdjson_really_inline - simdjson_result - simdjson_result:: - find_field_unordered(std::string_view key) & - noexcept { - if (error()) { - return error(); - } - return first.find_field_unordered(key); -} -simdjson_really_inline - simdjson_result - simdjson_result:: - find_field_unordered(std::string_view key) && - noexcept { - if (error()) { - return error(); - } - return std::forward( - first) - .find_field_unordered(key); -} -simdjson_really_inline - simdjson_result - simdjson_result:: - operator[](std::string_view key) & - noexcept { - if (error()) { - return error(); - } - return first[key]; -} -simdjson_really_inline - simdjson_result - simdjson_result:: - operator[](std::string_view key) && - noexcept { - if (error()) { - return error(); - } - return std::forward( - first)[key]; -} -simdjson_really_inline - simdjson_result - simdjson_result:: - find_field(std::string_view key) & - noexcept { - if (error()) { - return error(); - } - return first.find_field(key); -} -simdjson_really_inline - simdjson_result - simdjson_result:: - find_field(std::string_view key) && - noexcept { - if (error()) { - return error(); - } - return std::forward( - first) - .find_field(key); -} - -simdjson_really_inline simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value> -simdjson_result::at_pointer( - std::string_view json_pointer) noexcept { - if (error()) { - return error(); - } - return first.at_pointer(json_pointer); -} - -inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object>::reset() noexcept { - if (error()) { - return error(); - } - return first.reset(); -} - -inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object>::is_empty() noexcept { - if (error()) { - return error(); - } - return first.is_empty(); -} - -simdjson_really_inline simdjson_result simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object>::count_fields() & - noexcept { - if (error()) { - return error(); - } - return first.count_fields(); -} - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/object-inl.h */ -/* begin file include/simdjson/generic/ondemand/parser-inl.h */ -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -simdjson_really_inline parser::parser(size_t max_capacity) noexcept - : _max_capacity{max_capacity} {} - -simdjson_warn_unused simdjson_really_inline error_code -parser::allocate(size_t new_capacity, size_t new_max_depth) noexcept { - if (new_capacity > max_capacity()) { - return CAPACITY; - } - if (string_buf && new_capacity == capacity() && - new_max_depth == max_depth()) { - return SUCCESS; - } - - // string_capacity copied from document::allocate - _capacity = 0; - size_t string_capacity = - SIMDJSON_ROUNDUP_N(5 * new_capacity / 3 + SIMDJSON_PADDING, 64); - string_buf.reset(new (std::nothrow) uint8_t[string_capacity]); -#ifdef SIMDJSON_DEVELOPMENT_CHECKS - start_positions.reset(new (std::nothrow) token_position[new_max_depth]); -#endif - if (implementation) { - SIMDJSON_TRY(implementation->set_capacity(new_capacity)); - SIMDJSON_TRY(implementation->set_max_depth(new_max_depth)); - } else { - SIMDJSON_TRY(simdjson::get_active_implementation() - ->create_dom_parser_implementation( - new_capacity, new_max_depth, implementation)); - } - _capacity = new_capacity; - _max_depth = new_max_depth; - return SUCCESS; -} - -simdjson_warn_unused simdjson_really_inline simdjson_result - parser::iterate(padded_string_view json) & noexcept { - if (json.padding() < SIMDJSON_PADDING) { - return INSUFFICIENT_PADDING; - } - - // Allocate if needed - if (capacity() < json.length() || !string_buf) { - SIMDJSON_TRY(allocate(json.length(), max_depth())); - } - - // Run stage 1. - SIMDJSON_TRY( - implementation->stage1(reinterpret_cast(json.data()), - json.length(), - stage1_mode::regular)); - return document::start( - {reinterpret_cast(json.data()), this}); -} - -simdjson_warn_unused simdjson_really_inline simdjson_result - parser::iterate(const char *json, size_t len, size_t allocated) & noexcept { - return iterate(padded_string_view(json, len, allocated)); -} - -simdjson_warn_unused simdjson_really_inline simdjson_result - parser::iterate(const uint8_t *json, size_t len, size_t allocated) & - noexcept { - return iterate(padded_string_view(json, len, allocated)); -} - -simdjson_warn_unused simdjson_really_inline simdjson_result - parser::iterate(std::string_view json, size_t allocated) & noexcept { - return iterate(padded_string_view(json, allocated)); -} - -simdjson_warn_unused simdjson_really_inline simdjson_result - parser::iterate(const std::string &json) & noexcept { - return iterate(padded_string_view(json)); -} - -simdjson_warn_unused simdjson_really_inline simdjson_result - parser::iterate(const simdjson_result &result) & - noexcept { - // We don't presently have a way to temporarily get a const T& from a - // simdjson_result without throwing an exception - SIMDJSON_TRY(result.error()); - padded_string_view json = result.value_unsafe(); - return iterate(json); -} - -simdjson_warn_unused simdjson_really_inline simdjson_result - parser::iterate(const simdjson_result &result) & noexcept { - // We don't presently have a way to temporarily get a const T& from a - // simdjson_result without throwing an exception - SIMDJSON_TRY(result.error()); - const padded_string &json = result.value_unsafe(); - return iterate(json); -} - -simdjson_warn_unused simdjson_really_inline simdjson_result - parser::iterate_raw(padded_string_view json) & noexcept { - if (json.padding() < SIMDJSON_PADDING) { - return INSUFFICIENT_PADDING; - } - - // Allocate if needed - if (capacity() < json.length()) { - SIMDJSON_TRY(allocate(json.length(), max_depth())); - } - - // Run stage 1. - SIMDJSON_TRY( - implementation->stage1(reinterpret_cast(json.data()), - json.length(), - stage1_mode::regular)); - return json_iterator(reinterpret_cast(json.data()), this); -} - -inline simdjson_result parser::iterate_many( - const uint8_t *buf, size_t len, size_t batch_size) noexcept { - if (batch_size < MINIMAL_BATCH_SIZE) { - batch_size = MINIMAL_BATCH_SIZE; - } - return document_stream(*this, buf, len, batch_size); -} -inline simdjson_result parser::iterate_many( - const char *buf, size_t len, size_t batch_size) noexcept { - return iterate_many( - reinterpret_cast(buf), len, batch_size); -} -inline simdjson_result parser::iterate_many( - const std::string &s, size_t batch_size) noexcept { - return iterate_many(s.data(), s.length(), batch_size); -} -inline simdjson_result parser::iterate_many( - const padded_string &s, size_t batch_size) noexcept { - return iterate_many(s.data(), s.length(), batch_size); -} - -simdjson_really_inline size_t parser::capacity() const noexcept { - return _capacity; -} -simdjson_really_inline size_t parser::max_capacity() const noexcept { - return _max_capacity; -} -simdjson_really_inline size_t parser::max_depth() const noexcept { - return _max_depth; -} - -simdjson_really_inline void parser::set_max_capacity( - size_t max_capacity) noexcept { - size_t MINIMAL_DOCUMENT_CAPACITY = 32; - if (max_capacity < MINIMAL_DOCUMENT_CAPACITY) { - _max_capacity = max_capacity; - } else { - _max_capacity = MINIMAL_DOCUMENT_CAPACITY; - } -} - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -simdjson_really_inline -simdjson_result:: - simdjson_result( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::parser &&value) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::parser>( - std::forward( - value)) {} -simdjson_really_inline -simdjson_result:: - simdjson_result(error_code error) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::parser>(error) {} - -} // namespace simdjson -/* end file include/simdjson/generic/ondemand/parser-inl.h */ -/* begin file include/simdjson/generic/ondemand/document_stream-inl.h */ -#include -#include -#include -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -#ifdef SIMDJSON_THREADS_ENABLED - -inline void stage1_worker::finish() { - // After calling "run" someone would call finish() to wait - // for the end of the processing. - // This function will wait until either the thread has done - // the processing or, else, the destructor has been called. - std::unique_lock lock(locking_mutex); - cond_var.wait(lock, [this] { return has_work == false; }); -} - -inline stage1_worker::~stage1_worker() { - // The thread may never outlive the stage1_worker instance - // and will always be stopped/joined before the stage1_worker - // instance is gone. - stop_thread(); -} - -inline void stage1_worker::start_thread() { - std::unique_lock lock(locking_mutex); - if (thread.joinable()) { - return; // This should never happen but we never want to create more - // than one thread. - } - thread = std::thread([this] { - while (true) { - std::unique_lock thread_lock(locking_mutex); - // We wait for either "run" or "stop_thread" to be called. - cond_var.wait(thread_lock, - [this] { return has_work || !can_work; }); - // If, for some reason, the stop_thread() method was called (i.e., - // the - // destructor of stage1_worker is called, then we want to - // immediately destroy - // the thread (and not do any more processing). - if (!can_work) { - break; - } - this->owner->stage1_thread_error = this->owner->run_stage1( - *this->stage1_thread_parser, this->_next_batch_start); - this->has_work = false; - // The condition variable call should be moved after - // thread_lock.unlock() for performance - // reasons but thread sanitizers may report it as a data race if we - // do. - // See - // https://stackoverflow.com/questions/35775501/c-should-condition-variable-be-notified-under-lock - cond_var.notify_one(); // will notify "finish" - thread_lock.unlock(); - } - }); -} - - -inline void stage1_worker::stop_thread() { - std::unique_lock lock(locking_mutex); - // We have to make sure that all locks can be released. - can_work = false; - has_work = false; - cond_var.notify_all(); - lock.unlock(); - if (thread.joinable()) { - thread.join(); - } -} - -inline void stage1_worker::run(document_stream *ds, - parser *stage1, - size_t next_batch_start) { - std::unique_lock lock(locking_mutex); - owner = ds; - _next_batch_start = next_batch_start; - stage1_thread_parser = stage1; - has_work = true; - // The condition variable call should be moved after thread_lock.unlock() - // for performance - // reasons but thread sanitizers may report it as a data race if we do. - // See - // https://stackoverflow.com/questions/35775501/c-should-condition-variable-be-notified-under-lock - cond_var.notify_one(); // will notify the thread lock that we have work - lock.unlock(); -} - -#endif // SIMDJSON_THREADS_ENABLED - -simdjson_really_inline document_stream::document_stream( - ondemand::parser &_parser, - const uint8_t *_buf, - size_t _len, - size_t _batch_size) noexcept - : parser{&_parser}, - buf{_buf}, - len{_len}, - batch_size{_batch_size <= MINIMAL_BATCH_SIZE ? MINIMAL_BATCH_SIZE - : _batch_size}, - error { - SUCCESS -} -#ifdef SIMDJSON_THREADS_ENABLED -, use_thread(_parser.threaded) // we need to make a copy because - // _parser.threaded can change -#endif -{ -#ifdef SIMDJSON_THREADS_ENABLED - if (worker.get() == nullptr) { - error = MEMALLOC; - } -#endif -} - -simdjson_really_inline document_stream::document_stream() noexcept - : parser{nullptr}, - buf{nullptr}, - len{0}, - batch_size{0}, - error { - UNINITIALIZED -} -#ifdef SIMDJSON_THREADS_ENABLED -, use_thread(false) -#endif -{ -} - -simdjson_really_inline document_stream::~document_stream() noexcept { -#ifdef SIMDJSON_THREADS_ENABLED - worker.reset(); -#endif -} - -inline size_t document_stream::size_in_bytes() const noexcept { return len; } - -inline size_t document_stream::truncated_bytes() const noexcept { - if (error == CAPACITY) { - return len - batch_start; - } - return parser->implementation->structural_indexes - [parser->implementation->n_structural_indexes] - - parser->implementation->structural_indexes - [parser->implementation->n_structural_indexes + 1]; -} - -simdjson_really_inline document_stream::iterator::iterator() noexcept - : stream{nullptr}, - finished{true} {} - -simdjson_really_inline document_stream::iterator::iterator( - document_stream *_stream, bool is_end) noexcept : stream{_stream}, - finished{is_end} {} - -simdjson_really_inline simdjson_result - document_stream::iterator::operator*() noexcept { - // if(stream->error) { return stream->error; } - return simdjson_result(stream->doc, - stream->error); -} - -simdjson_really_inline document_stream::iterator - &document_stream::iterator::operator++() noexcept { - // If there is an error, then we want the iterator - // to be finished, no matter what. (E.g., we do not - // keep generating documents with errors, or go beyond - // a document with errors.) - // - // Users do not have to call "operator*()" when they use operator++, - // so we need to end the stream in the operator++ function. - // - // Note that setting finished = true is essential otherwise - // we would enter an infinite loop. - if (stream->error) { - finished = true; - } - // Note that stream->error() is guarded against error conditions - // (it will immediately return if stream->error casts to false). - // In effect, this next function does nothing when (stream->error) - // is true (hence the risk of an infinite loop). - stream->next(); - // If that was the last document, we're finished. - // It is the only type of error we do not want to appear - // in operator*. - if (stream->error == EMPTY) { - finished = true; - } - // If we had any other kind of error (not EMPTY) then we want - // to pass it along to the operator* and we cannot mark the result - // as "finished" just yet. - return *this; -} - -simdjson_really_inline bool document_stream::iterator::operator!=( - const document_stream::iterator &other) const noexcept { - return finished != other.finished; -} - -simdjson_really_inline document_stream::iterator -document_stream::begin() noexcept { - start(); - // If there are no documents, we're finished. - return iterator(this, error == EMPTY); -} - -simdjson_really_inline document_stream::iterator -document_stream::end() noexcept { - return iterator(this, true); -} - -inline void document_stream::start() noexcept { - if (error) { - return; - } - error = parser->allocate(batch_size); - if (error) { - return; - } - // Always run the first stage 1 parse immediately - batch_start = 0; - error = run_stage1(*parser, batch_start); - while (error == EMPTY) { - // In exceptional cases, we may start with an empty block - batch_start = next_batch_start(); - if (batch_start >= len) { - return; - } - error = run_stage1(*parser, batch_start); - } - if (error) { - return; - } - doc_index = batch_start; - doc = document(json_iterator(&buf[batch_start], parser)); - doc.iter._streaming = true; - -#ifdef SIMDJSON_THREADS_ENABLED - if (use_thread && next_batch_start() < len) { - // Kick off the first thread on next batch if needed - error = stage1_thread_parser.allocate(batch_size); - if (error) { - return; - } - worker->start_thread(); - start_stage1_thread(); - if (error) { - return; - } - } -#endif // SIMDJSON_THREADS_ENABLED -} - -inline void document_stream::next() noexcept { - // We always enter at once once in an error condition. - if (error) { - return; - } - next_document(); - if (error) { - return; - } - auto cur_struct_index = - doc.iter._root - parser->implementation->structural_indexes.get(); - doc_index = batch_start + - parser->implementation->structural_indexes[cur_struct_index]; - - // Check if at end of structural indexes (i.e. at end of batch) - if (cur_struct_index >= - static_cast(parser->implementation->n_structural_indexes)) { - error = EMPTY; - // Load another batch (if available) - while (error == EMPTY) { - batch_start = next_batch_start(); - if (batch_start >= len) { - break; - } -#ifdef SIMDJSON_THREADS_ENABLED - if (use_thread) { - load_from_stage1_thread(); - } else { - error = run_stage1(*parser, batch_start); - } -#else - error = run_stage1(*parser, batch_start); -#endif - /** - * Whenever we move to another window, we need to update all - * pointers to make - * it appear as if the input buffer started at the beginning of the - * window. - * - * Take this input: - * - * {"z":5} {"1":1,"2":2,"4":4} [7, 10, 9] [15, 11, 12, 13] - * [154, 110, 112, 1311] - * - * Say you process the following window... - * - * '{"z":5} {"1":1,"2":2,"4":4} [7, 10, 9]' - * - * When you do so, the json_iterator has a pointer at the beginning - * of the memory region - * (pointing at the beginning of '{"z"...'. - * - * When you move to the window that starts at... - * - * '[7, 10, 9] [15, 11, 12, 13] ... - * - * then it is not sufficient to just run stage 1. You also need to - * re-anchor the - * json_iterator so that it believes we are starting at '[7, 10, - * 9]...'. - * - * Under the DOM front-end, this gets done automatically because the - * parser owns - * the pointer the data, and when you call stage1 and then stage2 on - * the same - * parser, then stage2 will run on the pointer acquired by stage1. - * - * That is, stage1 calls "this->buf = _buf" so the parser remembers - * the buffer that - * we used. But json_iterator has no callback when stage1 is called - * on the parser. - * In fact, I think that the parser is unaware of json_iterator. - * - * - * So we need to re-anchor the json_iterator after each call to - * stage 1 so that - * all of the pointers are in sync. - */ - doc.iter = json_iterator(&buf[batch_start], parser); - doc.iter._streaming = true; - /** - * End of resync. - */ - - if (error) { - continue; - } // If the error was EMPTY, we may want to load another batch. - doc_index = batch_start; - } - } -} - -inline void document_stream::next_document() noexcept { - // Go to next place where depth=0 (document depth) - error = doc.iter.skip_child(0); - if (error) { - return; - } - // Always set depth=1 at the start of document - doc.iter._depth = 1; - // Resets the string buffer at the beginning, thus invalidating the strings. - doc.iter._string_buf_loc = parser->string_buf.get(); - doc.iter._root = doc.iter.position(); -} - -inline size_t document_stream::next_batch_start() const noexcept { - return batch_start + - parser->implementation->structural_indexes - [parser->implementation->n_structural_indexes]; -} - -inline error_code document_stream::run_stage1(ondemand::parser &p, - size_t _batch_start) noexcept { - // This code only updates the structural index in the parser, it does not - // update any json_iterator - // instance. - size_t remaining = len - _batch_start; - if (remaining <= batch_size) { - return p.implementation->stage1( - &buf[_batch_start], remaining, stage1_mode::streaming_final); - } else { - return p.implementation->stage1( - &buf[_batch_start], batch_size, stage1_mode::streaming_partial); - } -} - -simdjson_really_inline size_t document_stream::iterator::current_index() const - noexcept { - return stream->doc_index; -} - -simdjson_really_inline std::string_view document_stream::iterator::source() - const noexcept { - auto depth = stream->doc.iter.depth(); - auto cur_struct_index = - stream->doc.iter._root - - stream->parser->implementation->structural_indexes.get(); - - // If at root, process the first token to determine if scalar value - if (stream->doc.iter.at_root()) { - switch (stream->buf[stream->batch_start + - stream->parser->implementation - ->structural_indexes[cur_struct_index]]) { - case '{': - case '[': // Depth=1 already at start of document - break; - case '}': - case ']': - depth--; - break; - default: // Scalar value document - // TODO: Remove any trailing whitespaces - // This returns a string spanning from start of value to the - // beginning of the next document (excluded) - return std::string_view( - reinterpret_cast(stream->buf) + - current_index(), - stream->parser->implementation - ->structural_indexes[++cur_struct_index] - - current_index() - 1); - } - cur_struct_index++; - } - - while (cur_struct_index <= - static_cast( - stream->parser->implementation->n_structural_indexes)) { - switch (stream->buf[stream->batch_start + - stream->parser->implementation - ->structural_indexes[cur_struct_index]]) { - case '{': - case '[': - depth++; - break; - case '}': - case ']': - depth--; - break; - } - if (depth == 0) { - break; - } - cur_struct_index++; - } - - return std::string_view( - reinterpret_cast(stream->buf) + current_index(), - stream->parser->implementation->structural_indexes[cur_struct_index] - - current_index() + stream->batch_start + 1); - ; -} - -inline error_code document_stream::iterator::error() const noexcept { - return stream->error; -} - -#ifdef SIMDJSON_THREADS_ENABLED - -inline void document_stream::load_from_stage1_thread() noexcept { - worker->finish(); - // Swap to the parser that was loaded up in the thread. Make sure the parser - // has - // enough memory to swap to, as well. - std::swap(stage1_thread_parser, *parser); - error = stage1_thread_error; - if (error) { - return; - } - - // If there's anything left, start the stage 1 thread! - if (next_batch_start() < len) { - start_stage1_thread(); - } -} - -inline void document_stream::start_stage1_thread() noexcept { - // we call the thread on a lambda that will update - // this->stage1_thread_error - // there is only one thread that may write to this value - // TODO this is NOT exception-safe. - this->stage1_thread_error = - UNINITIALIZED; // In case something goes wrong, make sure it's an error - size_t _next_batch_start = this->next_batch_start(); - - worker->run(this, &this->stage1_thread_parser, _next_batch_start); -} - -#endif // SIMDJSON_THREADS_ENABLED - -} // namespace ondemand -} // namespace SIMDJSON_BUILTIN_IMPLEMENTATION -} // namespace simdjson - -namespace simdjson { - -simdjson_really_inline -simdjson_result:: - simdjson_result(error_code error) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_stream>(error) {} -simdjson_really_inline -simdjson_result:: - simdjson_result(SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_stream - &&value) noexcept - : implementation_simdjson_result_base< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_stream>( - std::forward< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_stream>( - value)) {} -} -/* end file include/simdjson/generic/ondemand/document_stream-inl.h */ -/* begin file include/simdjson/generic/ondemand/serialization-inl.h */ - - -namespace simdjson { - -inline std::string_view trim(const std::string_view str) noexcept { - // We can almost surely do better by rolling our own find_first_not_of - // function. - size_t first = str.find_first_not_of(" \t\n\r"); - // If we have the empty string (just white space), then no trimming is - // possible, and - // we return the empty string_view. - if (std::string_view::npos == first) { - return std::string_view(); - } - size_t last = str.find_last_not_of(" \t\n\r"); - return str.substr(first, (last - first + 1)); -} - - -inline simdjson_result to_json_string( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document &x) noexcept { - std::string_view v; - auto error = x.raw_json().get(v); - if (error) { - return error; - } - return trim(v); -} - -inline simdjson_result to_json_string( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference &x) noexcept { - std::string_view v; - auto error = x.raw_json().get(v); - if (error) { - return error; - } - return trim(v); -} - -inline simdjson_result to_json_string( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value &x) noexcept { - /** - * If we somehow receive a value that has already been consumed, - * then the following code could be in trouble. E.g., we create - * an array as needed, but if an array was already created, then - * it could be bad. - */ - using namespace SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand; - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::json_type t; - auto error = x.type().get(t); - if (error != SUCCESS) { - return error; - } - switch (t) { - case json_type::array: { - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array array; - error = x.get_array().get(array); - if (error) { - return error; - } - return to_json_string(array); - } - case json_type::object: { - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object object; - error = x.get_object().get(object); - if (error) { - return error; - } - return to_json_string(object); - } - default: - return trim(x.raw_json_token()); - } -} - -inline simdjson_result to_json_string( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object &x) noexcept { - std::string_view v; - auto error = x.raw_json().get(v); - if (error) { - return error; - } - return trim(v); -} - -inline simdjson_result to_json_string( - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array &x) noexcept { - std::string_view v; - auto error = x.raw_json().get(v); - if (error) { - return error; - } - return trim(v); -} - -inline simdjson_result to_json_string( - simdjson_result x) { - if (x.error()) { - return x.error(); - } - return to_json_string(x.value_unsafe()); -} - -inline simdjson_result to_json_string( - simdjson_result< - SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference> x) { - if (x.error()) { - return x.error(); - } - return to_json_string(x.value_unsafe()); -} - -inline simdjson_result to_json_string( - simdjson_result x) { - if (x.error()) { - return x.error(); - } - return to_json_string(x.value_unsafe()); -} - -inline simdjson_result to_json_string( - simdjson_result x) { - if (x.error()) { - return x.error(); - } - return to_json_string(x.value_unsafe()); -} - -inline simdjson_result to_json_string( - simdjson_result x) { - if (x.error()) { - return x.error(); - } - return to_json_string(x.value_unsafe()); -} -} // namespace simdjson - -namespace simdjson { -namespace SIMDJSON_BUILTIN_IMPLEMENTATION { -namespace ondemand { - -#if SIMDJSON_EXCEPTIONS -inline std::ostream &operator<<( - std::ostream &out, - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value x) { - std::string_view v; - auto error = simdjson::to_json_string(x).get(v); - if (error == simdjson::SUCCESS) { - return (out << v); - } else { - throw simdjson::simdjson_error(error); - } -} -inline std::ostream &operator<<( - std::ostream &out, - simdjson::simdjson_result< - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value> x) { - if (x.error()) { - throw simdjson::simdjson_error(x.error()); - } - return (out << x.value()); -} -#else -inline std::ostream &operator<<( - std::ostream &out, - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::value x) { - std::string_view v; - auto error = simdjson::to_json_string(x).get(v); - if (error == simdjson::SUCCESS) { - return (out << v); - } else { - return (out << error); - } -} -#endif - -#if SIMDJSON_EXCEPTIONS -inline std::ostream &operator<<( - std::ostream &out, - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array value) { - std::string_view v; - auto error = simdjson::to_json_string(value).get(v); - if (error == simdjson::SUCCESS) { - return (out << v); - } else { - throw simdjson::simdjson_error(error); - } -} -inline std::ostream &operator<<( - std::ostream &out, - simdjson::simdjson_result< - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array> x) { - if (x.error()) { - throw simdjson::simdjson_error(x.error()); - } - return (out << x.value()); -} -#else -inline std::ostream &operator<<( - std::ostream &out, - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::array value) { - std::string_view v; - auto error = simdjson::to_json_string(value).get(v); - if (error == simdjson::SUCCESS) { - return (out << v); - } else { - return (out << error); - } -} -#endif - -#if SIMDJSON_EXCEPTIONS -inline std::ostream &operator<<( - std::ostream &out, - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document &value) { - std::string_view v; - auto error = simdjson::to_json_string(value).get(v); - if (error == simdjson::SUCCESS) { - return (out << v); - } else { - throw simdjson::simdjson_error(error); - } -} -inline std::ostream &operator<<( - std::ostream &out, - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference - &value) { - std::string_view v; - auto error = simdjson::to_json_string(value).get(v); - if (error == simdjson::SUCCESS) { - return (out << v); - } else { - throw simdjson::simdjson_error(error); - } -} -inline std::ostream &operator<<( - std::ostream &out, - simdjson::simdjson_result< - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document> &&x) { - if (x.error()) { - throw simdjson::simdjson_error(x.error()); - } - return (out << x.value()); -} -inline std::ostream &operator<<( - std::ostream &out, - simdjson::simdjson_result< - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document_reference> - &&x) { - if (x.error()) { - throw simdjson::simdjson_error(x.error()); - } - return (out << x.value()); -} -#else -inline std::ostream &operator<<( - std::ostream &out, - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::document &value) { - std::string_view v; - auto error = simdjson::to_json_string(value).get(v); - if (error == simdjson::SUCCESS) { - return (out << v); - } else { - return (out << error); - } -} -#endif - -#if SIMDJSON_EXCEPTIONS -inline std::ostream &operator<<( - std::ostream &out, - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object value) { - std::string_view v; - auto error = simdjson::to_json_string(value).get(v); - if (error == simdjson::SUCCESS) { - return (out << v); - } else { - throw simdjson::simdjson_error(error); - } -} -inline std::ostream &operator<<( - std::ostream &out, - simdjson::simdjson_result< - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object> x) { - if (x.error()) { - throw simdjson::simdjson_error(x.error()); - } - return (out << x.value()); -} -#else -inline std::ostream &operator<<( - std::ostream &out, - simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand::object value) { - std::string_view v; - auto error = simdjson::to_json_string(value).get(v); - if (error == simdjson::SUCCESS) { - return (out << v); - } else { - return (out << error); - } -} -#endif -} -} -} // namespace simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand -/* end file include/simdjson/generic/ondemand/serialization-inl.h */ -/* end file include/simdjson/generic/ondemand-inl.h */ - - -namespace simdjson { -/** - * Represents the best statically linked simdjson implementation that can be - * used by the compiling - * program. - * - * Detects what options the program is compiled against, and picks the minimum - * implementation that - * will work on any computer that can run the program. For example, if you - * compile with g++ - * -march=westmere, it will pick the westmere implementation. The haswell - * implementation will - * still be available, and can be selected at runtime, but the builtin - * implementation (and any - * code that uses it) will use westmere. - */ -namespace builtin = SIMDJSON_BUILTIN_IMPLEMENTATION; -/** - * @copydoc simdjson::SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand - */ -namespace ondemand = SIMDJSON_BUILTIN_IMPLEMENTATION::ondemand; -/** - * Function which returns a pointer to an implementation matching the "builtin" - * implementation. - * The builtin implementation is the best statically linked simdjson - * implementation that can be used by the compiling - * program. If you compile with g++ -march=haswell, this will return the haswell - * implementation. - * It is handy to be able to check what builtin was used: - * builtin_implementation()->name(). - */ -const implementation *builtin_implementation(); -} // namespace simdjson - -#endif // SIMDJSON_BUILTIN_H -/* end file include/simdjson/builtin.h */ - -#endif // SIMDJSON_H -/* end file include/simdjson.h */ diff --git a/speechx/speechx/websocket/CMakeLists.txt b/speechx/speechx/websocket/CMakeLists.txt deleted file mode 100644 index 582a38031..000000000 --- a/speechx/speechx/websocket/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -project(websocket) - -add_library(websocket STATIC - websocket_server.cc - websocket_client.cc -) -target_link_libraries(websocket PUBLIC frontend decoder nnet) diff --git a/audio/tests/benchmark/README.md b/tests/benchmark/audio/README.md similarity index 97% rename from audio/tests/benchmark/README.md rename to tests/benchmark/audio/README.md index b9034100d..9cade74e0 100644 --- a/audio/tests/benchmark/README.md +++ b/tests/benchmark/audio/README.md @@ -15,7 +15,6 @@ Result: ========================================================================== test session starts ========================================================================== platform linux -- Python 3.7.7, pytest-7.0.1, pluggy-1.0.0 benchmark: 3.4.1 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000) -rootdir: /ssd3/chenxiaojie06/PaddleSpeech/DeepSpeech/paddleaudio plugins: typeguard-2.12.1, benchmark-3.4.1, anyio-3.5.0 collected 4 items diff --git a/audio/tests/benchmark/log_melspectrogram.py b/tests/benchmark/audio/log_melspectrogram.py similarity index 87% rename from audio/tests/benchmark/log_melspectrogram.py rename to tests/benchmark/audio/log_melspectrogram.py index 9832aed4d..c85fcecfb 100644 --- a/audio/tests/benchmark/log_melspectrogram.py +++ b/tests/benchmark/audio/log_melspectrogram.py @@ -17,15 +17,17 @@ import urllib.request import librosa import numpy as np import paddle -import paddleaudio import torch import torchaudio +import paddlespeech.audio + wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' if not os.path.isfile(os.path.basename(wav_url)): urllib.request.urlretrieve(wav_url, os.path.basename(wav_url)) -waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url))) +waveform, sr = paddlespeech.audio.load( + os.path.abspath(os.path.basename(wav_url))) waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0) waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0) @@ -55,7 +57,7 @@ def enable_gpu_device(): paddle.set_device('gpu') -log_mel_extractor = paddleaudio.features.LogMelSpectrogram( +log_mel_extractor = paddlespeech.audio.features.LogMelSpectrogram( **mel_conf, f_min=0.0, top_db=80.0, dtype=waveform_tensor.dtype) @@ -65,20 +67,20 @@ def log_melspectrogram(): def test_log_melspect_cpu(benchmark): enable_cpu_device() - feature_paddleaudio = benchmark(log_melspectrogram) + feature_audio = benchmark(log_melspectrogram) feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0) np.testing.assert_array_almost_equal( - feature_librosa, feature_paddleaudio, decimal=3) + feature_librosa, feature_audio, decimal=3) def test_log_melspect_gpu(benchmark): enable_gpu_device() - feature_paddleaudio = benchmark(log_melspectrogram) + feature_audio = benchmark(log_melspectrogram) feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0) np.testing.assert_array_almost_equal( - feature_librosa, feature_paddleaudio, decimal=2) + feature_librosa, feature_audio, decimal=2) mel_extractor_torchaudio = torchaudio.transforms.MelSpectrogram( @@ -102,11 +104,11 @@ def test_log_melspect_cpu_torchaudio(benchmark): waveform_tensor_torch = waveform_tensor_torch.to('cpu') amplitude_to_DB = amplitude_to_DB.to('cpu') - feature_paddleaudio = benchmark(log_melspectrogram_torchaudio) + feature_audio = benchmark(log_melspectrogram_torchaudio) feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0) np.testing.assert_array_almost_equal( - feature_librosa, feature_paddleaudio, decimal=3) + feature_librosa, feature_audio, decimal=3) def test_log_melspect_gpu_torchaudio(benchmark): diff --git a/audio/tests/benchmark/melspectrogram.py b/tests/benchmark/audio/melspectrogram.py similarity index 85% rename from audio/tests/benchmark/melspectrogram.py rename to tests/benchmark/audio/melspectrogram.py index 5fe3f2481..498158941 100644 --- a/audio/tests/benchmark/melspectrogram.py +++ b/tests/benchmark/audio/melspectrogram.py @@ -17,15 +17,17 @@ import urllib.request import librosa import numpy as np import paddle -import paddleaudio import torch import torchaudio +import paddlespeech.audio + wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' if not os.path.isfile(os.path.basename(wav_url)): urllib.request.urlretrieve(wav_url, os.path.basename(wav_url)) -waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url))) +waveform, sr = paddlespeech.audio.load( + os.path.abspath(os.path.basename(wav_url))) waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0) waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0) @@ -55,7 +57,7 @@ def enable_gpu_device(): paddle.set_device('gpu') -mel_extractor = paddleaudio.features.MelSpectrogram( +mel_extractor = paddlespeech.audio.features.MelSpectrogram( **mel_conf, f_min=0.0, dtype=waveform_tensor.dtype) @@ -65,18 +67,18 @@ def melspectrogram(): def test_melspect_cpu(benchmark): enable_cpu_device() - feature_paddleaudio = benchmark(melspectrogram) + feature_audio = benchmark(melspectrogram) feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) np.testing.assert_array_almost_equal( - feature_librosa, feature_paddleaudio, decimal=3) + feature_librosa, feature_audio, decimal=3) def test_melspect_gpu(benchmark): enable_gpu_device() - feature_paddleaudio = benchmark(melspectrogram) + feature_audio = benchmark(melspectrogram) feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) np.testing.assert_array_almost_equal( - feature_librosa, feature_paddleaudio, decimal=3) + feature_librosa, feature_audio, decimal=3) mel_extractor_torchaudio = torchaudio.transforms.MelSpectrogram( @@ -91,10 +93,10 @@ def test_melspect_cpu_torchaudio(benchmark): global waveform_tensor_torch, mel_extractor_torchaudio mel_extractor_torchaudio = mel_extractor_torchaudio.to('cpu') waveform_tensor_torch = waveform_tensor_torch.to('cpu') - feature_paddleaudio = benchmark(melspectrogram_torchaudio) + feature_audio = benchmark(melspectrogram_torchaudio) feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) np.testing.assert_array_almost_equal( - feature_librosa, feature_paddleaudio, decimal=3) + feature_librosa, feature_audio, decimal=3) def test_melspect_gpu_torchaudio(benchmark): diff --git a/audio/tests/benchmark/mfcc.py b/tests/benchmark/audio/mfcc.py similarity index 87% rename from audio/tests/benchmark/mfcc.py rename to tests/benchmark/audio/mfcc.py index c6a8c85f9..4e286de90 100644 --- a/audio/tests/benchmark/mfcc.py +++ b/tests/benchmark/audio/mfcc.py @@ -17,15 +17,17 @@ import urllib.request import librosa import numpy as np import paddle -import paddleaudio import torch import torchaudio +import paddlespeech.audio + wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' if not os.path.isfile(os.path.basename(wav_url)): urllib.request.urlretrieve(wav_url, os.path.basename(wav_url)) -waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url))) +waveform, sr = paddlespeech.audio.load( + os.path.abspath(os.path.basename(wav_url))) waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0) waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0) @@ -64,7 +66,7 @@ def enable_gpu_device(): paddle.set_device('gpu') -mfcc_extractor = paddleaudio.features.MFCC( +mfcc_extractor = paddlespeech.audio.features.MFCC( **mfcc_conf, f_min=0.0, dtype=waveform_tensor.dtype) @@ -74,18 +76,18 @@ def mfcc(): def test_mfcc_cpu(benchmark): enable_cpu_device() - feature_paddleaudio = benchmark(mfcc) + feature_audio = benchmark(mfcc) feature_librosa = librosa.feature.mfcc(waveform, **mel_conf) np.testing.assert_array_almost_equal( - feature_librosa, feature_paddleaudio, decimal=3) + feature_librosa, feature_audio, decimal=3) def test_mfcc_gpu(benchmark): enable_gpu_device() - feature_paddleaudio = benchmark(mfcc) + feature_audio = benchmark(mfcc) feature_librosa = librosa.feature.mfcc(waveform, **mel_conf) np.testing.assert_array_almost_equal( - feature_librosa, feature_paddleaudio, decimal=3) + feature_librosa, feature_audio, decimal=3) del mel_conf_torchaudio['sample_rate'] @@ -103,10 +105,10 @@ def test_mfcc_cpu_torchaudio(benchmark): mel_extractor_torchaudio = mfcc_extractor_torchaudio.to('cpu') waveform_tensor_torch = waveform_tensor_torch.to('cpu') - feature_paddleaudio = benchmark(mfcc_torchaudio) + feature_audio = benchmark(mfcc_torchaudio) feature_librosa = librosa.feature.mfcc(waveform, **mel_conf) np.testing.assert_array_almost_equal( - feature_librosa, feature_paddleaudio, decimal=3) + feature_librosa, feature_audio, decimal=3) def test_mfcc_gpu_torchaudio(benchmark): diff --git a/tests/test_tipc/prepare.sh b/tests/test_tipc/prepare.sh index 31dff320f..a13938017 100644 --- a/tests/test_tipc/prepare.sh +++ b/tests/test_tipc/prepare.sh @@ -24,35 +24,36 @@ trainer_list=$(func_parser_value "${lines[14]}") if [ ${MODE} = "benchmark_train" ];then curPath=$(readlink -f "$(dirname "$0")") - echo "curPath:"${curPath} + echo "curPath:"${curPath} # /PaddleSpeech/tests/test_tipc cd ${curPath}/../.. + echo "------------- install for speech " apt-get install libsndfile1 -y + pip install yacs -i https://pypi.tuna.tsinghua.edu.cn/simple pip install pytest-runner -i https://pypi.tuna.tsinghua.edu.cn/simple pip install kaldiio -i https://pypi.tuna.tsinghua.edu.cn/simple pip install setuptools_scm -i https://pypi.tuna.tsinghua.edu.cn/simple pip install . -i https://pypi.tuna.tsinghua.edu.cn/simple + pip install jsonlines + pip list cd - if [ ${model_name} == "conformer" ]; then # set the URL for aishell_tiny dataset - URL=${conformer_data_URL:-"None"} - echo "URL:"${URL} - if [ ${URL} == 'None' ];then + conformer_aishell_URL=${conformer_aishell_URL:-"None"} + if [ ${conformer_aishell_URL} == 'None' ];then echo "please contact author to get the URL.\n" exit - else - wget -P ${curPath}/../../dataset/aishell/ ${URL} + else + rm -rf ${curPath}/../../dataset/aishell/aishell.py + rm -rf ${curPath}/../../dataset/aishell/data_aishell_tiny* + wget -P ${curPath}/../../dataset/aishell/ ${conformer_aishell_URL} fi - sed -i "s#^URL_ROOT_TAG#URL_ROOT = '${URL}'#g" ${curPath}/conformer/scripts/aishell_tiny.py - cp ${curPath}/conformer/scripts/aishell_tiny.py ${curPath}/../../dataset/aishell/ cd ${curPath}/../../examples/aishell/asr1 - source path.sh - # download audio data - sed -i "s#aishell.py#aishell_tiny.py#g" ./local/data.sh - sed -i "s#python3#python#g" ./local/data.sh - bash ./local/data.sh || exit -1 - if [ $? -ne 0 ]; then - exit 1 - fi + + #Prepare the data + sed -i "s#python3#python#g" ./local/data.sh + bash run.sh --stage 0 --stop_stage 0 # 执行第一遍的时候会偶现报错 + bash run.sh --stage 0 --stop_stage 0 + mkdir -p ${curPath}/conformer/benchmark_train/ cp -rf conf ${curPath}/conformer/benchmark_train/ cp -rf data ${curPath}/conformer/benchmark_train/ diff --git a/audio/tests/backends/soundfile/__init__.py b/tests/unit/audio/backends/__init__.py similarity index 100% rename from audio/tests/backends/soundfile/__init__.py rename to tests/unit/audio/backends/__init__.py diff --git a/audio/tests/backends/base.py b/tests/unit/audio/backends/base.py similarity index 100% rename from audio/tests/backends/base.py rename to tests/unit/audio/backends/base.py diff --git a/audio/tests/features/__init__.py b/tests/unit/audio/backends/soundfile/__init__.py similarity index 100% rename from audio/tests/features/__init__.py rename to tests/unit/audio/backends/soundfile/__init__.py diff --git a/audio/tests/backends/soundfile/test_io.py b/tests/unit/audio/backends/soundfile/test_io.py similarity index 90% rename from audio/tests/backends/soundfile/test_io.py rename to tests/unit/audio/backends/soundfile/test_io.py index 9d092902d..26276751f 100644 --- a/audio/tests/backends/soundfile/test_io.py +++ b/tests/unit/audio/backends/soundfile/test_io.py @@ -16,16 +16,16 @@ import os import unittest import numpy as np -import paddleaudio import soundfile as sf +import paddlespeech.audio from ..base import BackendTest class TestIO(BackendTest): def test_load_mono_channel(self): sf_data, sf_sr = sf.read(self.files[0]) - pa_data, pa_sr = paddleaudio.load( + pa_data, pa_sr = paddlespeech.audio.load( self.files[0], normal=False, dtype='float64') self.assertEqual(sf_data.dtype, pa_data.dtype) @@ -35,7 +35,7 @@ class TestIO(BackendTest): def test_load_multi_channels(self): sf_data, sf_sr = sf.read(self.files[1]) sf_data = sf_data.T # Channel dim first - pa_data, pa_sr = paddleaudio.load( + pa_data, pa_sr = paddlespeech.audio.load( self.files[1], mono=False, normal=False, dtype='float64') self.assertEqual(sf_data.dtype, pa_data.dtype) @@ -49,7 +49,7 @@ class TestIO(BackendTest): pa_tmp_file = 'pa_tmp.wav' sf.write(sf_tmp_file, waveform, sr) - paddleaudio.save(waveform, sr, pa_tmp_file) + paddlespeech.audio.save(waveform, sr, pa_tmp_file) self.assertTrue(filecmp.cmp(sf_tmp_file, pa_tmp_file)) for file in [sf_tmp_file, pa_tmp_file]: @@ -62,7 +62,7 @@ class TestIO(BackendTest): pa_tmp_file = 'pa_tmp.wav' sf.write(sf_tmp_file, waveform.T, sr) - paddleaudio.save(waveform.T, sr, pa_tmp_file) + paddlespeech.audio.save(waveform.T, sr, pa_tmp_file) self.assertTrue(filecmp.cmp(sf_tmp_file, pa_tmp_file)) for file in [sf_tmp_file, pa_tmp_file]: diff --git a/paddlespeech/cli/stats/__init__.py b/tests/unit/audio/features/__init__.py similarity index 84% rename from paddlespeech/cli/stats/__init__.py rename to tests/unit/audio/features/__init__.py index 9fe6c4aba..97043fd7b 100644 --- a/paddlespeech/cli/stats/__init__.py +++ b/tests/unit/audio/features/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .infer import StatsExecutor diff --git a/audio/tests/features/base.py b/tests/unit/audio/features/base.py similarity index 97% rename from audio/tests/features/base.py rename to tests/unit/audio/features/base.py index 476f6b8ee..6d59f72b5 100644 --- a/audio/tests/features/base.py +++ b/tests/unit/audio/features/base.py @@ -17,7 +17,8 @@ import urllib.request import numpy as np import paddle -from paddleaudio import load + +from paddlespeech.audio import load wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' diff --git a/audio/tests/features/test_istft.py b/tests/unit/audio/features/test_istft.py similarity index 96% rename from audio/tests/features/test_istft.py rename to tests/unit/audio/features/test_istft.py index 9cf8cdd65..f1e6e4e33 100644 --- a/audio/tests/features/test_istft.py +++ b/tests/unit/audio/features/test_istft.py @@ -15,9 +15,9 @@ import unittest import numpy as np import paddle -from paddleaudio.functional.window import get_window from .base import FeatTest +from paddlespeech.audio.functional.window import get_window from paddlespeech.s2t.transform.spectrogram import IStft from paddlespeech.s2t.transform.spectrogram import Stft diff --git a/audio/tests/features/test_kaldi.py b/tests/unit/audio/features/test_kaldi.py similarity index 87% rename from audio/tests/features/test_kaldi.py rename to tests/unit/audio/features/test_kaldi.py index 00a576f6f..2b0ece890 100644 --- a/audio/tests/features/test_kaldi.py +++ b/tests/unit/audio/features/test_kaldi.py @@ -15,10 +15,10 @@ import unittest import numpy as np import paddle -import paddleaudio import torch import torchaudio +import paddlespeech.audio from .base import FeatTest @@ -40,17 +40,17 @@ class TestKaldi(FeatTest): self.window_size, periodic=False, dtype=eval(f'torch.{self.dtype}')).pow(0.85) - p_hann_window = paddleaudio.functional.window.get_window( + p_hann_window = paddlespeech.audio.functional.window.get_window( 'hann', self.window_size, fftbins=False, dtype=eval(f'paddle.{self.dtype}')) - p_hamm_window = paddleaudio.functional.window.get_window( + p_hamm_window = paddlespeech.audio.functional.window.get_window( 'hamming', self.window_size, fftbins=False, dtype=eval(f'paddle.{self.dtype}')) - p_povey_window = paddleaudio.functional.window.get_window( + p_povey_window = paddlespeech.audio.functional.window.get_window( 'hann', self.window_size, fftbins=False, @@ -63,7 +63,7 @@ class TestKaldi(FeatTest): def test_fbank(self): ta_features = torchaudio.compliance.kaldi.fbank( torch.from_numpy(self.waveform.astype(self.dtype))) - pa_features = paddleaudio.compliance.kaldi.fbank( + pa_features = paddlespeech.audio.compliance.kaldi.fbank( paddle.to_tensor(self.waveform.astype(self.dtype))) np.testing.assert_array_almost_equal( ta_features, pa_features, decimal=4) @@ -71,7 +71,7 @@ class TestKaldi(FeatTest): def test_mfcc(self): ta_features = torchaudio.compliance.kaldi.mfcc( torch.from_numpy(self.waveform.astype(self.dtype))) - pa_features = paddleaudio.compliance.kaldi.mfcc( + pa_features = paddlespeech.audio.compliance.kaldi.mfcc( paddle.to_tensor(self.waveform.astype(self.dtype))) np.testing.assert_array_almost_equal( ta_features, pa_features, decimal=4) diff --git a/audio/tests/features/test_librosa.py b/tests/unit/audio/features/test_librosa.py similarity index 89% rename from audio/tests/features/test_librosa.py rename to tests/unit/audio/features/test_librosa.py index a1d3e8400..ffdec3e78 100644 --- a/audio/tests/features/test_librosa.py +++ b/tests/unit/audio/features/test_librosa.py @@ -16,10 +16,10 @@ import unittest import librosa import numpy as np import paddle -import paddleaudio -from paddleaudio.functional.window import get_window +import paddlespeech.audio from .base import FeatTest +from paddlespeech.audio.functional.window import get_window class TestLibrosa(FeatTest): @@ -117,7 +117,7 @@ class TestLibrosa(FeatTest): htk=False, norm='slaney', dtype=self.waveform.dtype, ) - feature_compliance = paddleaudio.compliance.librosa.compute_fbank_matrix( + feature_compliance = paddlespeech.audio.compliance.librosa.compute_fbank_matrix( sr=self.sr, n_fft=self.n_fft, n_mels=self.n_mels, @@ -127,7 +127,7 @@ class TestLibrosa(FeatTest): norm='slaney', dtype=self.waveform.dtype, ) x = paddle.to_tensor(self.waveform) - feature_functional = paddleaudio.functional.compute_fbank_matrix( + feature_functional = paddlespeech.audio.functional.compute_fbank_matrix( sr=self.sr, n_fft=self.n_fft, n_mels=self.n_mels, @@ -156,8 +156,8 @@ class TestLibrosa(FeatTest): n_mels=self.n_mels, fmin=self.fmin) - # paddleaudio.compliance.librosa: - feature_compliance = paddleaudio.compliance.librosa.melspectrogram( + # paddlespeech.audio.compliance.librosa: + feature_compliance = paddlespeech.audio.compliance.librosa.melspectrogram( x=self.waveform, sr=self.sr, window_size=self.n_fft, @@ -166,10 +166,10 @@ class TestLibrosa(FeatTest): fmin=self.fmin, to_db=False) - # paddleaudio.features.layer + # paddlespeech.audio.features.layer x = paddle.to_tensor( self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim. - feature_extractor = paddleaudio.features.MelSpectrogram( + feature_extractor = paddlespeech.audio.features.MelSpectrogram( sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, @@ -198,8 +198,8 @@ class TestLibrosa(FeatTest): fmin=self.fmin) feature_librosa = librosa.power_to_db(feature_librosa, top_db=None) - # paddleaudio.compliance.librosa: - feature_compliance = paddleaudio.compliance.librosa.melspectrogram( + # paddlespeech.audio.compliance.librosa: + feature_compliance = paddlespeech.audio.compliance.librosa.melspectrogram( x=self.waveform, sr=self.sr, window_size=self.n_fft, @@ -207,10 +207,10 @@ class TestLibrosa(FeatTest): n_mels=self.n_mels, fmin=self.fmin) - # paddleaudio.features.layer + # paddlespeech.audio.features.layer x = paddle.to_tensor( self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim. - feature_extractor = paddleaudio.features.LogMelSpectrogram( + feature_extractor = paddlespeech.audio.features.LogMelSpectrogram( sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, @@ -243,8 +243,8 @@ class TestLibrosa(FeatTest): n_mels=self.n_mels, fmin=self.fmin) - # paddleaudio.compliance.librosa: - feature_compliance = paddleaudio.compliance.librosa.mfcc( + # paddlespeech.audio.compliance.librosa: + feature_compliance = paddlespeech.audio.compliance.librosa.mfcc( x=self.waveform, sr=self.sr, n_mfcc=self.n_mfcc, @@ -257,10 +257,10 @@ class TestLibrosa(FeatTest): fmin=self.fmin, top_db=self.top_db) - # paddleaudio.features.layer + # paddlespeech.audio.features.layer x = paddle.to_tensor( self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim. - feature_extractor = paddleaudio.features.MFCC( + feature_extractor = paddlespeech.audio.features.MFCC( sr=self.sr, n_mfcc=self.n_mfcc, n_fft=self.n_fft, diff --git a/audio/tests/features/test_log_melspectrogram.py b/tests/unit/audio/features/test_log_melspectrogram.py similarity index 90% rename from audio/tests/features/test_log_melspectrogram.py rename to tests/unit/audio/features/test_log_melspectrogram.py index 0383c2b8b..59eb73e8c 100644 --- a/audio/tests/features/test_log_melspectrogram.py +++ b/tests/unit/audio/features/test_log_melspectrogram.py @@ -15,8 +15,8 @@ import unittest import numpy as np import paddle -import paddleaudio +import paddlespeech.audio from .base import FeatTest from paddlespeech.s2t.transform.spectrogram import LogMelSpectrogram @@ -33,8 +33,7 @@ class TestLogMelSpectrogram(FeatTest): ps_res = ps_melspect(self.waveform.T).squeeze(1).T x = paddle.to_tensor(self.waveform) - # paddlespeech.s2t的特征存在幅度谱和功率谱滥用的情况 - ps_melspect = paddleaudio.features.LogMelSpectrogram( + ps_melspect = paddlespeech.audio.features.LogMelSpectrogram( self.sr, self.n_fft, self.hop_length, diff --git a/audio/tests/features/test_spectrogram.py b/tests/unit/audio/features/test_spectrogram.py similarity index 93% rename from audio/tests/features/test_spectrogram.py rename to tests/unit/audio/features/test_spectrogram.py index 1774fe619..7d908a7ef 100644 --- a/audio/tests/features/test_spectrogram.py +++ b/tests/unit/audio/features/test_spectrogram.py @@ -15,8 +15,8 @@ import unittest import numpy as np import paddle -import paddleaudio +import paddlespeech.audio from .base import FeatTest from paddlespeech.s2t.transform.spectrogram import Spectrogram @@ -31,7 +31,7 @@ class TestSpectrogram(FeatTest): ps_res = ps_spect(self.waveform.T).squeeze(1).T # Magnitude x = paddle.to_tensor(self.waveform) - pa_spect = paddleaudio.features.Spectrogram( + pa_spect = paddlespeech.audio.features.Spectrogram( self.n_fft, self.hop_length, power=1.0) pa_res = pa_spect(x).squeeze(0).numpy() diff --git a/audio/tests/features/test_stft.py b/tests/unit/audio/features/test_stft.py similarity index 95% rename from audio/tests/features/test_stft.py rename to tests/unit/audio/features/test_stft.py index 58792ffe2..03448ca80 100644 --- a/audio/tests/features/test_stft.py +++ b/tests/unit/audio/features/test_stft.py @@ -15,9 +15,9 @@ import unittest import numpy as np import paddle -from paddleaudio.functional.window import get_window from .base import FeatTest +from paddlespeech.audio.functional.window import get_window from paddlespeech.s2t.transform.spectrogram import Stft diff --git a/tests/unit/cli/aishell_test_prepare.py b/tests/unit/cli/aishell_test_prepare.py index 288de62a0..ed542c571 100644 --- a/tests/unit/cli/aishell_test_prepare.py +++ b/tests/unit/cli/aishell_test_prepare.py @@ -20,7 +20,6 @@ of each audio file in the data set. """ import argparse import codecs -import json import os from pathlib import Path @@ -55,6 +54,7 @@ args = parser.parse_args() def create_manifest(data_dir, manifest_path_prefix): print("Creating manifest %s ..." % manifest_path_prefix) json_lines = [] + reference_lines = [] transcript_path = os.path.join(data_dir, 'transcript', 'aishell_transcript_v0.8.txt') transcript_dict = {} @@ -88,6 +88,7 @@ def create_manifest(data_dir, manifest_path_prefix): duration = float(len(audio_data) / samplerate) text = transcript_dict[audio_id] json_lines.append(audio_path) + reference_lines.append(str(total_num + 1) + "\t" + text) total_sec += duration total_text += len(text) @@ -98,8 +99,13 @@ def create_manifest(data_dir, manifest_path_prefix): for line in json_lines: fout.write(line + '\n') + with codecs.open(manifest_path + ".text", 'w', 'utf-8') as fout: + for line in reference_lines: + fout.write(line + '\n') + manifest_dir = os.path.dirname(manifest_path_prefix) + def prepare_dataset(url, md5sum, target_dir, manifest_path=None): """Download, unpack and create manifest file.""" data_dir = os.path.join(target_dir, 'data_aishell') diff --git a/tests/unit/cli/calc_rtf_by_aishell.sh b/tests/unit/cli/calc_RTF_CER_by_aishell.sh similarity index 54% rename from tests/unit/cli/calc_rtf_by_aishell.sh rename to tests/unit/cli/calc_RTF_CER_by_aishell.sh index cee79160e..a5a1a77c1 100644 --- a/tests/unit/cli/calc_rtf_by_aishell.sh +++ b/tests/unit/cli/calc_RTF_CER_by_aishell.sh @@ -3,6 +3,10 @@ source path.sh stage=-1 stop_stage=100 +model_name=conformer_online_aishell +gpus=5 +log_file=res.log +res_file=res.rsl MAIN_ROOT=../../.. . ${MAIN_ROOT}/utils/parse_options.sh || exit -1; @@ -20,9 +24,16 @@ if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then echo "Prepare Aishell failed. Terminated." exit 1 fi - fi + if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then - cat data/manifest.test | paddlespeech asr --model conformer_online_aishell --device gpu --decode_method ctc_prefix_beam_search --rtf -v + export CUDA_VISIBLE_DEVICES=${gpus} + cat data/manifest.test | paddlespeech asr --model ${model_name} --device gpu --decode_method attention_rescoring --rtf -v &> ${log_file} +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + cat ${log_file} | grep "^[0-9]" > ${res_file} + python utils/compute-wer.py --char=1 --v=1 \ + data/manifest.test.text ${res_file} > ${res_file}.error fi diff --git a/tests/unit/cli/test_cli.sh b/tests/unit/cli/test_cli.sh index e1f1853f6..6879c4d64 100755 --- a/tests/unit/cli/test_cli.sh +++ b/tests/unit/cli/test_cli.sh @@ -22,10 +22,13 @@ paddlespeech asr --model deepspeech2online_wenetspeech --input ./zh.wav paddlespeech asr --model deepspeech2online_aishell --input ./zh.wav paddlespeech asr --model deepspeech2offline_librispeech --lang en --input ./en.wav +# Support editing num_decoding_left_chunks +paddlespeech asr --model conformer_online_wenetspeech --num_decoding_left_chunks 3 --input ./zh.wav + # long audio restriction { wget -c https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/test_long_audio_01.wav -paddlespeech asr --input test_long_audio_01.wav +paddlespeech asr --model deepspeech2online_wenetspeech --input test_long_audio_01.wav -y if [ $? -ne 255 ]; then echo -e "\e[1;31mTime restriction not passed\e[0m" exit 1 @@ -54,7 +57,7 @@ paddlespeech tts --am tacotron2_ljspeech --voc pwgan_ljspeech --lang en --input # Speech Translation (only support linux) paddlespeech st --input ./en.wav -# Speaker Verification +# Speaker Verification wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav paddlespeech vector --task spk --input 85236145389.wav @@ -65,7 +68,7 @@ echo -e "demo1 85236145389.wav \n demo2 85236145389.wav" > vec.job paddlespeech vector --task spk --input vec.job echo -e "demo3 85236145389.wav \n demo4 85236145389.wav" | paddlespeech vector --task spk -rm 85236145389.wav +rm 85236145389.wav rm vec.job # shell pipeline diff --git a/tests/unit/cli/utils b/tests/unit/cli/utils new file mode 120000 index 000000000..973afe674 --- /dev/null +++ b/tests/unit/cli/utils @@ -0,0 +1 @@ +../../../utils \ No newline at end of file diff --git a/tests/unit/server/online/tts/check_server/test.sh b/tests/unit/server/online/tts/check_server/test.sh index 766aea850..c62c54c76 100644 --- a/tests/unit/server/online/tts/check_server/test.sh +++ b/tests/unit/server/online/tts/check_server/test.sh @@ -28,7 +28,7 @@ StartService(){ ClientTest_http(){ for ((i=1; i<=3;i++)) do - paddlespeech_client tts_online --input "您好,欢迎使用百度飞桨深度学习框架。" + paddlespeech_client tts_online --input "您好,欢迎使用百度飞桨深度学习框架。" --port $port ((http_test_times+=1)) done } @@ -36,7 +36,7 @@ ClientTest_http(){ ClientTest_ws(){ for ((i=1; i<=3;i++)) do - paddlespeech_client tts_online --input "您好,欢迎使用百度飞桨深度学习框架。" --protocol websocket + paddlespeech_client tts_online --input "您好,欢迎使用百度飞桨深度学习框架。" --protocol websocket --port $port ((ws_test_times+=1)) done } @@ -54,7 +54,7 @@ GetTestResult_http() { GetTestResult_ws() { # Determine if the test was successful - ws_response_success_time=$(cat $log/server.log.wf | grep "Complete the transmission of audio streams" -c) + ws_response_success_time=$(cat $log/server.log.wf | grep "Complete the synthesis of the audio streams" -c) if (( $ws_response_success_time == $ws_test_times )) ; then echo "Testing successfully. $info" | tee -a $log/test_result.log else @@ -313,4 +313,4 @@ cat $log/test_result.log # Restoring conf is the same as demos/speech_server cp ./tts_online_application.yaml ./conf/application.yaml -rf -sleep 2s \ No newline at end of file +sleep 2s diff --git a/third_party/README.md b/third_party/README.md index c73df5427..843d0d3b2 100644 --- a/third_party/README.md +++ b/third_party/README.md @@ -1,27 +1,27 @@ * [python_kaldi_features](https://github.com/ZitengWang/python_kaldi_features) commit: fc1bd6240c2008412ab64dc25045cd872f5e126c ref: https://zhuanlan.zhihu.com/p/55371926 -licence: MIT +license: MIT * [python-pinyin](https://github.com/mozillazg/python-pinyin.git) commit: 55e524aa1b7b8eec3d15c5306043c6cdd5938b03 -licence: MIT +license: MIT * [zhon](https://github.com/tsroten/zhon) commit: 09bf543696277f71de502506984661a60d24494c -licence: MIT +license: MIT * [pymmseg-cpp](https://github.com/pluskid/pymmseg-cpp.git) commit: b76465045717fbb4f118c4fbdd24ce93bab10a6d -licence: MIT +license: MIT * [chinese_text_normalization](https://github.com/speechio/chinese_text_normalization.git) commit: 9e92c7bf2d6b5a7974305406d8e240045beac51c -licence: MIT +license: MIT * [phkit](https://github.com/KuangDD/phkit.git) commit: b2100293c1e36da531d7f30bd52c9b955a649522 -licence: None +license: None * [nnAudio](https://github.com/KinWaiCheuk/nnAudio.git) -licence: MIT +license: MIT diff --git a/third_party/ctc_decoders/LICENSE b/third_party/ctc_decoders/LICENSE index eeef74b30..ad947f8d7 100644 --- a/third_party/ctc_decoders/LICENSE +++ b/third_party/ctc_decoders/LICENSE @@ -5,4 +5,4 @@ score.h and score.cpp is under the LGPL license. The two files include the header files from KenLM project. For the rest: -The default licence of paddlespeech-ctcdecoders is Apache License 2.0. +The default license of paddlespeech-ctcdecoders is Apache License 2.0. diff --git a/utils/README.md b/utils/README.md index 163be850f..db2064efa 100644 --- a/utils/README.md +++ b/utils/README.md @@ -1,4 +1,4 @@ # Utils * [kaldi utils](https://github.com/kaldi-asr/kaldi/blob/cbed4ff688/egs/wsj/s5/utils) -* [espnet utils)(https://github.com/espnet/espnet/tree/master/utils) +* [espnet utils](https://github.com/espnet/espnet/tree/master/utils) diff --git a/utils/compute-wer.py b/utils/compute-wer.py index 978a80c9f..98bb24a7e 100755 --- a/utils/compute-wer.py +++ b/utils/compute-wer.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# CopyRight WeNet Apache-2.0 License +# Copyright 2021 Mobvoi Inc. All Rights Reserved. import codecs import re import sys diff --git a/speechx/examples/ngram/zh/local/text_to_lexicon.py b/utils/text_to_lexicon.py similarity index 100% rename from speechx/examples/ngram/zh/local/text_to_lexicon.py rename to utils/text_to_lexicon.py diff --git a/utils/zh_tn.py b/utils/zh_tn.py index 73bb8af22..6fee626bd 100755 --- a/utils/zh_tn.py +++ b/utils/zh_tn.py @@ -747,7 +747,7 @@ def num2chn(number_string, previous_symbol, (CNU, type(None))): if next_symbol.power != 1 and ( (previous_symbol is None) or - (previous_symbol.power != 1)): + (previous_symbol.power != 1)): # noqa: E129 result_symbols[i] = liang # if big is True, '两' will not be used and `alt_two` has no impact on output