pull/3197/head
YangZhou 1 year ago
commit a9027f18d0

@ -0,0 +1,77 @@
# Contributor Covenant Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to making participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Racial or political allusions
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies both within project spaces and in public spaces
when an individual is representing the project or its community. Examples of
representing a project or community include using an official project e-mail
address, posting via an official social media account, or acting as an appointed
representative at an online or offline event. Representation of a project may be
further defined and clarified by project maintainers.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at paddlespeech@baidu.com. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq

@ -0,0 +1,30 @@
# 💡 paddlespeech 提交代码须知
### Discussed in https://github.com/PaddlePaddle/PaddleSpeech/discussions/1326
<div type='discussions-op-text'>
<sup>Originally posted by **yt605155624** January 12, 2022</sup>
1. 写完代码之后可以用我们的 pre-commit 检查一下代码格式,注意只改自己修改的代码的格式即可,其他的代码有可能也被改了格式,不要 add 就好
```
pip install pre-commit
pre-commit run --file 你修改的代码
```
2. 提交 commit 中增加必要信息跳过不必要的 CI
- 提交 asr 相关代码
```text
git commit -m "xxxxxx, test=asr"
```
- 提交 tts 相关代码
```text
git commit -m "xxxxxx, test=tts"
```
- 仅修改文档
```text
git commit -m "xxxxxx, test=doc"
```
注意:
1. 虽然跳过了 CI但是还要先排队排到才能跳过所以非自己方向看到 pending 不要着急 🤣
2. 在 `git commit --amend` 的时候才加 `test=xxx` 可能不太有效
3. 一个 pr 多次提交 commit 注意每次都要加 `test=xxx`,因为每个 commit 都会触发 CI
4. 删除 python 环境中已经安装好的 paddlespeech否则可能会影响 import paddlespeech 的顺序</div>

@ -3,7 +3,6 @@ name: "\U0001F41B TTS Bug Report"
about: Create a report to help us improve
title: "[TTS]XXXX"
labels: Bug, T2S
assignees: yt605155624
---

5
.github/stale.yml vendored

@ -6,7 +6,8 @@ daysUntilClose: 30
exemptLabels:
- Roadmap
- Bug
- New Feature
- feature request
- Tips
# Label to use when marking an issue as stale
staleLabel: Stale
# Comment to post when marking an issue as stale. Set to `false` to disable
@ -17,4 +18,4 @@ markComment: >
unmarkComment: false
# Comment to post when closing a stale issue. Set to `false` to disable
closeComment: >
This issue is closed. Please re-open if needed.
This issue is closed. Please re-open if needed.

1
.gitignore vendored

@ -15,6 +15,7 @@
*.egg-info
build
*output/
.history
audio/dist/
audio/fc_patch/

@ -19,7 +19,7 @@ import subprocess
import platform
COPYRIGHT = '''
Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -128,4 +128,4 @@ def main(argv=None):
if __name__ == '__main__':
exit(main())
exit(main())

@ -97,26 +97,47 @@
</thead>
<tbody>
<tr>
<td >Life was like a box of chocolates, you never know what you're gonna get.</td>
<td>Life was like a box of chocolates, you never know what you're gonna get.</td>
<td align = "center">
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/tacotron2_ljspeech_waveflow_samples_0.2/sentence_1.wav" rel="nofollow">
<img align="center" src="./docs/images/audio_icon.png" width="200" style="max-width: 100%;"></a><br>
</td>
</tr>
<tr>
<td >早上好今天是2020/10/29最低温度是-3°C。</td>
<td>早上好今天是2020/10/29最低温度是-3°C。</td>
<td align = "center">
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/parakeet_espnet_fs2_pwg_demo/tn_g2p/parakeet/001.wav" rel="nofollow">
<img align="center" src="./docs/images/audio_icon.png" width="200" style="max-width: 100%;"></a><br>
</td>
</tr>
<tr>
<td >季姬寂,集鸡,鸡即棘鸡。棘鸡饥叽,季姬及箕稷济鸡。鸡既济,跻姬笈,季姬忌,急咭鸡,鸡急,继圾几,季姬急,即籍箕击鸡,箕疾击几伎,伎即齑,鸡叽集几基,季姬急极屐击鸡,鸡既殛,季姬激,即记《季姬击鸡记》。</td>
<td>季姬寂,集鸡,鸡即棘鸡。棘鸡饥叽,季姬及箕稷济鸡。鸡既济,跻姬笈,季姬忌,急咭鸡,鸡急,继圾几,季姬急,即籍箕击鸡,箕疾击几伎,伎即齑,鸡叽集几基,季姬急极屐击鸡,鸡既殛,季姬激,即记《季姬击鸡记》。</td>
<td align = "center">
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/jijiji.wav" rel="nofollow">
<img align="center" src="./docs/images/audio_icon.png" width="200" style="max-width: 100%;"></a><br>
</td>
</tr>
<tr>
<td>大家好,我是 parrot 虚拟老师我们来读一首诗我与春风皆过客I and the spring breeze are passing by你携秋水揽星河you take the autumn water to take the galaxy。</td>
<td align = "center">
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/labixiaoxin.wav" rel="nofollow">
<img align="center" src="./docs/images/audio_icon.png" width="200" style="max-width: 100%;"></a><br>
</td>
</tr>
<tr>
<td>宜家唔系事必要你讲,但系你所讲嘅说话将会变成呈堂证供。</td>
<td align = "center">
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/chengtangzhenggong.wav" rel="nofollow">
<img align="center" src="./docs/images/audio_icon.png" width="200" style="max-width: 100%;"></a><br>
</td>
</tr>
<tr>
<td>各个国家有各个国家嘅国歌</td>
<td align = "center">
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/gegege.wav" rel="nofollow">
<img align="center" src="./docs/images/audio_icon.png" width="200" style="max-width: 100%;"></a><br>
</td>
</tr>
</tbody>
</table>
@ -157,16 +178,24 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision
- 🧩 *Cascaded models application*: as an extension of the typical traditional audio tasks, we combine the workflows of the aforementioned tasks with other fields like Natural language processing (NLP) and Computer Vision (CV).
### Recent Update
- 🎉 2022.12.02: Add [end-to-end Prosody Prediction pipeline](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3_rhy) (including using prosody labels in Acoustic Model).
- 🎉 2022.11.30: Add [TTS Android Demo](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/demos/TTSAndroid).
- 🔥 2023.04.06: Add [subtitle file (.srt format) generation example](./demos/streaming_asr_server).
- 🔥 2023.03.14: Add SVS(Singing Voice Synthesis) examples with Opencpop dataset, including [DiffSinger](./examples/opencpop/svs1)、[PWGAN](./examples/opencpop/voc1) and [HiFiGAN](./examples/opencpop/voc5), the effect is continuously optimized.
- 👑 2023.03.09: Add [Wav2vec2ASR-zh](./examples/aishell/asr3).
- 🎉 2023.03.07: Add [TTS ARM Linux C++ Demo (with C++ Chinese Text Frontend)](./demos/TTSArmLinux).
- 🔥 2023.03.03 Add Voice Conversion [StarGANv2-VC synthesize pipeline](./examples/vctk/vc3).
- 🎉 2023.02.16: Add [Cantonese TTS](./examples/canton/tts3).
- 🔥 2023.01.10: Add [code-switch asr CLI and Demos](./demos/speech_recognition).
- 👑 2023.01.06: Add [code-switch asr tal_cs recipe](./examples/tal_cs/asr1/).
- 🎉 2022.12.02: Add [end-to-end Prosody Prediction pipeline](./examples/csmsc/tts3_rhy) (including using prosody labels in Acoustic Model).
- 🎉 2022.11.30: Add [TTS Android Demo](./demos/TTSAndroid).
- 🤗 2022.11.28: PP-TTS and PP-ASR demos are available in [AIStudio](https://aistudio.baidu.com/aistudio/modelsoverview) and [official website
of paddlepaddle](https://www.paddlepaddle.org.cn/models).
- 👑 2022.11.18: Add [Whisper CLI and Demos](https://github.com/PaddlePaddle/PaddleSpeech/pull/2640), support multi language recognition and translation.
- 🔥 2022.11.18: Add [Wav2vec2 CLI and Demos](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/demos/speech_ssl), Support ASR and Feature Extraction.
- 🔥 2022.11.18: Add [Wav2vec2 CLI and Demos](./demos/speech_ssl), Support ASR and Feature Extraction.
- 🎉 2022.11.17: Add [male voice for TTS](https://github.com/PaddlePaddle/PaddleSpeech/pull/2660).
- 🔥 2022.11.07: Add [U2/U2++ C++ High Performance Streaming ASR Deployment](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/runtime/examples/u2pp_ol/wenetspeech).
- 👑 2022.11.01: Add [Adversarial Loss](https://arxiv.org/pdf/1907.04448.pdf) for [Chinese English mixed TTS](./examples/zh_en_tts/tts3).
- 🔥 2022.10.26: Add [Prosody Prediction](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/rhy) for TTS.
- 🔥 2022.10.26: Add [Prosody Prediction](./examples/other/rhy) for TTS.
- 🎉 2022.10.21: Add [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) for TTS Chinese Text Frontend.
- 👑 2022.10.11: Add [Wav2vec2ASR-en](./examples/librispeech/asr3), wav2vec2.0 fine-tuning for ASR on LibriSpeech.
- 🔥 2022.09.26: Add Voice Cloning, TTS finetune, and [ERNIE-SAT](https://arxiv.org/abs/2211.03545) in [PaddleSpeech Web Demo](./demos/speech_web).
@ -180,16 +209,16 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision
- 🎉 2022.06.22: All TTS models support ONNX format.
- 🍀 2022.06.17: Add [PaddleSpeech Web Demo](./demos/speech_web).
- 👑 2022.05.13: Release [PP-ASR](./docs/source/asr/PPASR.md)、[PP-TTS](./docs/source/tts/PPTTS.md)、[PP-VPR](docs/source/vpr/PPVPR.md).
- 👏🏻 2022.05.06: `PaddleSpeech Streaming Server` is available for `Streaming ASR` with `Punctuation Restoration` and `Token Timestamp` and `Text-to-Speech`.
- 👏🏻 2022.05.06: `PaddleSpeech Server` is available for `Audio Classification`, `Automatic Speech Recognition` and `Text-to-Speech`, `Speaker Verification` and `Punctuation Restoration`.
- 👏🏻 2022.03.28: `PaddleSpeech CLI` is available for `Speaker Verification`.
- 👏🏻 2021.12.10: `PaddleSpeech CLI` is available for `Audio Classification`, `Automatic Speech Recognition`, `Speech Translation (English to Chinese)` and `Text-to-Speech`.
- 👏🏻 2022.05.06: `PaddleSpeech Streaming Server` is available for `Streaming ASR` with `Punctuation Restoration` and `Token Timestamp` and `Text-to-Speech`.
- 👏🏻 2022.05.06: `PaddleSpeech Server` is available for `Audio Classification`, `Automatic Speech Recognition` and `Text-to-Speech`, `Speaker Verification` and `Punctuation Restoration`.
- 👏🏻 2022.03.28: `PaddleSpeech CLI` is available for `Speaker Verification`.
- 👏🏻 2021.12.10: `PaddleSpeech CLI` is available for `Audio Classification`, `Automatic Speech Recognition`, `Speech Translation (English to Chinese)` and `Text-to-Speech`.
### Community
- Scan the QR code below with your Wechat, you can access to official technical exchange group and get the bonus ( more than 20GB learning materials, such as papers, codes and videos ) and the live link of the lessons. Look forward to your participation.
<div align="center">
<img src="https://user-images.githubusercontent.com/30135920/196351517-19dece6b-d6ea-448e-a341-d6bfe5712ec1.jpg" width = "200" />
<img src="https://user-images.githubusercontent.com/30135920/212860467-9e943cc3-8be8-49a4-97fd-7c94aad8e979.jpg" width = "200" />
</div>
## Installation
@ -550,14 +579,14 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
</thead>
<tbody>
<tr>
<td> Text Frontend </td>
<td colspan="2"> &emsp; </td>
<td>
<a href = "./examples/other/tn">tn</a> / <a href = "./examples/other/g2p">g2p</a>
</td>
<td> Text Frontend </td>
<td colspan="2"> &emsp; </td>
<td>
<a href = "./examples/other/tn">tn</a> / <a href = "./examples/other/g2p">g2p</a>
</td>
</tr>
<tr>
<td rowspan="5">Acoustic Model</td>
<td rowspan="6">Acoustic Model</td>
<td>Tacotron2</td>
<td>LJSpeech / CSMSC</td>
<td>
@ -592,6 +621,13 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
<a href = "./examples/vctk/ernie_sat">ERNIE-SAT-vctk</a> / <a href = "./examples/aishell3/ernie_sat">ERNIE-SAT-aishell3</a> / <a href = "./examples/aishell3_vctk/ernie_sat">ERNIE-SAT-zh_en</a>
</td>
</tr>
<tr>
<td>DiffSinger</td>
<td>Opencpop</td>
<td>
<a href = "./examples/opencpop/svs1">DiffSinger-opencpop</a>
</td>
</tr>
<tr>
<td rowspan="6">Vocoder</td>
<td >WaveFlow</td>
@ -602,9 +638,9 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
</tr>
<tr>
<td >Parallel WaveGAN</td>
<td >LJSpeech / VCTK / CSMSC / AISHELL-3</td>
<td >LJSpeech / VCTK / CSMSC / AISHELL-3 / Opencpop</td>
<td>
<a href = "./examples/ljspeech/voc1">PWGAN-ljspeech</a> / <a href = "./examples/vctk/voc1">PWGAN-vctk</a> / <a href = "./examples/csmsc/voc1">PWGAN-csmsc</a> / <a href = "./examples/aishell3/voc1">PWGAN-aishell3</a>
<a href = "./examples/ljspeech/voc1">PWGAN-ljspeech</a> / <a href = "./examples/vctk/voc1">PWGAN-vctk</a> / <a href = "./examples/csmsc/voc1">PWGAN-csmsc</a> / <a href = "./examples/aishell3/voc1">PWGAN-aishell3</a> / <a href = "./examples/opencpop/voc1">PWGAN-opencpop</a>
</td>
</tr>
<tr>
@ -623,9 +659,9 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
</tr>
<tr>
<td>HiFiGAN</td>
<td>LJSpeech / VCTK / CSMSC / AISHELL-3</td>
<td>LJSpeech / VCTK / CSMSC / AISHELL-3 / Opencpop</td>
<td>
<a href = "./examples/ljspeech/voc5">HiFiGAN-ljspeech</a> / <a href = "./examples/vctk/voc5">HiFiGAN-vctk</a> / <a href = "./examples/csmsc/voc5">HiFiGAN-csmsc</a> / <a href = "./examples/aishell3/voc5">HiFiGAN-aishell3</a>
<a href = "./examples/ljspeech/voc5">HiFiGAN-ljspeech</a> / <a href = "./examples/vctk/voc5">HiFiGAN-vctk</a> / <a href = "./examples/csmsc/voc5">HiFiGAN-csmsc</a> / <a href = "./examples/aishell3/voc5">HiFiGAN-aishell3</a> / <a href = "./examples/opencpop/voc5">HiFiGAN-opencpop</a>
</td>
</tr>
<tr>
@ -985,10 +1021,16 @@ You are warmly welcome to submit questions in [discussions](https://github.com/P
- Many thanks to [vpegasus](https://github.com/vpegasus)/[xuesebot](https://github.com/vpegasus/xuesebot) for developing a rasa chatbot,which is able to speak and listen thanks to PaddleSpeech.
- Many thanks to [chenkui164](https://github.com/chenkui164)/[FastASR](https://github.com/chenkui164/FastASR) for the C++ inference implementation of PaddleSpeech ASR.
- Many thanks to [heyudage](https://github.com/heyudage)/[VoiceTyping](https://github.com/heyudage/VoiceTyping) for the real-time voice typing tool implementation of PaddleSpeech ASR streaming services.
- Many thanks to [EscaticZheng](https://github.com/EscaticZheng)/[ps3.9wheel-install](https://github.com/EscaticZheng/ps3.9wheel-install) for the python3.9 prebuilt wheel for PaddleSpeech installation in Windows without Viusal Studio.
Besides, PaddleSpeech depends on a lot of open source repositories. See [references](./docs/source/reference.md) for more information.
- Many thanks to [chinobing](https://github.com/chinobing)/[FastAPI-PaddleSpeech-Audio-To-Text](https://github.com/chinobing/FastAPI-PaddleSpeech-Audio-To-Text) for converting audio to text based on FastAPI and PaddleSpeech.
- Many thanks to [MistEO](https://github.com/MistEO)/[Pallas-Bot](https://github.com/MistEO/Pallas-Bot) for QQ bot based on PaddleSpeech TTS.
<a name="License"></a>
## License
PaddleSpeech is provided under the [Apache-2.0 License](./LICENSE).
## Stargazers over time
[![Stargazers over time](https://starchart.cc/PaddlePaddle/PaddleSpeech.svg)](https://starchart.cc/PaddlePaddle/PaddleSpeech)

@ -122,6 +122,27 @@
<img align="center" src="./docs/images/audio_icon.png" width="200" style="max-width: 100%;"></a><br>
</td>
</tr>
<tr>
<td>大家好,我是 parrot 虚拟老师我们来读一首诗我与春风皆过客I and the spring breeze are passing by你携秋水揽星河you take the autumn water to take the galaxy。</td>
<td align = "center">
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/labixiaoxin.wav" rel="nofollow">
<img align="center" src="./docs/images/audio_icon.png" width="200" style="max-width: 100%;"></a><br>
</td>
</tr>
<tr>
<td>宜家唔系事必要你讲,但系你所讲嘅说话将会变成呈堂证供。</td>
<td align = "center">
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/chengtangzhenggong.wav" rel="nofollow">
<img align="center" src="./docs/images/audio_icon.png" width="200" style="max-width: 100%;"></a><br>
</td>
</tr>
<tr>
<td>各个国家有各个国家嘅国歌</td>
<td align = "center">
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/gegege.wav" rel="nofollow">
<img align="center" src="./docs/images/audio_icon.png" width="200" style="max-width: 100%;"></a><br>
</td>
</tr>
</tbody>
</table>
@ -161,18 +182,24 @@
- 🔬 主流模型及数据集: 本工具包实现了参与整条语音任务流水线的各个模块,并且采用了主流数据集如 LibriSpeech、LJSpeech、AIShell、CSMSC详情请见 [模型列表](#model-list)。
- 🧩 级联模型应用: 作为传统语音任务的扩展,我们结合了自然语言处理、计算机视觉等任务,实现更接近实际需求的产业级应用。
### 近期更新
- 🎉 2022.12.02: 新增 [端到端韵律预测全流程](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3_rhy) (包含在声学模型中使用韵律标签)。
- 🎉 2022.11.30: 新增 [TTS Android 部署示例](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/demos/TTSAndroid)。
- 👑 2023.04.06: 新增 [srt格式字幕生成功能](./demos/streaming_asr_server)。
- 🔥 2023.03.14: 新增基于 Opencpop 数据集的 SVS (歌唱合成) 示例,包含 [DiffSinger](./examples/opencpop/svs1)、[PWGAN](./examples/opencpop/voc1) 和 [HiFiGAN](./examples/opencpop/voc5),效果持续优化中。
- 👑 2023.03.09: 新增 [Wav2vec2ASR-zh](./examples/aishell/asr3)。
- 🎉 2023.03.07: 新增 [TTS ARM Linux C++ 部署示例 (包含 C++ 中文文本前端模块)](./demos/TTSArmLinux)。
- 🔥 2023.03.03: 新增声音转换模型 [StarGANv2-VC 合成流程](./examples/vctk/vc3)。
- 🎉 2023.02.16: 新增[粤语语音合成](./examples/canton/tts3)。
- 🔥 2023.01.10: 新增[中英混合 ASR CLI 和 Demos](./demos/speech_recognition)。
- 👑 2023.01.06: 新增 [ASR 中英混合 tal_cs 训练推理流程](./examples/tal_cs/asr1/)。
- 🎉 2022.12.02: 新增[端到端韵律预测全流程](./examples/csmsc/tts3_rhy) (包含在声学模型中使用韵律标签)。
- 🎉 2022.11.30: 新增 [TTS Android 部署示例](./demos/TTSAndroid)。
- 🤗 2022.11.28: PP-TTS and PP-ASR 示例可在 [AIStudio](https://aistudio.baidu.com/aistudio/modelsoverview) 和[飞桨官网](https://www.paddlepaddle.org.cn/models)体验!
- 👑 2022.11.18: 新增 [Whisper CLI 和 Demos](https://github.com/PaddlePaddle/PaddleSpeech/pull/2640), 支持多种语言的识别与翻译。
- 🔥 2022.11.18: 新增 [Wav2vec2 CLI 和 Demos](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/demos/speech_ssl), 支持 ASR 和 特征提取.
- 🔥 2022.11.18: 新增 [Wav2vec2 CLI 和 Demos](./demos/speech_ssl), 支持 ASR 和特征提取。
- 🎉 2022.11.17: TTS 新增[高质量男性音色](https://github.com/PaddlePaddle/PaddleSpeech/pull/2660)。
- 🔥 2022.11.07: 新增 [U2/U2++ 高性能流式 ASR C++ 部署](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/speechx/examples/u2pp_ol/wenetspeech)。
- 🔥 2022.11.07: 新增 [U2/U2++ 高性能流式 ASR C++ 部署](./speechx/examples/u2pp_ol/wenetspeech)。
- 👑 2022.11.01: [中英文混合 TTS](./examples/zh_en_tts/tts3) 新增 [Adversarial Loss](https://arxiv.org/pdf/1907.04448.pdf) 模块。
- 🔥 2022.10.26: TTS 新增[韵律预测](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/rhy)功能。
- 🔥 2022.10.26: TTS 新增[韵律预测](./develop/examples/other/rhy)功能。
- 🎉 2022.10.21: TTS 中文文本前端新增 [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) 功能。
- 👑 2022.10.11: 新增 [Wav2vec2ASR-en](./examples/librispeech/asr3), 在 LibriSpeech 上针对 ASR 任务对 wav2vec2.0 的 finetuning。
- 🔥 2022.09.26: 新增 Voice Cloning, TTS finetune 和 [ERNIE-SAT](https://arxiv.org/abs/2211.03545) 到 [PaddleSpeech 网页应用](./demos/speech_web)。
@ -200,7 +227,7 @@
微信扫描二维码关注公众号,点击“马上报名”填写问卷加入官方交流群,获得更高效的问题答疑,与各行各业开发者充分交流,期待您的加入。
<div align="center">
<img src="https://user-images.githubusercontent.com/30135920/196351517-19dece6b-d6ea-448e-a341-d6bfe5712ec1.jpg" width = "200" />
<img src="https://user-images.githubusercontent.com/30135920/212860467-9e943cc3-8be8-49a4-97fd-7c94aad8e979.jpg" width = "200" />
</div>
<a name="安装"></a>
@ -551,43 +578,50 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
<td>
<a href = "./examples/other/tn">tn</a> / <a href = "./examples/other/g2p">g2p</a>
</td>
</tr>
<tr>
<td rowspan="5">声学模型</td>
</tr>
<tr>
<td rowspan="6">声学模型</td>
<td>Tacotron2</td>
<td>LJSpeech / CSMSC</td>
<td>
<a href = "./examples/ljspeech/tts0">tacotron2-ljspeech</a> / <a href = "./examples/csmsc/tts0">tacotron2-csmsc</a>
</td>
</tr>
<tr>
</tr>
<tr>
<td>Transformer TTS</td>
<td>LJSpeech</td>
<td>
<a href = "./examples/ljspeech/tts1">transformer-ljspeech</a>
</td>
</tr>
<tr>
</tr>
<tr>
<td>SpeedySpeech</td>
<td>CSMSC</td>
<td >
<a href = "./examples/csmsc/tts2">speedyspeech-csmsc</a>
</td>
</tr>
<tr>
</tr>
<tr>
<td>FastSpeech2</td>
<td>LJSpeech / VCTK / CSMSC / AISHELL-3 / ZH_EN / finetune</td>
<td>
<a href = "./examples/ljspeech/tts3">fastspeech2-ljspeech</a> / <a href = "./examples/vctk/tts3">fastspeech2-vctk</a> / <a href = "./examples/csmsc/tts3">fastspeech2-csmsc</a> / <a href = "./examples/aishell3/tts3">fastspeech2-aishell3</a> / <a href = "./examples/zh_en_tts/tts3">fastspeech2-zh_en</a> / <a href = "./examples/other/tts_finetune/tts3">fastspeech2-finetune</a>
</td>
</tr>
<tr>
</tr>
<tr>
<td><a href = "https://arxiv.org/abs/2211.03545">ERNIE-SAT</a></td>
<td>VCTK / AISHELL-3 / ZH_EN</td>
<td>
<a href = "./examples/vctk/ernie_sat">ERNIE-SAT-vctk</a> / <a href = "./examples/aishell3/ernie_sat">ERNIE-SAT-aishell3</a> / <a href = "./examples/aishell3_vctk/ernie_sat">ERNIE-SAT-zh_en</a>
</td>
</tr>
</tr>
<tr>
<td>DiffSinger</td>
<td>Opencpop</td>
<td>
<a href = "./examples/opencpop/svs1">DiffSinger-opencpop</a>
</td>
</tr>
<tr>
<td rowspan="6">声码器</td>
<td >WaveFlow</td>
@ -598,9 +632,9 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
</tr>
<tr>
<td >Parallel WaveGAN</td>
<td >LJSpeech / VCTK / CSMSC / AISHELL-3</td>
<td >LJSpeech / VCTK / CSMSC / AISHELL-3 / Opencpop</td>
<td>
<a href = "./examples/ljspeech/voc1">PWGAN-ljspeech</a> / <a href = "./examples/vctk/voc1">PWGAN-vctk</a> / <a href = "./examples/csmsc/voc1">PWGAN-csmsc</a> / <a href = "./examples/aishell3/voc1">PWGAN-aishell3</a>
<a href = "./examples/ljspeech/voc1">PWGAN-ljspeech</a> / <a href = "./examples/vctk/voc1">PWGAN-vctk</a> / <a href = "./examples/csmsc/voc1">PWGAN-csmsc</a> / <a href = "./examples/aishell3/voc1">PWGAN-aishell3</a> / <a href = "./examples/opencpop/voc1">PWGAN-opencpop</a>
</td>
</tr>
<tr>
@ -619,9 +653,9 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
</tr>
<tr>
<td >HiFiGAN</td>
<td >LJSpeech / VCTK / CSMSC / AISHELL-3</td>
<td >LJSpeech / VCTK / CSMSC / AISHELL-3 / Opencpop</td>
<td>
<a href = "./examples/ljspeech/voc5">HiFiGAN-ljspeech</a> / <a href = "./examples/vctk/voc5">HiFiGAN-vctk</a> / <a href = "./examples/csmsc/voc5">HiFiGAN-csmsc</a> / <a href = "./examples/aishell3/voc5">HiFiGAN-aishell3</a>
<a href = "./examples/ljspeech/voc5">HiFiGAN-ljspeech</a> / <a href = "./examples/vctk/voc5">HiFiGAN-vctk</a> / <a href = "./examples/csmsc/voc5">HiFiGAN-csmsc</a> / <a href = "./examples/aishell3/voc5">HiFiGAN-aishell3</a> / <a href = "./examples/opencpop/voc5">HiFiGAN-opencpop</a>
</td>
</tr>
<tr>
@ -678,6 +712,7 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
</tbody>
</table>
<a name="声音分类模型"></a>
**声音分类**
@ -986,13 +1021,19 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
- 非常感谢 [awmmmm](https://github.com/awmmmm) 提供 fastspeech2 aishell3 conformer 预训练模型。
- 非常感谢 [phecda-xu](https://github.com/phecda-xu)/[PaddleDubbing](https://github.com/phecda-xu/PaddleDubbing) 基于 PaddleSpeech 的 TTS 模型搭建带 GUI 操作界面的配音工具。
- 非常感谢 [jerryuhoo](https://github.com/jerryuhoo)/[VTuberTalk](https://github.com/jerryuhoo/VTuberTalk) 基于 PaddleSpeech 的 TTS GUI 界面和基于 ASR 制作数据集的相关代码。
- 非常感谢 [vpegasus](https://github.com/vpegasus)/[xuesebot](https://github.com/vpegasus/xuesebot) 基于 PaddleSpeech 的 ASR 与 TTS 设计的可听、说对话机器人。
- 非常感谢 [chenkui164](https://github.com/chenkui164)/[FastASR](https://github.com/chenkui164/FastASR) 对 PaddleSpeech 的 ASR 进行 C++ 推理实现。
- 非常感谢 [heyudage](https://github.com/heyudage)/[VoiceTyping](https://github.com/heyudage/VoiceTyping) 基于 PaddleSpeech 的 ASR 流式服务实现的实时语音输入法工具。
- 非常感谢 [EscaticZheng](https://github.com/EscaticZheng)/[ps3.9wheel-install](https://github.com/EscaticZheng/ps3.9wheel-install) 对PaddleSpeech在Windows下的安装提供了无需Visua Studio基于python3.9的预编译依赖安装包。
- 非常感谢 [chinobing](https://github.com/chinobing)/[FastAPI-PaddleSpeech-Audio-To-Text](https://github.com/chinobing/FastAPI-PaddleSpeech-Audio-To-Text) 利用 FastAPI 实现 PaddleSpeech 语音转文字,文件上传、分割、转换进度显示、后台更新任务并以 csv 格式输出。
- 非常感谢 [MistEO](https://github.com/MistEO)/[Pallas-Bot](https://github.com/MistEO/Pallas-Bot) 基于 PaddleSpeech TTS 的 QQ Bot 项目。
此外PaddleSpeech 依赖于许多开源存储库。有关更多信息,请参阅 [references](./docs/source/reference.md)。
## License
PaddleSpeech 在 [Apache-2.0 许可](./LICENSE) 下提供。
## Stargazers over time
[![Stargazers over time](https://starchart.cc/PaddlePaddle/PaddleSpeech.svg)](https://starchart.cc/PaddlePaddle/PaddleSpeech)

@ -41,24 +41,18 @@ option(BUILD_PADDLEAUDIO_PYTHON_EXTENSION "Build Python extension" ON)
# cmake
set(CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH};${PROJECT_SOURCE_DIR}/cmake;${PROJECT_SOURCE_DIR}/cmake/external")
if (NOT MSVC)
find_package(GFortranLibs REQUIRED)
include(FortranCInterface)
include(FindGFortranLibs REQUIRED)
endif()
# fc_patch dir
set(FETCHCONTENT_QUIET off)
get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}")
set(FETCHCONTENT_BASE_DIR ${fc_patch})
set(THIRD_PARTY_PATH ${fc_patch})
include(openblas)
set(PYBIND11_PYTHON_VERSION ${PY_VERSION})
include(cmake/pybind.cmake)
include_directories(${PYTHON_INCLUDE_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/paddleaudio/third_party/)
# packages
find_package(Python3 COMPONENTS Interpreter Development)

@ -2,33 +2,22 @@
安装方式: pip install paddleaudio
目前支持的平台Linux
目前支持的平台Linux, Mac, Windows
## Environment
## Build wheel
cmd: python setup.py bdist_wheel
Linux test build whl environment:
* docker - `registry.baidubce.com/paddlepaddle/paddle:2.2.2`
* os - Ubuntu 16.04.7 LTS
* gcc/g++/gfortran - 8.2.0
* gcc/g++ - 8.2.0
* cmake - 3.18.0 (need install)
* [How to Install Docker](https://docs.docker.com/engine/install/)
* [A Docker Tutorial for Beginners](https://docker-curriculum.com/)
1. First to launch docker container.
```
docker run --privileged --net=host --ipc=host -it --rm -v $PWD:/workspace --name=dev registry.baidubce.com/paddlepaddle/paddle:2.2.2 /bin/bash
```
2. python setup.py bdist_wheel
MACtest build whl envrioment
* os
* gcc/g++/gfortran 12.2.0
* gcc/g++ 12.2.0
* cpu Intel Xeon E5 x86_64
Windows
not support paddleaudio C++ extension lib (sox io, kaldi native fbank)
python setup.py bdist_wheel
not support paddleaudio C++ extension lib (sox io, kaldi native fbank)

@ -1,19 +1,3 @@
add_subdirectory(third_party)
add_subdirectory(src)
if (APPLE)
file(COPY ${GFORTRAN_LIBRARIES_DIR}/libgcc_s.1.1.dylib
DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/lib)
endif(APPLE)
if (UNIX AND NOT APPLE)
file(COPY ${GFORTRAN_LIBRARIES_DIR}/libgfortran.so.5
DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/lib FOLLOW_SYMLINK_CHAIN)
file(COPY ${GFORTRAN_LIBRARIES_DIR}/libquadmath.so.0
DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/lib FOLLOW_SYMLINK_CHAIN)
file(COPY ${GFORTRAN_LIBRARIES_DIR}/libgcc_s.so.1
DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/lib FOLLOW_SYMLINK_CHAIN)
endif()

@ -67,8 +67,11 @@ def deprecated(direction: str, version: Optional[str]=None):
def is_kaldi_available():
return is_module_available("paddleaudio._paddleaudio")
try:
from paddleaudio import _paddleaudio
return True
except Exception:
return False
def requires_kaldi():
if is_kaldi_available():
@ -128,9 +131,11 @@ def requires_soundfile():
def is_sox_available():
if platform.system() == "Windows": # not support sox in windows
try:
from paddleaudio import _paddleaudio
return True
except Exception:
return False
return is_module_available("paddleaudio._paddleaudio")
def requires_sox():

@ -191,7 +191,7 @@ def soundfile_save(y: np.ndarray, sr: int, file: os.PathLike) -> None:
if sr <= 0:
raise ParameterError(
f'Sample rate should be larger than 0, recieved sr = {sr}')
f'Sample rate should be larger than 0, received sr = {sr}')
if y.dtype not in ['int16', 'int8']:
warnings.warn(

@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .kaldi import fbank
from .kaldi import pitch
#from .kaldi import pitch

@ -16,7 +16,6 @@ from paddleaudio._internal import module_utils
__all__ = [
'fbank',
'pitch',
]
@ -33,8 +32,6 @@ def fbank(
round_to_power_of_two: bool=True,
blackman_coeff: float=0.42,
snip_edges: bool=True,
allow_downsample: bool=False,
allow_upsample: bool=False,
max_feature_vectors: int=-1,
num_bins: int=23,
low_freq: float=20,
@ -62,8 +59,6 @@ def fbank(
frame_opts.round_to_power_of_two = round_to_power_of_two
frame_opts.blackman_coeff = blackman_coeff
frame_opts.snip_edges = snip_edges
frame_opts.allow_downsample = allow_downsample
frame_opts.allow_upsample = allow_upsample
frame_opts.max_feature_vectors = max_feature_vectors
mel_opts.num_bins = num_bins
@ -85,48 +80,48 @@ def fbank(
return feat
@module_utils.requires_kaldi()
def pitch(wav,
samp_freq: int=16000,
frame_shift_ms: float=10.0,
frame_length_ms: float=25.0,
preemph_coeff: float=0.0,
min_f0: int=50,
max_f0: int=400,
soft_min_f0: float=10.0,
penalty_factor: float=0.1,
lowpass_cutoff: int=1000,
resample_freq: int=4000,
delta_pitch: float=0.005,
nccf_ballast: int=7000,
lowpass_filter_width: int=1,
upsample_filter_width: int=5,
max_frames_latency: int=0,
frames_per_chunk: int=0,
simulate_first_pass_online: bool=False,
recompute_frame: int=500,
nccf_ballast_online: bool=False,
snip_edges: bool=True):
pitch_opts = paddleaudio._paddleaudio.PitchExtractionOptions()
pitch_opts.samp_freq = samp_freq
pitch_opts.frame_shift_ms = frame_shift_ms
pitch_opts.frame_length_ms = frame_length_ms
pitch_opts.preemph_coeff = preemph_coeff
pitch_opts.min_f0 = min_f0
pitch_opts.max_f0 = max_f0
pitch_opts.soft_min_f0 = soft_min_f0
pitch_opts.penalty_factor = penalty_factor
pitch_opts.lowpass_cutoff = lowpass_cutoff
pitch_opts.resample_freq = resample_freq
pitch_opts.delta_pitch = delta_pitch
pitch_opts.nccf_ballast = nccf_ballast
pitch_opts.lowpass_filter_width = lowpass_filter_width
pitch_opts.upsample_filter_width = upsample_filter_width
pitch_opts.max_frames_latency = max_frames_latency
pitch_opts.frames_per_chunk = frames_per_chunk
pitch_opts.simulate_first_pass_online = simulate_first_pass_online
pitch_opts.recompute_frame = recompute_frame
pitch_opts.nccf_ballast_online = nccf_ballast_online
pitch_opts.snip_edges = snip_edges
pitch = paddleaudio._paddleaudio.ComputeKaldiPitch(pitch_opts, wav)
return pitch
#@module_utils.requires_kaldi()
#def pitch(wav,
#samp_freq: int=16000,
#frame_shift_ms: float=10.0,
#frame_length_ms: float=25.0,
#preemph_coeff: float=0.0,
#min_f0: int=50,
#max_f0: int=400,
#soft_min_f0: float=10.0,
#penalty_factor: float=0.1,
#lowpass_cutoff: int=1000,
#resample_freq: int=4000,
#delta_pitch: float=0.005,
#nccf_ballast: int=7000,
#lowpass_filter_width: int=1,
#upsample_filter_width: int=5,
#max_frames_latency: int=0,
#frames_per_chunk: int=0,
#simulate_first_pass_online: bool=False,
#recompute_frame: int=500,
#nccf_ballast_online: bool=False,
#snip_edges: bool=True):
#pitch_opts = paddleaudio._paddleaudio.PitchExtractionOptions()
#pitch_opts.samp_freq = samp_freq
#pitch_opts.frame_shift_ms = frame_shift_ms
#pitch_opts.frame_length_ms = frame_length_ms
#pitch_opts.preemph_coeff = preemph_coeff
#pitch_opts.min_f0 = min_f0
#pitch_opts.max_f0 = max_f0
#pitch_opts.soft_min_f0 = soft_min_f0
#pitch_opts.penalty_factor = penalty_factor
#pitch_opts.lowpass_cutoff = lowpass_cutoff
#pitch_opts.resample_freq = resample_freq
#pitch_opts.delta_pitch = delta_pitch
#pitch_opts.nccf_ballast = nccf_ballast
#pitch_opts.lowpass_filter_width = lowpass_filter_width
#pitch_opts.upsample_filter_width = upsample_filter_width
#pitch_opts.max_frames_latency = max_frames_latency
#pitch_opts.frames_per_chunk = frames_per_chunk
#pitch_opts.simulate_first_pass_online = simulate_first_pass_online
#pitch_opts.recompute_frame = recompute_frame
#pitch_opts.nccf_ballast_online = nccf_ballast_online
#pitch_opts.snip_edges = snip_edges
#pitch = paddleaudio._paddleaudio.ComputeKaldiPitch(pitch_opts, wav)
#return pitch

@ -52,7 +52,7 @@ if(BUILD_KALDI)
list(
APPEND
LIBPADDLEAUDIO_LINK_LIBRARIES
libkaldi
kaldi-native-fbank-core
)
list(
APPEND
@ -92,14 +92,6 @@ define_library(
"${LIBPADDLEAUDIO_COMPILE_DEFINITIONS}"
)
if (APPLE)
add_custom_command(TARGET libpaddleaudio POST_BUILD COMMAND install_name_tool -change "${GFORTRAN_LIBRARIES_DIR}/libgcc_s.1.1.dylib" "@loader_path/libgcc_s.1.1.dylib" libpaddleaudio.so)
endif(APPLE)
if (UNIX AND NOT APPLE)
set_target_properties(libpaddleaudio PROPERTIES INSTALL_RPATH "$ORIGIN")
endif()
if (APPLE)
set(AUDIO_LIBRARY libpaddleaudio CACHE INTERNAL "")
else()
@ -207,11 +199,3 @@ define_extension(
# )
# endif()
endif()
if (APPLE)
add_custom_command(TARGET _paddleaudio POST_BUILD COMMAND install_name_tool -change "${GFORTRAN_LIBRARIES_DIR}/libgcc_s.1.1.dylib" "@loader_path/lib/libgcc_s.1.1.dylib" _paddleaudio.so)
endif(APPLE)
if (UNIX AND NOT APPLE)
set_target_properties(_paddleaudio PROPERTIES INSTALL_RPATH "$ORIGIN/lib")
endif()

@ -16,7 +16,7 @@
#include "pybind11/pybind11.h"
#include "pybind11/numpy.h"
#include "feat/feature-window.h"
#include "kaldi-native-fbank/csrc/feature-window.h"
namespace paddleaudio {
namespace kaldi {
@ -28,18 +28,18 @@ class StreamingFeatureTpl {
public:
typedef typename F::Options Options;
StreamingFeatureTpl(const Options& opts);
bool ComputeFeature(const ::kaldi::VectorBase<::kaldi::BaseFloat>& wav,
::kaldi::Vector<::kaldi::BaseFloat>* feats);
void Reset() { remained_wav_.Resize(0); }
bool ComputeFeature(const std::vector<float>& wav,
std::vector<float>* feats);
void Reset() { remained_wav_.resize(0); }
int Dim() { return computer_.Dim(); }
private:
bool Compute(const ::kaldi::Vector<::kaldi::BaseFloat>& waves,
::kaldi::Vector<::kaldi::BaseFloat>* feats);
bool Compute(const std::vector<float>& waves,
std::vector<float>* feats);
Options opts_;
::kaldi::FeatureWindowFunction window_function_;
::kaldi::Vector<::kaldi::BaseFloat> remained_wav_;
knf::FeatureWindowFunction window_function_;
std::vector<float> remained_wav_;
F computer_;
};

@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/kaldi-common.h"
namespace paddleaudio {
namespace kaldi {
@ -25,24 +24,29 @@ StreamingFeatureTpl<F>::StreamingFeatureTpl(const Options& opts)
template <class F>
bool StreamingFeatureTpl<F>::ComputeFeature(
const ::kaldi::VectorBase<::kaldi::BaseFloat>& wav,
::kaldi::Vector<::kaldi::BaseFloat>* feats) {
const std::vector<float>& wav,
std::vector<float>* feats) {
// append remaned waves
::kaldi::int32 wav_len = wav.Dim();
int wav_len = wav.size();
if (wav_len == 0) return false;
::kaldi::int32 left_len = remained_wav_.Dim();
::kaldi::Vector<::kaldi::BaseFloat> waves(left_len + wav_len);
waves.Range(0, left_len).CopyFromVec(remained_wav_);
waves.Range(left_len, wav_len).CopyFromVec(wav);
int left_len = remained_wav_.size();
std::vector<float> waves(left_len + wav_len);
std::memcpy(waves.data(),
remained_wav_.data(),
left_len * sizeof(float));
std::memcpy(waves.data() + left_len,
wav.data(),
wav_len * sizeof(float));
// cache remaned waves
::kaldi::FrameExtractionOptions frame_opts = computer_.GetFrameOptions();
::kaldi::int32 num_frames = ::kaldi::NumFrames(waves.Dim(), frame_opts);
::kaldi::int32 frame_shift = frame_opts.WindowShift();
::kaldi::int32 left_samples = waves.Dim() - frame_shift * num_frames;
remained_wav_.Resize(left_samples);
remained_wav_.CopyFromVec(
waves.Range(frame_shift * num_frames, left_samples));
knf::FrameExtractionOptions frame_opts = computer_.GetFrameOptions();
int num_frames = knf::NumFrames(waves.size(), frame_opts);
int frame_shift = frame_opts.WindowShift();
int left_samples = waves.size() - frame_shift * num_frames;
remained_wav_.resize(left_samples);
std::memcpy(remained_wav_.data(),
waves.data() + frame_shift * num_frames,
left_samples * sizeof(float));
// compute speech feature
Compute(waves, feats);
@ -51,40 +55,39 @@ bool StreamingFeatureTpl<F>::ComputeFeature(
// Compute feat
template <class F>
bool StreamingFeatureTpl<F>::Compute(
const ::kaldi::Vector<::kaldi::BaseFloat>& waves,
::kaldi::Vector<::kaldi::BaseFloat>* feats) {
::kaldi::BaseFloat vtln_warp = 1.0;
const ::kaldi::FrameExtractionOptions& frame_opts =
computer_.GetFrameOptions();
::kaldi::int32 num_samples = waves.Dim();
::kaldi::int32 frame_length = frame_opts.WindowSize();
::kaldi::int32 sample_rate = frame_opts.samp_freq;
bool StreamingFeatureTpl<F>::Compute(const std::vector<float>& waves,
std::vector<float>* feats) {
const knf::FrameExtractionOptions& frame_opts = computer_.GetFrameOptions();
int num_samples = waves.size();
int frame_length = frame_opts.WindowSize();
int sample_rate = frame_opts.samp_freq;
if (num_samples < frame_length) {
return false;
return true;
}
::kaldi::int32 num_frames = ::kaldi::NumFrames(num_samples, frame_opts);
feats->Resize(num_frames * Dim());
int num_frames = knf::NumFrames(num_samples, frame_opts);
feats->resize(num_frames * Dim());
::kaldi::Vector<::kaldi::BaseFloat> window;
std::vector<float> window;
bool need_raw_log_energy = computer_.NeedRawLogEnergy();
for (::kaldi::int32 frame = 0; frame < num_frames; frame++) {
::kaldi::BaseFloat raw_log_energy = 0.0;
::kaldi::ExtractWindow(0,
waves,
frame,
frame_opts,
window_function_,
&window,
need_raw_log_energy ? &raw_log_energy : NULL);
for (int frame = 0; frame < num_frames; frame++) {
std::fill(window.begin(), window.end(), 0);
float raw_log_energy = 0.0;
float vtln_warp = 1.0;
knf::ExtractWindow(0,
waves,
frame,
frame_opts,
window_function_,
&window,
need_raw_log_energy ? &raw_log_energy : NULL);
::kaldi::Vector<::kaldi::BaseFloat> this_feature(computer_.Dim(),
::kaldi::kUndefined);
computer_.Compute(raw_log_energy, vtln_warp, &window, &this_feature);
::kaldi::SubVector<::kaldi::BaseFloat> output_row(
feats->Data() + frame * Dim(), Dim());
output_row.CopyFromVec(this_feature);
std::vector<float> this_feature(computer_.Dim());
computer_.Compute(
raw_log_energy, vtln_warp, &window, this_feature.data());
std::memcpy(feats->data() + frame * Dim(),
this_feature.data(),
sizeof(float) * Dim());
}
return true;
}

@ -13,16 +13,16 @@
// limitations under the License.
#include "paddleaudio/src/pybind/kaldi/kaldi_feature.h"
#include "feat/pitch-functions.h"
//#include "feat/pitch-functions.h"
namespace paddleaudio {
namespace kaldi {
bool InitFbank(
::kaldi::FrameExtractionOptions frame_opts,
::kaldi::MelBanksOptions mel_opts,
knf::FrameExtractionOptions frame_opts,
knf::MelBanksOptions mel_opts,
FbankOptions fbank_opts) {
::kaldi::FbankOptions opts;
knf::FbankOptions opts;
opts.frame_opts = frame_opts;
opts.mel_opts = mel_opts;
opts.use_energy = fbank_opts.use_energy;
@ -41,8 +41,8 @@ py::array_t<float> ComputeFbankStreaming(const py::array_t<float>& wav) {
}
py::array_t<float> ComputeFbank(
::kaldi::FrameExtractionOptions frame_opts,
::kaldi::MelBanksOptions mel_opts,
knf::FrameExtractionOptions frame_opts,
knf::MelBanksOptions mel_opts,
FbankOptions fbank_opts,
const py::array_t<float>& wav) {
InitFbank(frame_opts, mel_opts, fbank_opts);
@ -55,21 +55,21 @@ void ResetFbank() {
paddleaudio::kaldi::KaldiFeatureWrapper::GetInstance()->ResetFbank();
}
py::array_t<float> ComputeKaldiPitch(
const ::kaldi::PitchExtractionOptions& opts,
const py::array_t<float>& wav) {
py::buffer_info info = wav.request();
::kaldi::SubVector<::kaldi::BaseFloat> input_wav((float*)info.ptr, info.size);
//py::array_t<float> ComputeKaldiPitch(
//const ::kaldi::PitchExtractionOptions& opts,
//const py::array_t<float>& wav) {
//py::buffer_info info = wav.request();
//::kaldi::SubVector<::kaldi::BaseFloat> input_wav((float*)info.ptr, info.size);
::kaldi::Matrix<::kaldi::BaseFloat> features;
::kaldi::ComputeKaldiPitch(opts, input_wav, &features);
auto result = py::array_t<float>({features.NumRows(), features.NumCols()});
for (int row_idx = 0; row_idx < features.NumRows(); ++row_idx) {
std::memcpy(result.mutable_data(row_idx), features.Row(row_idx).Data(),
sizeof(float)*features.NumCols());
}
return result;
}
//::kaldi::Matrix<::kaldi::BaseFloat> features;
//::kaldi::ComputeKaldiPitch(opts, input_wav, &features);
//auto result = py::array_t<float>({features.NumRows(), features.NumCols()});
//for (int row_idx = 0; row_idx < features.NumRows(); ++row_idx) {
//std::memcpy(result.mutable_data(row_idx), features.Row(row_idx).Data(),
//sizeof(float)*features.NumCols());
//}
//return result;
//}
} // namespace kaldi
} // namespace paddleaudio

@ -19,7 +19,7 @@
#include <string>
#include "paddleaudio/src/pybind/kaldi/kaldi_feature_wrapper.h"
#include "feat/pitch-functions.h"
//#include "feat/pitch-functions.h"
namespace py = pybind11;
@ -42,13 +42,13 @@ struct FbankOptions{
};
bool InitFbank(
::kaldi::FrameExtractionOptions frame_opts,
::kaldi::MelBanksOptions mel_opts,
knf::FrameExtractionOptions frame_opts,
knf::MelBanksOptions mel_opts,
FbankOptions fbank_opts);
py::array_t<float> ComputeFbank(
::kaldi::FrameExtractionOptions frame_opts,
::kaldi::MelBanksOptions mel_opts,
knf::FrameExtractionOptions frame_opts,
knf::MelBanksOptions mel_opts,
FbankOptions fbank_opts,
const py::array_t<float>& wav);
@ -56,9 +56,9 @@ py::array_t<float> ComputeFbankStreaming(const py::array_t<float>& wav);
void ResetFbank();
py::array_t<float> ComputeKaldiPitch(
const ::kaldi::PitchExtractionOptions& opts,
const py::array_t<float>& wav);
//py::array_t<float> ComputeKaldiPitch(
//const ::kaldi::PitchExtractionOptions& opts,
//const py::array_t<float>& wav);
} // namespace kaldi
} // namespace paddleaudio

@ -22,7 +22,7 @@ KaldiFeatureWrapper* KaldiFeatureWrapper::GetInstance() {
return &instance;
}
bool KaldiFeatureWrapper::InitFbank(::kaldi::FbankOptions opts) {
bool KaldiFeatureWrapper::InitFbank(knf::FbankOptions opts) {
fbank_.reset(new Fbank(opts));
return true;
}
@ -30,21 +30,18 @@ bool KaldiFeatureWrapper::InitFbank(::kaldi::FbankOptions opts) {
py::array_t<float> KaldiFeatureWrapper::ComputeFbank(
const py::array_t<float> wav) {
py::buffer_info info = wav.request();
::kaldi::SubVector<::kaldi::BaseFloat> input_wav((float*)info.ptr, info.size);
std::vector<float> input_wav((float*)info.ptr, (float*)info.ptr + info.size);
::kaldi::Vector<::kaldi::BaseFloat> feats;
std::vector<float> feats;
bool flag = fbank_->ComputeFeature(input_wav, &feats);
if (flag == false || feats.Dim() == 0) return py::array_t<float>();
auto result = py::array_t<float>(feats.Dim());
if (flag == false || feats.size() == 0) return py::array_t<float>();
auto result = py::array_t<float>(feats.size());
py::buffer_info xs = result.request();
std::cout << std::endl;
float* res_ptr = (float*)xs.ptr;
for (int idx = 0; idx < feats.Dim(); ++idx) {
*res_ptr = feats(idx);
res_ptr++;
}
return result.reshape({feats.Dim() / Dim(), Dim()});
std::memcpy(res_ptr, feats.data(), sizeof(float)*feats.size());
std::vector<int> shape{static_cast<int>(feats.size() / Dim()),
static_cast<int>(Dim())};
return result.reshape(shape);
}
} // namesapce kaldi

@ -14,20 +14,18 @@
#pragma once
#include "base/kaldi-common.h"
#include "feat/feature-fbank.h"
#include "paddleaudio/third_party/kaldi-native-fbank/csrc/feature-fbank.h"
#include "paddleaudio/src/pybind/kaldi/feature_common.h"
namespace paddleaudio {
namespace kaldi {
typedef StreamingFeatureTpl<::kaldi::FbankComputer> Fbank;
typedef StreamingFeatureTpl<knf::FbankComputer> Fbank;
class KaldiFeatureWrapper {
public:
static KaldiFeatureWrapper* GetInstance();
bool InitFbank(::kaldi::FbankOptions opts);
bool InitFbank(knf::FbankOptions opts);
py::array_t<float> ComputeFbank(const py::array_t<float> wav);
int Dim() { return fbank_->Dim(); }
void ResetFbank() { fbank_->Reset(); }

@ -2,7 +2,7 @@
#ifdef INCLUDE_KALDI
#include "paddleaudio/src/pybind/kaldi/kaldi_feature.h"
#include "paddleaudio/third_party/kaldi/feat/feature-fbank.h"
#include "paddleaudio/third_party/kaldi-native-fbank/csrc/feature-fbank.h"
#endif
#ifdef INCLUDE_SOX
@ -89,53 +89,51 @@ PYBIND11_MODULE(_paddleaudio, m) {
#ifdef INCLUDE_KALDI
m.def("ComputeFbank", &paddleaudio::kaldi::ComputeFbank, "compute fbank");
py::class_<kaldi::PitchExtractionOptions>(m, "PitchExtractionOptions")
.def(py::init<>())
.def_readwrite("samp_freq", &kaldi::PitchExtractionOptions::samp_freq)
.def_readwrite("frame_shift_ms", &kaldi::PitchExtractionOptions::frame_shift_ms)
.def_readwrite("frame_length_ms", &kaldi::PitchExtractionOptions::frame_length_ms)
.def_readwrite("preemph_coeff", &kaldi::PitchExtractionOptions::preemph_coeff)
.def_readwrite("min_f0", &kaldi::PitchExtractionOptions::min_f0)
.def_readwrite("max_f0", &kaldi::PitchExtractionOptions::max_f0)
.def_readwrite("soft_min_f0", &kaldi::PitchExtractionOptions::soft_min_f0)
.def_readwrite("penalty_factor", &kaldi::PitchExtractionOptions::penalty_factor)
.def_readwrite("lowpass_cutoff", &kaldi::PitchExtractionOptions::lowpass_cutoff)
.def_readwrite("resample_freq", &kaldi::PitchExtractionOptions::resample_freq)
.def_readwrite("delta_pitch", &kaldi::PitchExtractionOptions::delta_pitch)
.def_readwrite("nccf_ballast", &kaldi::PitchExtractionOptions::nccf_ballast)
.def_readwrite("lowpass_filter_width", &kaldi::PitchExtractionOptions::lowpass_filter_width)
.def_readwrite("upsample_filter_width", &kaldi::PitchExtractionOptions::upsample_filter_width)
.def_readwrite("max_frames_latency", &kaldi::PitchExtractionOptions::max_frames_latency)
.def_readwrite("frames_per_chunk", &kaldi::PitchExtractionOptions::frames_per_chunk)
.def_readwrite("simulate_first_pass_online", &kaldi::PitchExtractionOptions::simulate_first_pass_online)
.def_readwrite("recompute_frame", &kaldi::PitchExtractionOptions::recompute_frame)
.def_readwrite("nccf_ballast_online", &kaldi::PitchExtractionOptions::nccf_ballast_online)
.def_readwrite("snip_edges", &kaldi::PitchExtractionOptions::snip_edges);
m.def("ComputeKaldiPitch", &paddleaudio::kaldi::ComputeKaldiPitch, "compute kaldi pitch");
py::class_<kaldi::FrameExtractionOptions>(m, "FrameExtractionOptions")
//py::class_<kaldi::PitchExtractionOptions>(m, "PitchExtractionOptions")
//.def(py::init<>())
//.def_readwrite("samp_freq", &kaldi::PitchExtractionOptions::samp_freq)
//.def_readwrite("frame_shift_ms", &kaldi::PitchExtractionOptions::frame_shift_ms)
//.def_readwrite("frame_length_ms", &kaldi::PitchExtractionOptions::frame_length_ms)
//.def_readwrite("preemph_coeff", &kaldi::PitchExtractionOptions::preemph_coeff)
//.def_readwrite("min_f0", &kaldi::PitchExtractionOptions::min_f0)
//.def_readwrite("max_f0", &kaldi::PitchExtractionOptions::max_f0)
//.def_readwrite("soft_min_f0", &kaldi::PitchExtractionOptions::soft_min_f0)
//.def_readwrite("penalty_factor", &kaldi::PitchExtractionOptions::penalty_factor)
//.def_readwrite("lowpass_cutoff", &kaldi::PitchExtractionOptions::lowpass_cutoff)
//.def_readwrite("resample_freq", &kaldi::PitchExtractionOptions::resample_freq)
//.def_readwrite("delta_pitch", &kaldi::PitchExtractionOptions::delta_pitch)
//.def_readwrite("nccf_ballast", &kaldi::PitchExtractionOptions::nccf_ballast)
//.def_readwrite("lowpass_filter_width", &kaldi::PitchExtractionOptions::lowpass_filter_width)
//.def_readwrite("upsample_filter_width", &kaldi::PitchExtractionOptions::upsample_filter_width)
//.def_readwrite("max_frames_latency", &kaldi::PitchExtractionOptions::max_frames_latency)
//.def_readwrite("frames_per_chunk", &kaldi::PitchExtractionOptions::frames_per_chunk)
//.def_readwrite("simulate_first_pass_online", &kaldi::PitchExtractionOptions::simulate_first_pass_online)
//.def_readwrite("recompute_frame", &kaldi::PitchExtractionOptions::recompute_frame)
//.def_readwrite("nccf_ballast_online", &kaldi::PitchExtractionOptions::nccf_ballast_online)
//.def_readwrite("snip_edges", &kaldi::PitchExtractionOptions::snip_edges);
//m.def("ComputeKaldiPitch", &paddleaudio::kaldi::ComputeKaldiPitch, "compute kaldi pitch");
py::class_<knf::FrameExtractionOptions>(m, "FrameExtractionOptions")
.def(py::init<>())
.def_readwrite("samp_freq", &kaldi::FrameExtractionOptions::samp_freq)
.def_readwrite("frame_shift_ms", &kaldi::FrameExtractionOptions::frame_shift_ms)
.def_readwrite("frame_length_ms", &kaldi::FrameExtractionOptions::frame_length_ms)
.def_readwrite("dither", &kaldi::FrameExtractionOptions::dither)
.def_readwrite("preemph_coeff", &kaldi::FrameExtractionOptions::preemph_coeff)
.def_readwrite("remove_dc_offset", &kaldi::FrameExtractionOptions::remove_dc_offset)
.def_readwrite("window_type", &kaldi::FrameExtractionOptions::window_type)
.def_readwrite("round_to_power_of_two", &kaldi::FrameExtractionOptions::round_to_power_of_two)
.def_readwrite("blackman_coeff", &kaldi::FrameExtractionOptions::blackman_coeff)
.def_readwrite("snip_edges", &kaldi::FrameExtractionOptions::snip_edges)
.def_readwrite("allow_downsample", &kaldi::FrameExtractionOptions::allow_downsample)
.def_readwrite("allow_upsample", &kaldi::FrameExtractionOptions::allow_upsample)
.def_readwrite("max_feature_vectors", &kaldi::FrameExtractionOptions::max_feature_vectors);
py::class_<kaldi::MelBanksOptions>(m, "MelBanksOptions")
.def_readwrite("samp_freq", &knf::FrameExtractionOptions::samp_freq)
.def_readwrite("frame_shift_ms", &knf::FrameExtractionOptions::frame_shift_ms)
.def_readwrite("frame_length_ms", &knf::FrameExtractionOptions::frame_length_ms)
.def_readwrite("dither", &knf::FrameExtractionOptions::dither)
.def_readwrite("preemph_coeff", &knf::FrameExtractionOptions::preemph_coeff)
.def_readwrite("remove_dc_offset", &knf::FrameExtractionOptions::remove_dc_offset)
.def_readwrite("window_type", &knf::FrameExtractionOptions::window_type)
.def_readwrite("round_to_power_of_two", &knf::FrameExtractionOptions::round_to_power_of_two)
.def_readwrite("blackman_coeff", &knf::FrameExtractionOptions::blackman_coeff)
.def_readwrite("snip_edges", &knf::FrameExtractionOptions::snip_edges)
.def_readwrite("max_feature_vectors", &knf::FrameExtractionOptions::max_feature_vectors);
py::class_<knf::MelBanksOptions>(m, "MelBanksOptions")
.def(py::init<>())
.def_readwrite("num_bins", &kaldi::MelBanksOptions::num_bins)
.def_readwrite("low_freq", &kaldi::MelBanksOptions::low_freq)
.def_readwrite("high_freq", &kaldi::MelBanksOptions::high_freq)
.def_readwrite("vtln_low", &kaldi::MelBanksOptions::vtln_low)
.def_readwrite("vtln_high", &kaldi::MelBanksOptions::vtln_high)
.def_readwrite("debug_mel", &kaldi::MelBanksOptions::debug_mel)
.def_readwrite("htk_mode", &kaldi::MelBanksOptions::htk_mode);
.def_readwrite("num_bins", &knf::MelBanksOptions::num_bins)
.def_readwrite("low_freq", &knf::MelBanksOptions::low_freq)
.def_readwrite("high_freq", &knf::MelBanksOptions::high_freq)
.def_readwrite("vtln_low", &knf::MelBanksOptions::vtln_low)
.def_readwrite("vtln_high", &knf::MelBanksOptions::vtln_high)
.def_readwrite("debug_mel", &knf::MelBanksOptions::debug_mel)
.def_readwrite("htk_mode", &knf::MelBanksOptions::htk_mode);
py::class_<paddleaudio::kaldi::FbankOptions>(m, "FbankOptions")
.def(py::init<>())

@ -11,5 +11,6 @@ endif()
# kaldi
################################################################################
if (BUILD_KALDI)
add_subdirectory(kaldi)
endif()
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
add_subdirectory(kaldi-native-fbank/csrc)
endif()

@ -0,0 +1,22 @@
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../)
add_library(kaldi-native-fbank-core
feature-fbank.cc
feature-functions.cc
feature-window.cc
fftsg.c
log.cc
mel-computations.cc
rfft.cc
)
# We are using std::call_once() in log.h,which requires us to link with -pthread
if(NOT WIN32)
target_link_libraries(kaldi-native-fbank-core -pthread)
endif()
if(KNF_HAVE_EXECINFO_H)
target_compile_definitions(kaldi-native-fbank-core PRIVATE KNF_HAVE_EXECINFO_H=1)
endif()
if(KNF_HAVE_CXXABI_H)
target_compile_definitions(kaldi-native-fbank-core PRIVATE KNF_HAVE_CXXABI_H=1)
endif()

@ -0,0 +1,117 @@
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// This file is copied/modified from kaldi/src/feat/feature-fbank.cc
//
#include "kaldi-native-fbank/csrc/feature-fbank.h"
#include <cmath>
#include "kaldi-native-fbank/csrc/feature-functions.h"
namespace knf {
static void Sqrt(float *in_out, int32_t n) {
for (int32_t i = 0; i != n; ++i) {
in_out[i] = std::sqrt(in_out[i]);
}
}
std::ostream &operator<<(std::ostream &os, const FbankOptions &opts) {
os << opts.ToString();
return os;
}
FbankComputer::FbankComputer(const FbankOptions &opts)
: opts_(opts), rfft_(opts.frame_opts.PaddedWindowSize()) {
if (opts.energy_floor > 0.0f) {
log_energy_floor_ = logf(opts.energy_floor);
}
// We'll definitely need the filterbanks info for VTLN warping factor 1.0.
// [note: this call caches it.]
GetMelBanks(1.0f);
}
FbankComputer::~FbankComputer() {
for (auto iter = mel_banks_.begin(); iter != mel_banks_.end(); ++iter)
delete iter->second;
}
const MelBanks *FbankComputer::GetMelBanks(float vtln_warp) {
MelBanks *this_mel_banks = nullptr;
// std::map<float, MelBanks *>::iterator iter = mel_banks_.find(vtln_warp);
auto iter = mel_banks_.find(vtln_warp);
if (iter == mel_banks_.end()) {
this_mel_banks = new MelBanks(opts_.mel_opts, opts_.frame_opts, vtln_warp);
mel_banks_[vtln_warp] = this_mel_banks;
} else {
this_mel_banks = iter->second;
}
return this_mel_banks;
}
void FbankComputer::Compute(float signal_raw_log_energy, float vtln_warp,
std::vector<float> *signal_frame, float *feature) {
const MelBanks &mel_banks = *(GetMelBanks(vtln_warp));
KNF_CHECK_EQ(signal_frame->size(), opts_.frame_opts.PaddedWindowSize());
// Compute energy after window function (not the raw one).
if (opts_.use_energy && !opts_.raw_energy) {
signal_raw_log_energy = std::log(
std::max<float>(InnerProduct(signal_frame->data(), signal_frame->data(),
signal_frame->size()),
std::numeric_limits<float>::epsilon()));
}
rfft_.Compute(signal_frame->data()); // signal_frame is modified in-place
ComputePowerSpectrum(signal_frame);
// Use magnitude instead of power if requested.
if (!opts_.use_power) {
Sqrt(signal_frame->data(), signal_frame->size() / 2 + 1);
}
int32_t mel_offset = ((opts_.use_energy && !opts_.htk_compat) ? 1 : 0);
// Its length is opts_.mel_opts.num_bins
float *mel_energies = feature + mel_offset;
// Sum with mel filter banks over the power spectrum
mel_banks.Compute(signal_frame->data(), mel_energies);
if (opts_.use_log_fbank) {
// Avoid log of zero (which should be prevented anyway by dithering).
for (int32_t i = 0; i != opts_.mel_opts.num_bins; ++i) {
auto t = std::max(mel_energies[i], std::numeric_limits<float>::epsilon());
mel_energies[i] = std::log(t);
}
}
// Copy energy as first value (or the last, if htk_compat == true).
if (opts_.use_energy) {
if (opts_.energy_floor > 0.0 && signal_raw_log_energy < log_energy_floor_) {
signal_raw_log_energy = log_energy_floor_;
}
int32_t energy_index = opts_.htk_compat ? opts_.mel_opts.num_bins : 0;
feature[energy_index] = signal_raw_log_energy;
}
}
} // namespace knf

@ -0,0 +1,132 @@
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// This file is copied/modified from kaldi/src/feat/feature-fbank.h
#ifndef KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_
#define KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_
#include <map>
#include "kaldi-native-fbank/csrc/feature-window.h"
#include "kaldi-native-fbank/csrc/mel-computations.h"
#include "kaldi-native-fbank/csrc/rfft.h"
namespace knf {
struct FbankOptions {
FrameExtractionOptions frame_opts;
MelBanksOptions mel_opts;
// append an extra dimension with energy to the filter banks
bool use_energy = false;
float energy_floor = 0.0f; // active iff use_energy==true
// If true, compute log_energy before preemphasis and windowing
// If false, compute log_energy after preemphasis ans windowing
bool raw_energy = true; // active iff use_energy==true
// If true, put energy last (if using energy)
// If false, put energy first
bool htk_compat = false; // active iff use_energy==true
// if true (default), produce log-filterbank, else linear
bool use_log_fbank = true;
// if true (default), use power in filterbank
// analysis, else magnitude.
bool use_power = true;
FbankOptions() { mel_opts.num_bins = 23; }
std::string ToString() const {
std::ostringstream os;
os << "frame_opts: \n";
os << frame_opts << "\n";
os << "\n";
os << "mel_opts: \n";
os << mel_opts << "\n";
os << "use_energy: " << use_energy << "\n";
os << "energy_floor: " << energy_floor << "\n";
os << "raw_energy: " << raw_energy << "\n";
os << "htk_compat: " << htk_compat << "\n";
os << "use_log_fbank: " << use_log_fbank << "\n";
os << "use_power: " << use_power << "\n";
return os.str();
}
};
std::ostream &operator<<(std::ostream &os, const FbankOptions &opts);
class FbankComputer {
public:
using Options = FbankOptions;
explicit FbankComputer(const FbankOptions &opts);
~FbankComputer();
int32_t Dim() const {
return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0);
}
// if true, compute log_energy_pre_window but after dithering and dc removal
bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; }
const FrameExtractionOptions &GetFrameOptions() const {
return opts_.frame_opts;
}
const FbankOptions &GetOptions() const { return opts_; }
/**
Function that computes one frame of features from
one frame of signal.
@param [in] signal_raw_log_energy The log-energy of the frame of the signal
prior to windowing and pre-emphasis, or
log(numeric_limits<float>::min()), whichever is greater. Must be
ignored by this function if this class returns false from
this->NeedsRawLogEnergy().
@param [in] vtln_warp The VTLN warping factor that the user wants
to be applied when computing features for this utterance. Will
normally be 1.0, meaning no warping is to be done. The value will
be ignored for feature types that don't support VLTN, such as
spectrogram features.
@param [in] signal_frame One frame of the signal,
as extracted using the function ExtractWindow() using the options
returned by this->GetFrameOptions(). The function will use the
vector as a workspace, which is why it's a non-const pointer.
@param [out] feature Pointer to a vector of size this->Dim(), to which
the computed feature will be written. It should be pre-allocated.
*/
void Compute(float signal_raw_log_energy, float vtln_warp,
std::vector<float> *signal_frame, float *feature);
private:
const MelBanks *GetMelBanks(float vtln_warp);
FbankOptions opts_;
float log_energy_floor_;
std::map<float, MelBanks *> mel_banks_; // float is VTLN coefficient.
Rfft rfft_;
};
} // namespace knf
#endif // KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_

@ -0,0 +1,49 @@
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// This file is copied/modified from kaldi/src/feat/feature-functions.cc
#include "kaldi-native-fbank/csrc/feature-functions.h"
#include <cstdint>
#include <vector>
namespace knf {
void ComputePowerSpectrum(std::vector<float> *complex_fft) {
int32_t dim = complex_fft->size();
// now we have in complex_fft, first half of complex spectrum
// it's stored as [real0, realN/2, real1, im1, real2, im2, ...]
float *p = complex_fft->data();
int32_t half_dim = dim / 2;
float first_energy = p[0] * p[0];
float last_energy = p[1] * p[1]; // handle this special case
for (int32_t i = 1; i < half_dim; ++i) {
float real = p[i * 2];
float im = p[i * 2 + 1];
p[i] = real * real + im * im;
}
p[0] = first_energy;
p[half_dim] = last_energy; // Will actually never be used, and anyway
// if the signal has been bandlimited sensibly this should be zero.
}
} // namespace knf

@ -0,0 +1,38 @@
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// This file is copied/modified from kaldi/src/feat/feature-functions.h
#ifndef KALDI_NATIVE_FBANK_CSRC_FEATURE_FUNCTIONS_H
#define KALDI_NATIVE_FBANK_CSRC_FEATURE_FUNCTIONS_H
#include <vector>
namespace knf {
// ComputePowerSpectrum converts a complex FFT (as produced by the FFT
// functions in csrc/rfft.h), and converts it into
// a power spectrum. If the complex FFT is a vector of size n (representing
// half of the complex FFT of a real signal of size n, as described there),
// this function computes in the first (n/2) + 1 elements of it, the
// energies of the fft bins from zero to the Nyquist frequency. Contents of the
// remaining (n/2) - 1 elements are undefined at output.
void ComputePowerSpectrum(std::vector<float> *complex_fft);
} // namespace knf
#endif // KALDI_NATIVE_FBANK_CSRC_FEATURE_FUNCTIONS_H

@ -0,0 +1,236 @@
// kaldi-native-fbank/csrc/feature-window.cc
//
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/feature-window.cc
#include "kaldi-native-fbank/csrc/feature-window.h"
#include <cmath>
#include <vector>
#ifndef M_2PI
#define M_2PI 6.283185307179586476925286766559005
#endif
namespace knf {
std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts) {
os << opts.ToString();
return os;
}
FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts)
: window_(opts.WindowSize()) {
int32_t frame_length = opts.WindowSize();
KNF_CHECK_GT(frame_length, 0);
float *window_data = window_.data();
double a = M_2PI / (frame_length - 1);
for (int32_t i = 0; i < frame_length; i++) {
double i_fl = static_cast<double>(i);
if (opts.window_type == "hanning") {
window_data[i] = 0.5 - 0.5 * cos(a * i_fl);
} else if (opts.window_type == "sine") {
// when you are checking ws wikipedia, please
// note that 0.5 * a = M_PI/(frame_length-1)
window_data[i] = sin(0.5 * a * i_fl);
} else if (opts.window_type == "hamming") {
window_data[i] = 0.54 - 0.46 * cos(a * i_fl);
} else if (opts.window_type ==
"povey") { // like hamming but goes to zero at edges.
window_data[i] = pow(0.5 - 0.5 * cos(a * i_fl), 0.85);
} else if (opts.window_type == "rectangular") {
window_data[i] = 1.0;
} else if (opts.window_type == "blackman") {
window_data[i] = opts.blackman_coeff - 0.5 * cos(a * i_fl) +
(0.5 - opts.blackman_coeff) * cos(2 * a * i_fl);
} else {
KNF_LOG(FATAL) << "Invalid window type " << opts.window_type;
}
}
}
void FeatureWindowFunction::Apply(float *wave) const {
int32_t window_size = window_.size();
const float *p = window_.data();
for (int32_t k = 0; k != window_size; ++k) {
wave[k] *= p[k];
}
}
int64_t FirstSampleOfFrame(int32_t frame, const FrameExtractionOptions &opts) {
int64_t frame_shift = opts.WindowShift();
if (opts.snip_edges) {
return frame * frame_shift;
} else {
int64_t midpoint_of_frame = frame_shift * frame + frame_shift / 2,
beginning_of_frame = midpoint_of_frame - opts.WindowSize() / 2;
return beginning_of_frame;
}
}
int32_t NumFrames(int64_t num_samples, const FrameExtractionOptions &opts,
bool flush /*= true*/) {
int64_t frame_shift = opts.WindowShift();
int64_t frame_length = opts.WindowSize();
if (opts.snip_edges) {
// with --snip-edges=true (the default), we use a HTK-like approach to
// determining the number of frames-- all frames have to fit completely into
// the waveform, and the first frame begins at sample zero.
if (num_samples < frame_length)
return 0;
else
return (1 + ((num_samples - frame_length) / frame_shift));
// You can understand the expression above as follows: 'num_samples -
// frame_length' is how much room we have to shift the frame within the
// waveform; 'frame_shift' is how much we shift it each time; and the ratio
// is how many times we can shift it (integer arithmetic rounds down).
} else {
// if --snip-edges=false, the number of frames is determined by rounding the
// (file-length / frame-shift) to the nearest integer. The point of this
// formula is to make the number of frames an obvious and predictable
// function of the frame shift and signal length, which makes many
// segmentation-related questions simpler.
//
// Because integer division in C++ rounds toward zero, we add (half the
// frame-shift minus epsilon) before dividing, to have the effect of
// rounding towards the closest integer.
int32_t num_frames = (num_samples + (frame_shift / 2)) / frame_shift;
if (flush) return num_frames;
// note: 'end' always means the last plus one, i.e. one past the last.
int64_t end_sample_of_last_frame =
FirstSampleOfFrame(num_frames - 1, opts) + frame_length;
// the following code is optimized more for clarity than efficiency.
// If flush == false, we can't output frames that extend past the end
// of the signal.
while (num_frames > 0 && end_sample_of_last_frame > num_samples) {
num_frames--;
end_sample_of_last_frame -= frame_shift;
}
return num_frames;
}
}
void ExtractWindow(int64_t sample_offset, const std::vector<float> &wave,
int32_t f, const FrameExtractionOptions &opts,
const FeatureWindowFunction &window_function,
std::vector<float> *window,
float *log_energy_pre_window /*= nullptr*/) {
KNF_CHECK(sample_offset >= 0 && wave.size() != 0);
int32_t frame_length = opts.WindowSize();
int32_t frame_length_padded = opts.PaddedWindowSize();
int64_t num_samples = sample_offset + wave.size();
int64_t start_sample = FirstSampleOfFrame(f, opts);
int64_t end_sample = start_sample + frame_length;
if (opts.snip_edges) {
KNF_CHECK(start_sample >= sample_offset && end_sample <= num_samples);
} else {
KNF_CHECK(sample_offset == 0 || start_sample >= sample_offset);
}
if (window->size() != frame_length_padded) {
window->resize(frame_length_padded);
}
// wave_start and wave_end are start and end indexes into 'wave', for the
// piece of wave that we're trying to extract.
int32_t wave_start = int32_t(start_sample - sample_offset);
int32_t wave_end = wave_start + frame_length;
if (wave_start >= 0 && wave_end <= wave.size()) {
// the normal case-- no edge effects to consider.
std::copy(wave.begin() + wave_start,
wave.begin() + wave_start + frame_length, window->data());
} else {
// Deal with any end effects by reflection, if needed. This code will only
// be reached for about two frames per utterance, so we don't concern
// ourselves excessively with efficiency.
int32_t wave_dim = wave.size();
for (int32_t s = 0; s < frame_length; ++s) {
int32_t s_in_wave = s + wave_start;
while (s_in_wave < 0 || s_in_wave >= wave_dim) {
// reflect around the beginning or end of the wave.
// e.g. -1 -> 0, -2 -> 1.
// dim -> dim - 1, dim + 1 -> dim - 2.
// the code supports repeated reflections, although this
// would only be needed in pathological cases.
if (s_in_wave < 0)
s_in_wave = -s_in_wave - 1;
else
s_in_wave = 2 * wave_dim - 1 - s_in_wave;
}
(*window)[s] = wave[s_in_wave];
}
}
ProcessWindow(opts, window_function, window->data(), log_energy_pre_window);
}
static void RemoveDcOffset(float *d, int32_t n) {
float sum = 0;
for (int32_t i = 0; i != n; ++i) {
sum += d[i];
}
float mean = sum / n;
for (int32_t i = 0; i != n; ++i) {
d[i] -= mean;
}
}
float InnerProduct(const float *a, const float *b, int32_t n) {
float sum = 0;
for (int32_t i = 0; i != n; ++i) {
sum += a[i] * b[i];
}
return sum;
}
static void Preemphasize(float *d, int32_t n, float preemph_coeff) {
if (preemph_coeff == 0.0) {
return;
}
KNF_CHECK(preemph_coeff >= 0.0 && preemph_coeff <= 1.0);
for (int32_t i = n - 1; i > 0; --i) {
d[i] -= preemph_coeff * d[i - 1];
}
d[0] -= preemph_coeff * d[0];
}
void ProcessWindow(const FrameExtractionOptions &opts,
const FeatureWindowFunction &window_function, float *window,
float *log_energy_pre_window /*= nullptr*/) {
int32_t frame_length = opts.WindowSize();
// TODO(fangjun): Remove dither
KNF_CHECK_EQ(opts.dither, 0);
if (opts.remove_dc_offset) {
RemoveDcOffset(window, frame_length);
}
if (log_energy_pre_window != NULL) {
float energy = std::max<float>(InnerProduct(window, window, frame_length),
std::numeric_limits<float>::epsilon());
*log_energy_pre_window = std::log(energy);
}
if (opts.preemph_coeff != 0.0) {
Preemphasize(window, frame_length, opts.preemph_coeff);
}
window_function.Apply(window);
}
} // namespace knf

@ -0,0 +1,178 @@
// kaldi-native-fbank/csrc/feature-window.h
//
// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
// This file is copied/modified from kaldi/src/feat/feature-window.h
#ifndef KALDI_NATIVE_FEAT_CSRC_FEATURE_WINDOW_H_
#define KALDI_NATIVE_FEAT_CSRC_FEATURE_WINDOW_H_
#include <sstream>
#include <string>
#include <vector>
#include "kaldi-native-fbank/csrc/log.h"
namespace knf {
inline int32_t RoundUpToNearestPowerOfTwo(int32_t n) {
// copied from kaldi/src/base/kaldi-math.cc
KNF_CHECK_GT(n, 0);
n--;
n |= n >> 1;
n |= n >> 2;
n |= n >> 4;
n |= n >> 8;
n |= n >> 16;
return n + 1;
}
struct FrameExtractionOptions {
float samp_freq = 16000;
float frame_shift_ms = 10.0f; // in milliseconds.
float frame_length_ms = 25.0f; // in milliseconds.
float dither = 1.0f; // Amount of dithering, 0.0 means no dither.
float preemph_coeff = 0.97f; // Preemphasis coefficient.
bool remove_dc_offset = true; // Subtract mean of wave before FFT.
std::string window_type = "povey"; // e.g. Hamming window
// May be "hamming", "rectangular", "povey", "hanning", "sine", "blackman"
// "povey" is a window I made to be similar to Hamming but to go to zero at
// the edges, it's pow((0.5 - 0.5*cos(n/N*2*pi)), 0.85) I just don't think the
// Hamming window makes sense as a windowing function.
bool round_to_power_of_two = true;
float blackman_coeff = 0.42f;
bool snip_edges = true;
// bool allow_downsample = false;
// bool allow_upsample = false;
// Used for streaming feature extraction. It indicates the number
// of feature frames to keep in the recycling vector. -1 means to
// keep all feature frames.
int32_t max_feature_vectors = -1;
int32_t WindowShift() const {
return static_cast<int32_t>(samp_freq * 0.001f * frame_shift_ms);
}
int32_t WindowSize() const {
return static_cast<int32_t>(samp_freq * 0.001f * frame_length_ms);
}
int32_t PaddedWindowSize() const {
return (round_to_power_of_two ? RoundUpToNearestPowerOfTwo(WindowSize())
: WindowSize());
}
std::string ToString() const {
std::ostringstream os;
#define KNF_PRINT(x) os << #x << ": " << x << "\n"
KNF_PRINT(samp_freq);
KNF_PRINT(frame_shift_ms);
KNF_PRINT(frame_length_ms);
KNF_PRINT(dither);
KNF_PRINT(preemph_coeff);
KNF_PRINT(remove_dc_offset);
KNF_PRINT(window_type);
KNF_PRINT(round_to_power_of_two);
KNF_PRINT(blackman_coeff);
KNF_PRINT(snip_edges);
// KNF_PRINT(allow_downsample);
// KNF_PRINT(allow_upsample);
KNF_PRINT(max_feature_vectors);
#undef KNF_PRINT
return os.str();
}
};
std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts);
class FeatureWindowFunction {
public:
FeatureWindowFunction() = default;
explicit FeatureWindowFunction(const FrameExtractionOptions &opts);
/**
* @param wave Pointer to a 1-D array of shape [window_size].
* It is modified in-place: wave[i] = wave[i] * window_[i].
* @param
*/
void Apply(float *wave) const;
private:
std::vector<float> window_; // of size opts.WindowSize()
};
int64_t FirstSampleOfFrame(int32_t frame, const FrameExtractionOptions &opts);
/**
This function returns the number of frames that we can extract from a wave
file with the given number of samples in it (assumed to have the same
sampling rate as specified in 'opts').
@param [in] num_samples The number of samples in the wave file.
@param [in] opts The frame-extraction options class
@param [in] flush True if we are asserting that this number of samples
is 'all there is', false if we expecting more data to possibly come in. This
only makes a difference to the answer
if opts.snips_edges== false. For offline feature extraction you always want
flush == true. In an online-decoding context, once you know (or decide) that
no more data is coming in, you'd call it with flush == true at the end to
flush out any remaining data.
*/
int32_t NumFrames(int64_t num_samples, const FrameExtractionOptions &opts,
bool flush = true);
/*
ExtractWindow() extracts a windowed frame of waveform (possibly with a
power-of-two, padded size, depending on the config), including all the
processing done by ProcessWindow().
@param [in] sample_offset If 'wave' is not the entire waveform, but
part of it to the left has been discarded, then the
number of samples prior to 'wave' that we have
already discarded. Set this to zero if you are
processing the entire waveform in one piece, or
if you get 'no matching function' compilation
errors when updating the code.
@param [in] wave The waveform
@param [in] f The frame index to be extracted, with
0 <= f < NumFrames(sample_offset + wave.Dim(), opts, true)
@param [in] opts The options class to be used
@param [in] window_function The windowing function, as derived from the
options class.
@param [out] window The windowed, possibly-padded waveform to be
extracted. Will be resized as needed.
@param [out] log_energy_pre_window If non-NULL, the log-energy of
the signal prior to pre-emphasis and multiplying by
the windowing function will be written to here.
*/
void ExtractWindow(int64_t sample_offset, const std::vector<float> &wave,
int32_t f, const FrameExtractionOptions &opts,
const FeatureWindowFunction &window_function,
std::vector<float> *window,
float *log_energy_pre_window = nullptr);
/**
This function does all the windowing steps after actually
extracting the windowed signal: depending on the
configuration, it does dithering, dc offset removal,
preemphasis, and multiplication by the windowing function.
@param [in] opts The options class to be used
@param [in] window_function The windowing function-- should have
been initialized using 'opts'.
@param [in,out] window A vector of size opts.WindowSize(). Note:
it will typically be a sub-vector of a larger vector of size
opts.PaddedWindowSize(), with the remaining samples zero,
as the FFT code is more efficient if it operates on data with
power-of-two size.
@param [out] log_energy_pre_window If non-NULL, then after dithering and
DC offset removal, this function will write to this pointer the log of
the total energy (i.e. sum-squared) of the frame.
*/
void ProcessWindow(const FrameExtractionOptions &opts,
const FeatureWindowFunction &window_function, float *window,
float *log_energy_pre_window = nullptr);
// Compute the inner product of two vectors
float InnerProduct(const float *a, const float *b, int32_t n);
} // namespace knf
#endif // KALDI_NATIVE_FEAT_CSRC_FEATURE_WINDOW_H_

File diff suppressed because it is too large Load Diff

@ -0,0 +1,143 @@
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Stack trace related stuff is from kaldi.
* Refer to
* https://github.com/kaldi-asr/kaldi/blob/master/src/base/kaldi-error.cc
*/
#include "kaldi-native-fbank/csrc/log.h"
#ifdef KNF_HAVE_EXECINFO_H
#include <execinfo.h> // To get stack trace in error messages.
#ifdef KNF_HAVE_CXXABI_H
#include <cxxabi.h> // For name demangling.
// Useful to decode the stack trace, but only used if we have execinfo.h
#endif // KNF_HAVE_CXXABI_H
#endif // KNF_HAVE_EXECINFO_H
#include <stdlib.h>
#include <ctime>
#include <iomanip>
#include <string>
namespace knf {
std::string GetDateTimeStr() {
std::ostringstream os;
std::time_t t = std::time(nullptr);
std::tm tm = *std::localtime(&t);
os << std::put_time(&tm, "%F %T"); // yyyy-mm-dd hh:mm:ss
return os.str();
}
static bool LocateSymbolRange(const std::string &trace_name, std::size_t *begin,
std::size_t *end) {
// Find the first '_' with leading ' ' or '('.
*begin = std::string::npos;
for (std::size_t i = 1; i < trace_name.size(); ++i) {
if (trace_name[i] != '_') {
continue;
}
if (trace_name[i - 1] == ' ' || trace_name[i - 1] == '(') {
*begin = i;
break;
}
}
if (*begin == std::string::npos) {
return false;
}
*end = trace_name.find_first_of(" +", *begin);
return *end != std::string::npos;
}
#ifdef KNF_HAVE_EXECINFO_H
static std::string Demangle(const std::string &trace_name) {
#ifndef KNF_HAVE_CXXABI_H
return trace_name;
#else // KNF_HAVE_CXXABI_H
// Try demangle the symbol. We are trying to support the following formats
// produced by different platforms:
//
// Linux:
// ./kaldi-error-test(_ZN5kaldi13UnitTestErrorEv+0xb) [0x804965d]
//
// Mac:
// 0 server 0x000000010f67614d _ZNK5kaldi13MessageLogger10LogMessageEv + 813
//
// We want to extract the name e.g., '_ZN5kaldi13UnitTestErrorEv' and
// demangle it info a readable name like kaldi::UnitTextError.
std::size_t begin, end;
if (!LocateSymbolRange(trace_name, &begin, &end)) {
return trace_name;
}
std::string symbol = trace_name.substr(begin, end - begin);
int status;
char *demangled_name = abi::__cxa_demangle(symbol.c_str(), 0, 0, &status);
if (status == 0 && demangled_name != nullptr) {
symbol = demangled_name;
free(demangled_name);
}
return trace_name.substr(0, begin) + symbol +
trace_name.substr(end, std::string::npos);
#endif // KNF_HAVE_CXXABI_H
}
#endif // KNF_HAVE_EXECINFO_H
std::string GetStackTrace() {
std::string ans;
#ifdef KNF_HAVE_EXECINFO_H
constexpr const std::size_t kMaxTraceSize = 50;
constexpr const std::size_t kMaxTracePrint = 50; // Must be even.
// Buffer for the trace.
void *trace[kMaxTraceSize];
// Get the trace.
std::size_t size = backtrace(trace, kMaxTraceSize);
// Get the trace symbols.
char **trace_symbol = backtrace_symbols(trace, size);
if (trace_symbol == nullptr)
return ans;
// Compose a human-readable backtrace string.
ans += "[ Stack-Trace: ]\n";
if (size <= kMaxTracePrint) {
for (std::size_t i = 0; i < size; ++i) {
ans += Demangle(trace_symbol[i]) + "\n";
}
} else { // Print out first+last (e.g.) 5.
for (std::size_t i = 0; i < kMaxTracePrint / 2; ++i) {
ans += Demangle(trace_symbol[i]) + "\n";
}
ans += ".\n.\n.\n";
for (std::size_t i = size - kMaxTracePrint / 2; i < size; ++i) {
ans += Demangle(trace_symbol[i]) + "\n";
}
if (size == kMaxTraceSize)
ans += ".\n.\n.\n"; // Stack was too long, probably a bug.
}
// We must free the array of pointers allocated by backtrace_symbols(),
// but not the strings themselves.
free(trace_symbol);
#endif // KNF_HAVE_EXECINFO_H
return ans;
}
} // namespace knf

@ -0,0 +1,347 @@
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// The content in this file is copied/modified from
// https://github.com/k2-fsa/k2/blob/master/k2/csrc/log.h
#ifndef KALDI_NATIVE_FBANK_CSRC_LOG_H_
#define KALDI_NATIVE_FBANK_CSRC_LOG_H_
#include <stdio.h>
#include <mutex> // NOLINT
#include <sstream>
#include <string>
namespace knf {
#if defined(NDEBUG)
constexpr bool kDisableDebug = true;
#else
constexpr bool kDisableDebug = false;
#endif
enum class LogLevel {
kTrace = 0,
kDebug = 1,
kInfo = 2,
kWarning = 3,
kError = 4,
kFatal = 5, // print message and abort the program
};
// They are used in KNF_LOG(xxx), so their names
// do not follow the google c++ code style
//
// You can use them in the following way:
//
// KNF_LOG(TRACE) << "some message";
// KNF_LOG(DEBUG) << "some message";
#ifndef _MSC_VER
constexpr LogLevel TRACE = LogLevel::kTrace;
constexpr LogLevel DEBUG = LogLevel::kDebug;
constexpr LogLevel INFO = LogLevel::kInfo;
constexpr LogLevel WARNING = LogLevel::kWarning;
constexpr LogLevel ERROR = LogLevel::kError;
constexpr LogLevel FATAL = LogLevel::kFatal;
#else
#define TRACE LogLevel::kTrace
#define DEBUG LogLevel::kDebug
#define INFO LogLevel::kInfo
#define WARNING LogLevel::kWarning
#define ERROR LogLevel::kError
#define FATAL LogLevel::kFatal
#endif
std::string GetStackTrace();
/* Return the current log level.
If the current log level is TRACE, then all logged messages are printed out.
If the current log level is DEBUG, log messages with "TRACE" level are not
shown and all other levels are printed out.
Similarly, if the current log level is INFO, log message with "TRACE" and
"DEBUG" are not shown and all other levels are printed out.
If it is FATAL, then only FATAL messages are shown.
*/
inline LogLevel GetCurrentLogLevel() {
static LogLevel log_level = INFO;
static std::once_flag init_flag;
std::call_once(init_flag, []() {
const char *env_log_level = std::getenv("KNF_LOG_LEVEL");
if (env_log_level == nullptr) return;
std::string s = env_log_level;
if (s == "TRACE")
log_level = TRACE;
else if (s == "DEBUG")
log_level = DEBUG;
else if (s == "INFO")
log_level = INFO;
else if (s == "WARNING")
log_level = WARNING;
else if (s == "ERROR")
log_level = ERROR;
else if (s == "FATAL")
log_level = FATAL;
else
fprintf(stderr,
"Unknown KNF_LOG_LEVEL: %s"
"\nSupported values are: "
"TRACE, DEBUG, INFO, WARNING, ERROR, FATAL",
s.c_str());
});
return log_level;
}
inline bool EnableAbort() {
static std::once_flag init_flag;
static bool enable_abort = false;
std::call_once(init_flag, []() {
enable_abort = (std::getenv("KNF_ABORT") != nullptr);
});
return enable_abort;
}
class Logger {
public:
Logger(const char *filename, const char *func_name, uint32_t line_num,
LogLevel level)
: filename_(filename),
func_name_(func_name),
line_num_(line_num),
level_(level) {
cur_level_ = GetCurrentLogLevel();
fprintf(stderr, "here\n");
switch (level) {
case TRACE:
if (cur_level_ <= TRACE) fprintf(stderr, "[T] ");
break;
case DEBUG:
if (cur_level_ <= DEBUG) fprintf(stderr, "[D] ");
break;
case INFO:
if (cur_level_ <= INFO) fprintf(stderr, "[I] ");
break;
case WARNING:
if (cur_level_ <= WARNING) fprintf(stderr, "[W] ");
break;
case ERROR:
if (cur_level_ <= ERROR) fprintf(stderr, "[E] ");
break;
case FATAL:
if (cur_level_ <= FATAL) fprintf(stderr, "[F] ");
break;
}
if (cur_level_ <= level_) {
fprintf(stderr, "%s:%u:%s ", filename, line_num, func_name);
}
}
~Logger() noexcept(false) {
static constexpr const char *kErrMsg = R"(
Some bad things happened. Please read the above error messages and stack
trace. If you are using Python, the following command may be helpful:
gdb --args python /path/to/your/code.py
(You can use `gdb` to debug the code. Please consider compiling
a debug version of KNF.).
If you are unable to fix it, please open an issue at:
https://github.com/csukuangfj/kaldi-native-fbank/issues/new
)";
fprintf(stderr, "\n");
if (level_ == FATAL) {
std::string stack_trace = GetStackTrace();
if (!stack_trace.empty()) {
fprintf(stderr, "\n\n%s\n", stack_trace.c_str());
}
fflush(nullptr);
#ifndef __ANDROID_API__
if (EnableAbort()) {
// NOTE: abort() will terminate the program immediately without
// printing the Python stack backtrace.
abort();
}
throw std::runtime_error(kErrMsg);
#else
abort();
#endif
}
}
const Logger &operator<<(bool b) const {
if (cur_level_ <= level_) {
fprintf(stderr, b ? "true" : "false");
}
return *this;
}
const Logger &operator<<(int8_t i) const {
if (cur_level_ <= level_) fprintf(stderr, "%d", i);
return *this;
}
const Logger &operator<<(const char *s) const {
if (cur_level_ <= level_) fprintf(stderr, "%s", s);
return *this;
}
const Logger &operator<<(int32_t i) const {
if (cur_level_ <= level_) fprintf(stderr, "%d", i);
return *this;
}
const Logger &operator<<(uint32_t i) const {
if (cur_level_ <= level_) fprintf(stderr, "%u", i);
return *this;
}
const Logger &operator<<(uint64_t i) const {
if (cur_level_ <= level_)
fprintf(stderr, "%llu", (long long unsigned int)i); // NOLINT
return *this;
}
const Logger &operator<<(int64_t i) const {
if (cur_level_ <= level_)
fprintf(stderr, "%lli", (long long int)i); // NOLINT
return *this;
}
const Logger &operator<<(float f) const {
if (cur_level_ <= level_) fprintf(stderr, "%f", f);
return *this;
}
const Logger &operator<<(double d) const {
if (cur_level_ <= level_) fprintf(stderr, "%f", d);
return *this;
}
template <typename T>
const Logger &operator<<(const T &t) const {
// require T overloads operator<<
std::ostringstream os;
os << t;
return *this << os.str().c_str();
}
// specialization to fix compile error: `stringstream << nullptr` is ambiguous
const Logger &operator<<(const std::nullptr_t &null) const {
if (cur_level_ <= level_) *this << "(null)";
return *this;
}
private:
const char *filename_;
const char *func_name_;
uint32_t line_num_;
LogLevel level_;
LogLevel cur_level_;
};
class Voidifier {
public:
void operator&(const Logger &)const {}
};
} // namespace knf
#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) || \
defined(__PRETTY_FUNCTION__)
// for clang and GCC
#define KNF_FUNC __PRETTY_FUNCTION__
#else
// for other compilers
#define KNF_FUNC __func__
#endif
#define KNF_STATIC_ASSERT(x) static_assert(x, "")
#define KNF_CHECK(x) \
(x) ? (void)0 \
: ::knf::Voidifier() & \
::knf::Logger(__FILE__, KNF_FUNC, __LINE__, ::knf::FATAL) \
<< "Check failed: " << #x << " "
// WARNING: x and y may be evaluated multiple times, but this happens only
// when the check fails. Since the program aborts if it fails, we don't think
// the extra evaluation of x and y matters.
//
// CAUTION: we recommend the following use case:
//
// auto x = Foo();
// auto y = Bar();
// KNF_CHECK_EQ(x, y) << "Some message";
//
// And please avoid
//
// KNF_CHECK_EQ(Foo(), Bar());
//
// if `Foo()` or `Bar()` causes some side effects, e.g., changing some
// local static variables or global variables.
#define _KNF_CHECK_OP(x, y, op) \
((x)op(y)) ? (void)0 \
: ::knf::Voidifier() & \
::knf::Logger(__FILE__, KNF_FUNC, __LINE__, ::knf::FATAL) \
<< "Check failed: " << #x << " " << #op << " " << #y \
<< " (" << (x) << " vs. " << (y) << ") "
#define KNF_CHECK_EQ(x, y) _KNF_CHECK_OP(x, y, ==)
#define KNF_CHECK_NE(x, y) _KNF_CHECK_OP(x, y, !=)
#define KNF_CHECK_LT(x, y) _KNF_CHECK_OP(x, y, <)
#define KNF_CHECK_LE(x, y) _KNF_CHECK_OP(x, y, <=)
#define KNF_CHECK_GT(x, y) _KNF_CHECK_OP(x, y, >)
#define KNF_CHECK_GE(x, y) _KNF_CHECK_OP(x, y, >=)
#define KNF_LOG(x) ::knf::Logger(__FILE__, KNF_FUNC, __LINE__, ::knf::x)
// ------------------------------------------------------------
// For debug check
// ------------------------------------------------------------
// If you define the macro "-D NDEBUG" while compiling kaldi-native-fbank,
// the following macros are in fact empty and does nothing.
#define KNF_DCHECK(x) ::knf::kDisableDebug ? (void)0 : KNF_CHECK(x)
#define KNF_DCHECK_EQ(x, y) ::knf::kDisableDebug ? (void)0 : KNF_CHECK_EQ(x, y)
#define KNF_DCHECK_NE(x, y) ::knf::kDisableDebug ? (void)0 : KNF_CHECK_NE(x, y)
#define KNF_DCHECK_LT(x, y) ::knf::kDisableDebug ? (void)0 : KNF_CHECK_LT(x, y)
#define KNF_DCHECK_LE(x, y) ::knf::kDisableDebug ? (void)0 : KNF_CHECK_LE(x, y)
#define KNF_DCHECK_GT(x, y) ::knf::kDisableDebug ? (void)0 : KNF_CHECK_GT(x, y)
#define KNF_DCHECK_GE(x, y) ::knf::kDisableDebug ? (void)0 : KNF_CHECK_GE(x, y)
#define KNF_DLOG(x) \
::knf::kDisableDebug ? (void)0 : ::knf::Voidifier() & KNF_LOG(x)
#endif // KALDI_NATIVE_FBANK_CSRC_LOG_H_

@ -0,0 +1,256 @@
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// This file is copied/modified from kaldi/src/feat/mel-computations.cc
#include "kaldi-native-fbank/csrc/mel-computations.h"
#include <algorithm>
#include <sstream>
#include "kaldi-native-fbank/csrc/feature-window.h"
namespace knf {
std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts) {
os << opts.ToString();
return os;
}
float MelBanks::VtlnWarpFreq(
float vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN.
float vtln_high_cutoff,
float low_freq, // upper+lower frequency cutoffs in mel computation
float high_freq, float vtln_warp_factor, float freq) {
/// This computes a VTLN warping function that is not the same as HTK's one,
/// but has similar inputs (this function has the advantage of never producing
/// empty bins).
/// This function computes a warp function F(freq), defined between low_freq
/// and high_freq inclusive, with the following properties:
/// F(low_freq) == low_freq
/// F(high_freq) == high_freq
/// The function is continuous and piecewise linear with two inflection
/// points.
/// The lower inflection point (measured in terms of the unwarped
/// frequency) is at frequency l, determined as described below.
/// The higher inflection point is at a frequency h, determined as
/// described below.
/// If l <= f <= h, then F(f) = f/vtln_warp_factor.
/// If the higher inflection point (measured in terms of the unwarped
/// frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
/// Since (by the last point) F(h) == h/vtln_warp_factor, then
/// max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
/// h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
/// = vtln_high_cutoff * min(1, vtln_warp_factor).
/// If the lower inflection point (measured in terms of the unwarped
/// frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
/// This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
/// = vtln_low_cutoff * max(1, vtln_warp_factor)
if (freq < low_freq || freq > high_freq)
return freq; // in case this gets called
// for out-of-range frequencies, just return the freq.
KNF_CHECK_GT(vtln_low_cutoff, low_freq);
KNF_CHECK_LT(vtln_high_cutoff, high_freq);
float one = 1.0f;
float l = vtln_low_cutoff * std::max(one, vtln_warp_factor);
float h = vtln_high_cutoff * std::min(one, vtln_warp_factor);
float scale = 1.0f / vtln_warp_factor;
float Fl = scale * l; // F(l);
float Fh = scale * h; // F(h);
KNF_CHECK(l > low_freq && h < high_freq);
// slope of left part of the 3-piece linear function
float scale_left = (Fl - low_freq) / (l - low_freq);
// [slope of center part is just "scale"]
// slope of right part of the 3-piece linear function
float scale_right = (high_freq - Fh) / (high_freq - h);
if (freq < l) {
return low_freq + scale_left * (freq - low_freq);
} else if (freq < h) {
return scale * freq;
} else { // freq >= h
return high_freq + scale_right * (freq - high_freq);
}
}
float MelBanks::VtlnWarpMelFreq(
float vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN.
float vtln_high_cutoff,
float low_freq, // upper+lower frequency cutoffs in mel computation
float high_freq, float vtln_warp_factor, float mel_freq) {
return MelScale(VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff, low_freq,
high_freq, vtln_warp_factor,
InverseMelScale(mel_freq)));
}
MelBanks::MelBanks(const MelBanksOptions &opts,
const FrameExtractionOptions &frame_opts,
float vtln_warp_factor)
: htk_mode_(opts.htk_mode) {
int32_t num_bins = opts.num_bins;
if (num_bins < 3) KNF_LOG(FATAL) << "Must have at least 3 mel bins";
float sample_freq = frame_opts.samp_freq;
int32_t window_length_padded = frame_opts.PaddedWindowSize();
KNF_CHECK_EQ(window_length_padded % 2, 0);
int32_t num_fft_bins = window_length_padded / 2;
float nyquist = 0.5f * sample_freq;
float low_freq = opts.low_freq, high_freq;
if (opts.high_freq > 0.0f)
high_freq = opts.high_freq;
else
high_freq = nyquist + opts.high_freq;
if (low_freq < 0.0f || low_freq >= nyquist || high_freq <= 0.0f ||
high_freq > nyquist || high_freq <= low_freq) {
KNF_LOG(FATAL) << "Bad values in options: low-freq " << low_freq
<< " and high-freq " << high_freq << " vs. nyquist "
<< nyquist;
}
float fft_bin_width = sample_freq / window_length_padded;
// fft-bin width [think of it as Nyquist-freq / half-window-length]
float mel_low_freq = MelScale(low_freq);
float mel_high_freq = MelScale(high_freq);
debug_ = opts.debug_mel;
// divide by num_bins+1 in next line because of end-effects where the bins
// spread out to the sides.
float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1);
float vtln_low = opts.vtln_low, vtln_high = opts.vtln_high;
if (vtln_high < 0.0f) {
vtln_high += nyquist;
}
if (vtln_warp_factor != 1.0f &&
(vtln_low < 0.0f || vtln_low <= low_freq || vtln_low >= high_freq ||
vtln_high <= 0.0f || vtln_high >= high_freq || vtln_high <= vtln_low)) {
KNF_LOG(FATAL) << "Bad values in options: vtln-low " << vtln_low
<< " and vtln-high " << vtln_high << ", versus "
<< "low-freq " << low_freq << " and high-freq " << high_freq;
}
bins_.resize(num_bins);
center_freqs_.resize(num_bins);
for (int32_t bin = 0; bin < num_bins; ++bin) {
float left_mel = mel_low_freq + bin * mel_freq_delta,
center_mel = mel_low_freq + (bin + 1) * mel_freq_delta,
right_mel = mel_low_freq + (bin + 2) * mel_freq_delta;
if (vtln_warp_factor != 1.0f) {
left_mel = VtlnWarpMelFreq(vtln_low, vtln_high, low_freq, high_freq,
vtln_warp_factor, left_mel);
center_mel = VtlnWarpMelFreq(vtln_low, vtln_high, low_freq, high_freq,
vtln_warp_factor, center_mel);
right_mel = VtlnWarpMelFreq(vtln_low, vtln_high, low_freq, high_freq,
vtln_warp_factor, right_mel);
}
center_freqs_[bin] = InverseMelScale(center_mel);
// this_bin will be a vector of coefficients that is only
// nonzero where this mel bin is active.
std::vector<float> this_bin(num_fft_bins);
int32_t first_index = -1, last_index = -1;
for (int32_t i = 0; i < num_fft_bins; ++i) {
float freq = (fft_bin_width * i); // Center frequency of this fft
// bin.
float mel = MelScale(freq);
if (mel > left_mel && mel < right_mel) {
float weight;
if (mel <= center_mel)
weight = (mel - left_mel) / (center_mel - left_mel);
else
weight = (right_mel - mel) / (right_mel - center_mel);
this_bin[i] = weight;
if (first_index == -1) first_index = i;
last_index = i;
}
}
KNF_CHECK(first_index != -1 && last_index >= first_index &&
"You may have set num_mel_bins too large.");
bins_[bin].first = first_index;
int32_t size = last_index + 1 - first_index;
bins_[bin].second.insert(bins_[bin].second.end(),
this_bin.begin() + first_index,
this_bin.begin() + first_index + size);
// Replicate a bug in HTK, for testing purposes.
if (opts.htk_mode && bin == 0 && mel_low_freq != 0.0f) {
bins_[bin].second[0] = 0.0;
}
} // for (int32_t bin = 0; bin < num_bins; ++bin) {
if (debug_) {
std::ostringstream os;
for (size_t i = 0; i < bins_.size(); i++) {
os << "bin " << i << ", offset = " << bins_[i].first << ", vec = ";
for (auto k : bins_[i].second) os << k << ", ";
os << "\n";
}
KNF_LOG(INFO) << os.str();
}
}
// "power_spectrum" contains fft energies.
void MelBanks::Compute(const float *power_spectrum,
float *mel_energies_out) const {
int32_t num_bins = bins_.size();
for (int32_t i = 0; i < num_bins; i++) {
int32_t offset = bins_[i].first;
const auto &v = bins_[i].second;
float energy = 0;
for (int32_t k = 0; k != v.size(); ++k) {
energy += v[k] * power_spectrum[k + offset];
}
// HTK-like flooring- for testing purposes (we prefer dither)
if (htk_mode_ && energy < 1.0) {
energy = 1.0;
}
mel_energies_out[i] = energy;
// The following assert was added due to a problem with OpenBlas that
// we had at one point (it was a bug in that library). Just to detect
// it early.
KNF_CHECK_EQ(energy, energy); // check that energy is not nan
}
if (debug_) {
fprintf(stderr, "MEL BANKS:\n");
for (int32_t i = 0; i < num_bins; i++)
fprintf(stderr, " %f", mel_energies_out[i]);
fprintf(stderr, "\n");
}
}
} // namespace knf

@ -0,0 +1,115 @@
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// This file is copied/modified from kaldi/src/feat/mel-computations.h
#ifndef KALDI_NATIVE_FBANK_CSRC_MEL_COMPUTATIONS_H_
#define KALDI_NATIVE_FBANK_CSRC_MEL_COMPUTATIONS_H_
#include <cmath>
#include <string>
#include "kaldi-native-fbank/csrc/feature-window.h"
namespace knf {
struct MelBanksOptions {
int32_t num_bins = 25; // e.g. 25; number of triangular bins
float low_freq = 20; // e.g. 20; lower frequency cutoff
// an upper frequency cutoff; 0 -> no cutoff, negative
// ->added to the Nyquist frequency to get the cutoff.
float high_freq = 0;
float vtln_low = 100; // vtln lower cutoff of warping function.
// vtln upper cutoff of warping function: if negative, added
// to the Nyquist frequency to get the cutoff.
float vtln_high = -500;
bool debug_mel = false;
// htk_mode is a "hidden" config, it does not show up on command line.
// Enables more exact compatibility with HTK, for testing purposes. Affects
// mel-energy flooring and reproduces a bug in HTK.
bool htk_mode = false;
std::string ToString() const {
std::ostringstream os;
os << "num_bins: " << num_bins << "\n";
os << "low_freq: " << low_freq << "\n";
os << "high_freq: " << high_freq << "\n";
os << "vtln_low: " << vtln_low << "\n";
os << "vtln_high: " << vtln_high << "\n";
os << "debug_mel: " << debug_mel << "\n";
os << "htk_mode: " << htk_mode << "\n";
return os.str();
}
};
std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts);
class MelBanks {
public:
static inline float InverseMelScale(float mel_freq) {
return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f);
}
static inline float MelScale(float freq) {
return 1127.0f * logf(1.0f + freq / 700.0f);
}
static float VtlnWarpFreq(
float vtln_low_cutoff,
float vtln_high_cutoff, // discontinuities in warp func
float low_freq,
float high_freq, // upper+lower frequency cutoffs in
// the mel computation
float vtln_warp_factor, float freq);
static float VtlnWarpMelFreq(float vtln_low_cutoff, float vtln_high_cutoff,
float low_freq, float high_freq,
float vtln_warp_factor, float mel_freq);
// TODO(fangjun): Remove vtln_warp_factor
MelBanks(const MelBanksOptions &opts,
const FrameExtractionOptions &frame_opts, float vtln_warp_factor);
/// Compute Mel energies (note: not log energies).
/// At input, "fft_energies" contains the FFT energies (not log).
///
/// @param fft_energies 1-D array of size num_fft_bins/2+1
/// @param mel_energies_out 1-D array of size num_mel_bins
void Compute(const float *fft_energies, float *mel_energies_out) const;
int32_t NumBins() const { return bins_.size(); }
private:
// center frequencies of bins, numbered from 0 ... num_bins-1.
// Needed by GetCenterFreqs().
std::vector<float> center_freqs_;
// the "bins_" vector is a vector, one for each bin, of a pair:
// (the first nonzero fft-bin), (the vector of weights).
std::vector<std::pair<int32_t, std::vector<float>>> bins_;
// TODO(fangjun): Remove debug_ and htk_mode_
bool debug_;
bool htk_mode_;
};
} // namespace knf
#endif // KALDI_NATIVE_FBANK_CSRC_MEL_COMPUTATIONS_H_

@ -0,0 +1,66 @@
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kaldi-native-fbank/csrc/rfft.h"
#include <cmath>
#include <vector>
#include "kaldi-native-fbank/csrc/log.h"
// see fftsg.c
#ifdef __cplusplus
extern "C" void rdft(int n, int isgn, double *a, int *ip, double *w);
#else
void rdft(int n, int isgn, double *a, int *ip, double *w);
#endif
namespace knf {
class Rfft::RfftImpl {
public:
explicit RfftImpl(int32_t n) : n_(n), ip_(2 + std::sqrt(n / 2)), w_(n / 2) {
KNF_CHECK_EQ(n & (n - 1), 0);
}
void Compute(float *in_out) {
std::vector<double> d(in_out, in_out + n_);
Compute(d.data());
std::copy(d.begin(), d.end(), in_out);
}
void Compute(double *in_out) {
// 1 means forward fft
rdft(n_, 1, in_out, ip_.data(), w_.data());
}
private:
int32_t n_;
std::vector<int32_t> ip_;
std::vector<double> w_;
};
Rfft::Rfft(int32_t n) : impl_(std::make_unique<RfftImpl>(n)) {}
Rfft::~Rfft() = default;
void Rfft::Compute(float *in_out) { impl_->Compute(in_out); }
void Rfft::Compute(double *in_out) { impl_->Compute(in_out); }
} // namespace knf

@ -0,0 +1,56 @@
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef KALDI_NATIVE_FBANK_CSRC_RFFT_H_
#define KALDI_NATIVE_FBANK_CSRC_RFFT_H_
#include <memory>
namespace knf {
// n-point Real discrete Fourier transform
// where n is a power of 2. n >= 2
//
// R[k] = sum_j=0^n-1 in[j]*cos(2*pi*j*k/n), 0<=k<=n/2
// I[k] = sum_j=0^n-1 in[j]*sin(2*pi*j*k/n), 0<k<n/2
class Rfft {
public:
// @param n Number of fft bins. it should be a power of 2.
explicit Rfft(int32_t n);
~Rfft();
/** @param in_out A 1-D array of size n.
* On return:
* in_out[0] = R[0]
* in_out[1] = R[n/2]
* for 1 < k < n/2,
* in_out[2*k] = R[k]
* in_out[2*k+1] = I[k]
*
*/
void Compute(float *in_out);
void Compute(double *in_out);
private:
class RfftImpl;
std::unique_ptr<RfftImpl> impl_;
};
} // namespace knf
#endif // KALDI_NATIVE_FBANK_CSRC_RFFT_H_

@ -1,111 +0,0 @@
# checkout the thirdparty/kaldi/base/kaldi-types.h
# compile kaldi without openfst
add_definitions("-DCOMPILE_WITHOUT_OPENFST")
if ((NOT EXISTS ${CMAKE_CURRENT_LIST_DIR}/base))
file(COPY ../../../../speechx/speechx/kaldi/base DESTINATION ${CMAKE_CURRENT_LIST_DIR})
file(COPY ../../../../speechx/speechx/kaldi/feat DESTINATION ${CMAKE_CURRENT_LIST_DIR})
file(COPY ../../../../speechx/speechx/kaldi/matrix DESTINATION ${CMAKE_CURRENT_LIST_DIR})
file(COPY ../../../../speechx/speechx/kaldi/util DESTINATION ${CMAKE_CURRENT_LIST_DIR})
endif()
# kaldi-base
add_library(kaldi-base STATIC
base/io-funcs.cc
base/kaldi-error.cc
base/kaldi-math.cc
base/kaldi-utils.cc
base/timer.cc
)
target_include_directories(kaldi-base PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
# kaldi-matrix
add_library(kaldi-matrix STATIC
matrix/compressed-matrix.cc
matrix/matrix-functions.cc
matrix/kaldi-matrix.cc
matrix/kaldi-vector.cc
matrix/optimization.cc
matrix/packed-matrix.cc
matrix/qr.cc
matrix/sparse-matrix.cc
matrix/sp-matrix.cc
matrix/srfft.cc
matrix/tp-matrix.cc
)
target_include_directories(kaldi-matrix PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
if (NOT MSVC)
target_link_libraries(kaldi-matrix PUBLIC kaldi-base libopenblas)
else()
target_link_libraries(kaldi-matrix PUBLIC kaldi-base openblas)
endif()
# kaldi-util
add_library(kaldi-util STATIC
util/kaldi-holder.cc
util/kaldi-io.cc
util/kaldi-semaphore.cc
util/kaldi-table.cc
util/kaldi-thread.cc
util/parse-options.cc
util/simple-io-funcs.cc
util/simple-options.cc
util/text-utils.cc
)
target_include_directories(kaldi-util PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
target_link_libraries(kaldi-util PUBLIC kaldi-base kaldi-matrix)
# kaldi-feat-common
add_library(kaldi-feat-common STATIC
feat/cmvn.cc
feat/feature-functions.cc
feat/feature-window.cc
feat/mel-computations.cc
feat/pitch-functions.cc
feat/resample.cc
feat/signal.cc
feat/wave-reader.cc
)
target_include_directories(kaldi-feat-common PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
target_link_libraries(kaldi-feat-common PUBLIC kaldi-base kaldi-matrix kaldi-util)
# kaldi-mfcc
add_library(kaldi-mfcc STATIC
feat/feature-mfcc.cc
)
target_include_directories(kaldi-mfcc PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
target_link_libraries(kaldi-mfcc PUBLIC kaldi-feat-common)
# kaldi-fbank
add_library(kaldi-fbank STATIC
feat/feature-fbank.cc
)
target_include_directories(kaldi-fbank PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
target_link_libraries(kaldi-fbank PUBLIC kaldi-feat-common)
set(KALDI_LIBRARIES
${CMAKE_CURRENT_BINARY_DIR}/libkaldi-base.a
${CMAKE_CURRENT_BINARY_DIR}/libkaldi-matrix.a
${CMAKE_CURRENT_BINARY_DIR}/libkaldi-util.a
${CMAKE_CURRENT_BINARY_DIR}/libkaldi-feat-common.a
${CMAKE_CURRENT_BINARY_DIR}/libkaldi-mfcc.a
${CMAKE_CURRENT_BINARY_DIR}/libkaldi-fbank.a
)
add_library(libkaldi INTERFACE)
add_dependencies(libkaldi kaldi-base kaldi-matrix kaldi-util kaldi-feat-common kaldi-mfcc kaldi-fbank)
target_include_directories(libkaldi INTERFACE ${CMAKE_CURRENT_SOURCE_DIR})
if (APPLE)
target_link_libraries(libkaldi INTERFACE ${KALDI_LIBRARIES} libopenblas ${GFORTRAN_LIBRARIES_DIR}/libgfortran.a ${GFORTRAN_LIBRARIES_DIR}/libquadmath.a ${GFORTRAN_LIBRARIES_DIR}/libgcc_s.1.1.dylib)
elseif (MSVC)
target_link_libraries(libkaldi INTERFACE kaldi-base kaldi-matrix kaldi-util kaldi-feat-common kaldi-mfcc kaldi-fbank openblas)
else()
target_link_libraries(libkaldi INTERFACE -Wl,--start-group -Wl,--whole-archive ${KALDI_LIBRARIES} libopenblas.a gfortran -Wl,--no-whole-archive -Wl,--end-group)
endif()
target_compile_definitions(libkaldi INTERFACE "-DCOMPILE_WITHOUT_OPENFST")

@ -40,19 +40,13 @@ COMMITID = 'none'
base = [
"kaldiio",
"librosa==0.8.1",
"scipy>=1.0.0",
"soundfile~=0.10",
"colorlog",
"pathos == 0.2.8",
"pathos",
"pybind11",
"parameterized",
"tqdm",
"scikit-learn"
]
requirements = {
"install":
base,
"install": base,
"develop": [
"sox",
"soxbindings",
@ -60,6 +54,7 @@ requirements = {
],
}
def check_call(cmd: str, shell=False, executable=None):
try:
sp.check_call(
@ -92,6 +87,7 @@ def check_output(cmd: Union[str, List[str], Tuple[str]], shell=False):
file=sys.stderr)
return out_bytes.strip().decode('utf8')
def _run_cmd(cmd):
try:
return subprocess.check_output(
@ -100,6 +96,7 @@ def _run_cmd(cmd):
except Exception:
return None
@contextlib.contextmanager
def pushd(new_dir):
old_dir = os.getcwd()
@ -109,22 +106,26 @@ def pushd(new_dir):
os.chdir(old_dir)
print(old_dir)
def read(*names, **kwargs):
with io.open(
os.path.join(os.path.dirname(__file__), *names),
encoding=kwargs.get("encoding", "utf8")) as fp:
return fp.read()
def _remove(files: str):
for f in files:
f.unlink()
################################# Install ##################################
def _post_install(install_lib_dir):
pass
class DevelopCommand(develop):
def run(self):
develop.run(self)
@ -142,7 +143,7 @@ class TestCommand(test):
# Run nose ensuring that argv simulates running nosetests directly
import nose
nose.run_exit(argv=['nosetests', '-w', 'tests'])
def run_benchmark(self):
for benchmark_item in glob.glob('tests/benchmark/*py'):
os.system(f'pytest {benchmark_item}')
@ -188,6 +189,7 @@ def _make_version_file(version, sha):
with open(version_path, "a") as f:
f.write(f"__version__ = '{version}'\n")
def _rm_version():
file_ = ROOT_DIR / "paddleaudio" / "__init__.py"
with open(file_, "r") as f:
@ -235,8 +237,8 @@ def main():
if platform.system() != 'Windows' and platform.system() != 'Linux':
lib_package_data = {'paddleaudio': ['lib/libgcc_s.1.1.dylib']}
if platform.system() == 'Linux':
lib_package_data = {'paddleaudio': ['lib/lib*']}
#if platform.system() == 'Linux':
# lib_package_data = {'paddleaudio': ['lib/lib*']}
setup_info = dict(
# Metadata
@ -254,8 +256,7 @@ def main():
python_requires='>=3.7',
install_requires=requirements["install"],
extras_require={
'develop':
requirements["develop"],
'develop': requirements["develop"],
#'test': ["nose", "torchaudio==0.10.2", "pytest-benchmark", "librosa=0.8.1", "parameterized", "paddlepaddle"],
},
cmdclass={
@ -267,7 +268,7 @@ def main():
},
# Package info
packages=find_packages(include=('paddleaudio*')),
packages=find_packages(include=['paddleaudio*']),
package_data=lib_package_data,
ext_modules=setup_helpers.get_ext_modules(),
zip_safe=True,
@ -284,11 +285,11 @@ def main():
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
],
)
], )
setup(**setup_info)
_rm_version()
if __name__ == '__main__':
main()

@ -18,139 +18,7 @@ Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
import argparse
import codecs
import json
import os
from pathlib import Path
import soundfile
from utils.utility import download
from utils.utility import unpack
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
URL_ROOT = 'http://www.openslr.org/resources/62'
# URL_ROOT = 'https://openslr.magicdatatech.com/resources/62'
DATA_URL = URL_ROOT + '/aidatatang_200zh.tgz'
MD5_DATA = '6e0f4f39cd5f667a7ee53c397c8d0949'
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
default=DATA_HOME + "/aidatatang_200zh",
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
def create_manifest(data_dir, manifest_path_prefix):
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
transcript_path = os.path.join(data_dir, 'transcript',
'aidatatang_200_zh_transcript.txt')
transcript_dict = {}
for line in codecs.open(transcript_path, 'r', 'utf-8'):
line = line.strip()
if line == '':
continue
audio_id, text = line.split(' ', 1)
# remove withespace, charactor text
text = ''.join(text.split())
transcript_dict[audio_id] = text
data_types = ['train', 'dev', 'test']
for dtype in data_types:
del json_lines[:]
total_sec = 0.0
total_text = 0.0
total_num = 0
audio_dir = os.path.join(data_dir, 'corpus/', dtype)
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist:
if not fname.endswith('.wav'):
continue
audio_path = os.path.abspath(os.path.join(subfolder, fname))
audio_id = os.path.basename(fname)[:-4]
utt2spk = Path(audio_path).parent.name
audio_data, samplerate = soundfile.read(audio_path)
duration = float(len(audio_data) / samplerate)
text = transcript_dict[audio_id]
json_lines.append(
json.dumps(
{
'utt': audio_id,
'utt2spk': str(utt2spk),
'feat': audio_path,
'feat_shape': (duration, ), # second
'text': text,
},
ensure_ascii=False))
total_sec += duration
total_text += len(text)
total_num += 1
manifest_path = manifest_path_prefix + '.' + dtype
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
manifest_dir = os.path.dirname(manifest_path_prefix)
meta_path = os.path.join(manifest_dir, dtype) + '.meta'
with open(meta_path, 'w') as f:
print(f"{dtype}:", file=f)
print(f"{total_num} utts", file=f)
print(f"{total_sec / (60*60)} h", file=f)
print(f"{total_text} text", file=f)
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(url, md5sum, target_dir, manifest_path, subset):
"""Download, unpack and create manifest file."""
data_dir = os.path.join(target_dir, subset)
if not os.path.exists(data_dir):
filepath = download(url, md5sum, target_dir)
unpack(filepath, target_dir)
# unpack all audio tar files
audio_dir = os.path.join(data_dir, 'corpus')
for subfolder, dirlist, filelist in sorted(os.walk(audio_dir)):
for sub in dirlist:
print(f"unpack dir {sub}...")
for folder, _, filelist in sorted(
os.walk(os.path.join(subfolder, sub))):
for ftar in filelist:
unpack(os.path.join(folder, ftar), folder, True)
else:
print("Skip downloading and unpacking. Data already exists in %s." %
target_dir)
create_manifest(data_dir, manifest_path)
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
prepare_dataset(
url=DATA_URL,
md5sum=MD5_DATA,
target_dir=args.target_dir,
manifest_path=args.manifest_prefix,
subset='aidatatang_200zh')
print("Data download and manifest prepare done!")
from paddlespeech.dataset.aidatatang_200zh import aidatatang_200zh_main
if __name__ == '__main__':
main()
aidatatang_200zh_main()

@ -1,3 +0,0 @@
# [Aishell1](http://openslr.elda.org/33/)
This Open Source Mandarin Speech Corpus, AISHELL-ASR0009-OS1, is 178 hours long. It is a part of AISHELL-ASR0009, of which utterance contains 11 domains, including smart home, autonomous driving, and industrial production. The whole recording was put in quiet indoor environment, using 3 different devices at the same time: high fidelity microphone (44.1kHz, 16-bit,); Android-system mobile phone (16kHz, 16-bit), iOS-system mobile phone (16kHz, 16-bit). Audios in high fidelity were re-sampled to 16kHz to build AISHELL- ASR0009-OS1. 400 speakers from different accent areas in China were invited to participate in the recording. The manual transcription accuracy rate is above 95%, through professional speech annotation and strict quality inspection. The corpus is divided into training, development and testing sets. ( This database is free for academic research, not in the commerce, if without permission. )

@ -18,143 +18,7 @@ Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
import argparse
import codecs
import json
import os
from pathlib import Path
import soundfile
from utils.utility import download
from utils.utility import unpack
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
URL_ROOT = 'http://openslr.elda.org/resources/33'
# URL_ROOT = 'https://openslr.magicdatatech.com/resources/33'
DATA_URL = URL_ROOT + '/data_aishell.tgz'
MD5_DATA = '2f494334227864a8a8fec932999db9d8'
RESOURCE_URL = URL_ROOT + '/resource_aishell.tgz'
MD5_RESOURCE = '957d480a0fcac85fc18e550756f624e5'
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
default=DATA_HOME + "/Aishell",
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
def create_manifest(data_dir, manifest_path_prefix):
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
transcript_path = os.path.join(data_dir, 'transcript',
'aishell_transcript_v0.8.txt')
transcript_dict = {}
for line in codecs.open(transcript_path, 'r', 'utf-8'):
line = line.strip()
if line == '':
continue
audio_id, text = line.split(' ', 1)
# remove withespace, charactor text
text = ''.join(text.split())
transcript_dict[audio_id] = text
data_types = ['train', 'dev', 'test']
for dtype in data_types:
del json_lines[:]
total_sec = 0.0
total_text = 0.0
total_num = 0
audio_dir = os.path.join(data_dir, 'wav', dtype)
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist:
audio_path = os.path.abspath(os.path.join(subfolder, fname))
audio_id = os.path.basename(fname)[:-4]
# if no transcription for audio then skipped
if audio_id not in transcript_dict:
continue
utt2spk = Path(audio_path).parent.name
audio_data, samplerate = soundfile.read(audio_path)
duration = float(len(audio_data) / samplerate)
text = transcript_dict[audio_id]
json_lines.append(
json.dumps(
{
'utt': audio_id,
'utt2spk': str(utt2spk),
'feat': audio_path,
'feat_shape': (duration, ), # second
'text': text
},
ensure_ascii=False))
total_sec += duration
total_text += len(text)
total_num += 1
manifest_path = manifest_path_prefix + '.' + dtype
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
manifest_dir = os.path.dirname(manifest_path_prefix)
meta_path = os.path.join(manifest_dir, dtype) + '.meta'
with open(meta_path, 'w') as f:
print(f"{dtype}:", file=f)
print(f"{total_num} utts", file=f)
print(f"{total_sec / (60*60)} h", file=f)
print(f"{total_text} text", file=f)
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(url, md5sum, target_dir, manifest_path=None):
"""Download, unpack and create manifest file."""
data_dir = os.path.join(target_dir, 'data_aishell')
if not os.path.exists(data_dir):
filepath = download(url, md5sum, target_dir)
unpack(filepath, target_dir)
# unpack all audio tar files
audio_dir = os.path.join(data_dir, 'wav')
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for ftar in filelist:
unpack(os.path.join(subfolder, ftar), subfolder, True)
else:
print("Skip downloading and unpacking. Data already exists in %s." %
target_dir)
if manifest_path:
create_manifest(data_dir, manifest_path)
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
prepare_dataset(
url=DATA_URL,
md5sum=MD5_DATA,
target_dir=args.target_dir,
manifest_path=args.manifest_prefix)
prepare_dataset(
url=RESOURCE_URL,
md5sum=MD5_RESOURCE,
target_dir=args.target_dir,
manifest_path=None)
print("Data download and manifest prepare done!")
from paddlespeech.dataset.aishell import aishell_main
if __name__ == '__main__':
main()
aishell_main()

@ -28,8 +28,8 @@ from multiprocessing.pool import Pool
import distutils.util
import soundfile
from utils.utility import download
from utils.utility import unpack
from paddlespeech.dataset.download import download
from paddlespeech.dataset.download import unpack
URL_ROOT = "http://openslr.elda.org/resources/12"
#URL_ROOT = "https://openslr.magicdatatech.com/resources/12"

@ -27,8 +27,8 @@ from multiprocessing.pool import Pool
import soundfile
from utils.utility import download
from utils.utility import unpack
from paddlespeech.dataset.download import download
from paddlespeech.dataset.download import unpack
URL_ROOT = "http://openslr.elda.org/resources/31"
URL_TRAIN_CLEAN = URL_ROOT + "/train-clean-5.tar.gz"

@ -29,8 +29,8 @@ import os
import soundfile
from utils.utility import download
from utils.utility import unpack
from paddlespeech.dataset.download import download
from paddlespeech.dataset.download import unpack
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')

@ -29,8 +29,8 @@ import os
import soundfile
from utils.utility import download
from utils.utility import unzip
from paddlespeech.dataset.download import download
from paddlespeech.dataset.download import unzip
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')

@ -0,0 +1,13 @@
# [TAL_CSASR](https://ai.100tal.com/dataset/)
This data set is TAL English class audio, including mixed Chinese and English speech. Each audio has only one speaker, and this data set has more than 100 speakers. (File 63.36G) This data contains the sample of intra sentence and inter sentence mixing. The ratio between Chinese characters and English words in the data is 13:1.
- Total data: 587H (train_set: 555.9H, dev_set: 8H, test_set: 23.6H)
- Sample rate: 16000
- Sample bit: 16
- Recording device: microphone
- Speaker number: 200+
- Recording time: 2019
- Data format: audio: .wav; test: .txt
- Audio duration: 1-60s
- Data type: audio of English teachers' teaching

@ -0,0 +1,116 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prepare TALCS ASR datasets.
create manifest files.
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
import argparse
import codecs
import io
import json
import os
import soundfile
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
TRAIN_SET = os.path.join(args.target_dir, "train_set")
DEV_SET = os.path.join(args.target_dir, "dev_set")
TEST_SET = os.path.join(args.target_dir, "test_set")
manifest_train_path = os.path.join(args.manifest_prefix, "manifest.train.raw")
manifest_dev_path = os.path.join(args.manifest_prefix, "manifest.dev.raw")
manifest_test_path = os.path.join(args.manifest_prefix, "manifest.test.raw")
def create_manifest(data_dir, manifest_path):
"""Create a manifest json file summarizing the data set, with each line
containing the meta data (i.e. audio filepath, transcription text, audio
duration) of each audio file within the data set.
"""
print("Creating manifest %s ..." % manifest_path)
json_lines = []
total_sec = 0.0
total_char = 0.0
total_num = 0
wav_dir = os.path.join(data_dir, 'wav')
text_filepath = os.path.join(data_dir, 'label.txt')
for subfolder, _, filelist in sorted(os.walk(wav_dir)):
for line in io.open(text_filepath, encoding="utf8"):
segments = line.strip().split()
nchars = len(segments[1:])
text = ' '.join(segments[1:]).lower()
audio_filepath = os.path.abspath(
os.path.join(subfolder, segments[0] + '.wav'))
audio_data, samplerate = soundfile.read(audio_filepath)
duration = float(len(audio_data)) / samplerate
utt = os.path.splitext(os.path.basename(audio_filepath))[0]
utt2spk = '-'.join(utt.split('-')[:2])
json_lines.append(
json.dumps({
'utt': utt,
'utt2spk': utt2spk,
'feat': audio_filepath,
'feat_shape': (duration, ), # second
'text': text,
}))
total_sec += duration
total_char += nchars
total_num += 1
with codecs.open(manifest_path, 'w', 'utf-8') as out_file:
for line in json_lines:
out_file.write(line + '\n')
subset = os.path.splitext(manifest_path)[1][1:]
manifest_dir = os.path.dirname(manifest_path)
data_dir_name = os.path.split(data_dir)[-1]
meta_path = os.path.join(manifest_dir, data_dir_name) + '.meta'
with open(meta_path, 'w') as f:
print(f"{subset}:", file=f)
print(f"{total_num} utts", file=f)
print(f"{total_sec / (60*60)} h", file=f)
print(f"{total_char} char", file=f)
print(f"{total_char / total_sec} char/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
create_manifest(TRAIN_SET, manifest_train_path)
create_manifest(DEV_SET, manifest_dev_path)
create_manifest(TEST_SET, manifest_test_path)
print("Data download and manifest prepare done!")
if __name__ == '__main__':
main()

@ -27,8 +27,8 @@ from pathlib import Path
import soundfile
from utils.utility import download
from utils.utility import unpack
from paddlespeech.dataset.download import download
from paddlespeech.dataset.download import unpack
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')

@ -28,7 +28,7 @@ from pathlib import Path
import soundfile
from utils.utility import unzip
from paddlespeech.dataset.download import unzip
URL_ROOT = ""
MD5_DATA = "45c68037c7fdfe063a43c851f181fb2d"

@ -31,9 +31,9 @@ from pathlib import Path
import soundfile
from utils.utility import check_md5sum
from utils.utility import download
from utils.utility import unzip
from paddlespeech.dataset.download import check_md5sum
from paddlespeech.dataset.download import download
from paddlespeech.dataset.download import unzip
# all the data will be download in the current data/voxceleb directory default
DATA_HOME = os.path.expanduser('.')

@ -27,9 +27,9 @@ from pathlib import Path
import soundfile
from utils.utility import check_md5sum
from utils.utility import download
from utils.utility import unzip
from paddlespeech.dataset.download import check_md5sum
from paddlespeech.dataset.download import download
from paddlespeech.dataset.download import unzip
# all the data will be download in the current data/voxceleb directory default
DATA_HOME = os.path.expanduser('.')

@ -28,9 +28,9 @@ import subprocess
import soundfile
from utils.utility import download_multi
from utils.utility import getfile_insensitive
from utils.utility import unpack
from paddlespeech.dataset.download import download_multi
from paddlespeech.dataset.download import getfile_insensitive
from paddlespeech.dataset.download import unpack
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')

@ -1,6 +1,6 @@
# 语音合成 Java API Demo 使用指南
在 Android 上实现语音合成功能,此 Demo 有很好的易用性和开放性,如在 Demo 中跑自己训练好的模型等。
在 Android 上实现语音合成功能,此 Demo 有很好的易用性和开放性,如在 Demo 中跑自己训练好的模型等。
本文主要介绍语音合成 Demo 运行方法。
@ -157,8 +157,11 @@ Android 示例基于 Java API 开发,调用 Paddle Lite `Java API` 包括以
### 更新输入
**本 Demo 不包含文本前端模块**,通过下拉框选择预先设置好的文本,在代码中映射成对应的 phone_id**如需文本前端模块请自行处理**`phone_id_map.txt`
请参考 [fastspeech2_cnndecoder_csmsc_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_pdlite_1.3.0.zip)。
**本 Demo 不包含文本前端模块**,通过下拉框选择预先设置好的文本,在代码中映射成对应的 phone_id**如需文本前端模块请自行处理**,可参考:
- C++ 中文前端 [lym0302/paddlespeech_tts_cpp](https://github.com/lym0302/paddlespeech_tts_cpp)
- C++ 英文 g2p [yazone/g2pE_mobile](https://github.com/yazone/g2pE_mobile)
`phone_id_map.txt` 请参考 [fastspeech2_cnndecoder_csmsc_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_pdlite_1.3.0.zip)。
## 通过 setting 界面更新语音合成的相关参数

@ -0,0 +1,8 @@
# 目录
build/
output/
libs/
models/
# 符号连接
dict

@ -0,0 +1,91 @@
# TTS ARM Linux C++ Demo
修改自 [demos/TTSAndroid](../TTSAndroid),模型也来自该安卓 Demo。
### 配置编译选项
打开 [config.sh](config.sh) 按需修改配置。
默认编译 64 位版本,如果要编译 32 位版本,把 `ARM_ABI=armv8` 改成 `ARM_ABI=armv7hf`
### 安装依赖
```bash
# Ubuntu
sudo apt install build-essential cmake pkg-config wget tar unzip
# CentOS
sudo yum groupinstall "Development Tools"
sudo yum install cmake wget tar unzip
```
### 下载 Paddle Lite 库文件和模型文件
预编译的二进制使用与安卓 Demo 版本相同的 Paddle Lite 推理库([Paddle-Lite:68b66fd35](https://github.com/PaddlePaddle/Paddle-Lite/tree/68b66fd356c875c92167d311ad458e6093078449))和模型([fs2cnn_mbmelgan_cpu_v1.3.0](https://paddlespeech.bj.bcebos.com/demos/TTSAndroid/fs2cnn_mbmelgan_cpu_v1.3.0.tar.gz))。
可用以下命令下载:
```bash
./download.sh
```
### 编译 Demo
```bash
./build.sh
```
预编译的二进制兼容 Ubuntu 16.04 到 20.04。
如果编译或链接失败,说明发行版与预编译库不兼容,请尝试手动编译 Paddle Lite 库,具体步骤在最下面。
### 运行
你可以修改 `./front.conf``--phone2id_path` 参数为你自己的声学模型的 `phone_id_map.txt`
```bash
./run.sh
./run.sh --sentence "语音合成测试"
./run.sh --sentence "输出到指定的音频文件" --output_wav ./output/test.wav
./run.sh --help
```
目前只支持中文合成,出现任何英文都会导致程序崩溃。
如果未指定`--wav_file`,默认输出到`./output/tts.wav`。
## 手动编译 Paddle Lite 库
预编译的二进制兼容 Ubuntu 16.04 到 20.04,如果你的发行版与其不兼容,可以自行从源代码编译。
注意,我们只能保证 [Paddle-Lite:68b66fd35](https://github.com/PaddlePaddle/Paddle-Lite/tree/68b66fd356c875c92167d311ad458e6093078449) 与通过 `download.sh` 下载的模型兼容。
如果使用其他版本的 Paddle Lite 库,可能需要用对应版本的 opt 工具重新导出模型。
此外,[Paddle-Lite 2.12](https://github.com/PaddlePaddle/Paddle-Lite/releases/tag/v2.12) 与 TTS 不兼容,无法导出或运行 TTS 模型,需要使用更新的版本(比如 `develop` 分支中的代码)。
`develop` 分支中的代码可能与通过 `download.sh` 下载的模型不兼容Demo 运行起来可能会崩溃。
### 安装 Paddle Lite 的编译依赖
```bash
# Ubuntu
sudo apt install build-essential cmake git python
# CentOS
sudo yum groupinstall "Development Tools"
sudo yum install cmake git python
```
### 编译 Paddle Lite 68b66fd35
```
git clone https://github.com/PaddlePaddle/Paddle-Lite.git
cd Paddle-Lite
git checkout 68b66fd356c875c92167d311ad458e6093078449
./lite/tools/build_linux.sh --with_extra=ON
```
编译完成后,打开 Demo 的 [config.sh](config.sh),把 `PADDLE_LITE_DIR` 改成以下值即可(注意替换 `/path/to/` 为实际目录):
```
PADDLE_LITE_DIR="/path/to/Paddle-Lite/build.lite.linux.${ARM_ABI}.gcc/inference_lite_lib.armlinux.${ARM_ABI}/cxx"
```

@ -0,0 +1 @@
src/TTSCppFrontend/build-depends.sh

@ -0,0 +1,29 @@
#!/bin/bash
set -e
set -x
cd "$(dirname "$(realpath "$0")")"
BASE_DIR="$PWD"
# load configure
. ./config.sh
# build
echo "ARM_ABI is ${ARM_ABI}"
echo "PADDLE_LITE_DIR is ${PADDLE_LITE_DIR}"
echo "Build depends..."
./build-depends.sh "$@"
mkdir -p "$BASE_DIR/build"
cd "$BASE_DIR/build"
cmake -DPADDLE_LITE_DIR="${PADDLE_LITE_DIR}" -DARM_ABI="${ARM_ABI}" ../src
if [ "$*" = "" ]; then
make -j$(nproc)
else
make "$@"
fi
echo "make successful!"

@ -0,0 +1,23 @@
#!/bin/bash
set -e
set -x
cd "$(dirname "$(realpath "$0")")"
BASE_DIR="$PWD"
# load configure
. ./config.sh
# remove dirs
set -x
rm -rf "$OUTPUT_DIR"
rm -rf "$LIBS_DIR"
rm -rf "$MODELS_DIR"
rm -rf "$BASE_DIR/build"
"$BASE_DIR/src/TTSCppFrontend/clean.sh"
# 符号连接
rm "$BASE_DIR/dict"

@ -0,0 +1,15 @@
# configuration
ARM_ABI=armv8
#ARM_ABI=armv7hf
MODELS_DIR="${PWD}/models"
LIBS_DIR="${PWD}/libs"
OUTPUT_DIR="${PWD}/output"
PADDLE_LITE_DIR="${LIBS_DIR}/inference_lite_lib.armlinux.${ARM_ABI}.gcc.with_extra.with_cv/cxx"
#PADDLE_LITE_DIR="/path/to/Paddle-Lite/build.lite.linux.${ARM_ABI}.gcc/inference_lite_lib.armlinux.${ARM_ABI}/cxx"
ACOUSTIC_MODEL_PATH="${MODELS_DIR}/cpu/fastspeech2_csmsc_arm.nb"
VOCODER_PATH="${MODELS_DIR}/cpu/mb_melgan_csmsc_arm.nb"
FRONT_CONF="${PWD}/front.conf"

@ -0,0 +1,70 @@
#!/bin/bash
set -e
cd "$(dirname "$(realpath "$0")")"
BASE_DIR="$PWD"
# load configure
. ./config.sh
mkdir -p "$LIBS_DIR" "$MODELS_DIR"
download() {
file="$1"
url="$2"
md5="$3"
dir="$4"
cd "$dir"
if [ -f "$file" ] && [ "$(md5sum "$file" | awk '{ print $1 }')" = "$md5" ]; then
echo "File $file (MD5: $md5) has been downloaded."
else
echo "Downloading $file..."
wget -O "$file" "$url"
# MD5 verify
fileMd5="$(md5sum "$file" | awk '{ print $1 }')"
if [ "$fileMd5" == "$md5" ]; then
echo "File $file (MD5: $md5) has been downloaded."
else
echo "MD5 mismatch, file may be corrupt"
echo "$file MD5: $fileMd5, it should be $md5"
fi
fi
echo "Extracting $file..."
echo '-----------------------'
tar -vxf "$file"
echo '======================='
}
########################################
echo "Download models..."
download 'inference_lite_lib.armlinux.armv8.gcc.with_extra.with_cv.tar.gz' \
'https://paddlespeech.bj.bcebos.com/demos/TTSArmLinux/inference_lite_lib.armlinux.armv8.gcc.with_extra.with_cv.tar.gz' \
'39e0c6604f97c70f5d13c573d7e709b9' \
"$LIBS_DIR"
download 'inference_lite_lib.armlinux.armv7hf.gcc.with_extra.with_cv.tar.gz' \
'https://paddlespeech.bj.bcebos.com/demos/TTSArmLinux/inference_lite_lib.armlinux.armv7hf.gcc.with_extra.with_cv.tar.gz' \
'f5ceb509f0b610dafb8379889c5f36f8' \
"$LIBS_DIR"
download 'fs2cnn_mbmelgan_cpu_v1.3.0.tar.gz' \
'https://paddlespeech.bj.bcebos.com/demos/TTSAndroid/fs2cnn_mbmelgan_cpu_v1.3.0.tar.gz' \
'93ef17d44b498aff3bea93e2c5c09a1e' \
"$MODELS_DIR"
echo "Done."
########################################
echo "Download dictionary files..."
ln -s src/TTSCppFrontend/front_demo/dict "$BASE_DIR/"
"$BASE_DIR/src/TTSCppFrontend/download.sh"

@ -0,0 +1,21 @@
# jieba conf
--jieba_dict_path=./dict/jieba/jieba.dict.utf8
--jieba_hmm_path=./dict/jieba/hmm_model.utf8
--jieba_user_dict_path=./dict/jieba/user.dict.utf8
--jieba_idf_path=./dict/jieba/idf.utf8
--jieba_stop_word_path=./dict/jieba/stop_words.utf8
# dict conf fastspeech2_0.4
--separate_tone=false
--word2phone_path=./dict/fastspeech2_nosil_baker_ckpt_0.4/word2phone_fs2.dict
--phone2id_path=./dict/fastspeech2_nosil_baker_ckpt_0.4/phone_id_map.txt
--tone2id_path=./dict/fastspeech2_nosil_baker_ckpt_0.4/word2phone_fs2.dict
# dict conf speedyspeech_0.5
#--separate_tone=true
#--word2phone_path=./dict/speedyspeech_nosil_baker_ckpt_0.5/word2phone.dict
#--phone2id_path=./dict/speedyspeech_nosil_baker_ckpt_0.5/phone_id_map.txt
#--tone2id_path=./dict/speedyspeech_nosil_baker_ckpt_0.5/tone_id_map.txt
# dict of tranditional_to_simplified
--trand2simpd_path=./dict/tranditional_to_simplified/trand2simp.txt

@ -0,0 +1,19 @@
#!/bin/bash
set -e
cd "$(dirname "$(realpath "$0")")"
# load configure
. ./config.sh
# create dir
mkdir -p "$OUTPUT_DIR"
# run
set -x
./build/paddlespeech_tts_demo \
--front_conf "$FRONT_CONF" \
--acoustic_model "$ACOUSTIC_MODEL_PATH" \
--vocoder "$VOCODER_PATH" \
"$@"
# end

@ -0,0 +1,80 @@
cmake_minimum_required(VERSION 3.10)
project(paddlespeech_tts_demo)
########## Global Options ##########
option(WITH_FRONT_DEMO "Build front demo" OFF)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(ABSL_PROPAGATE_CXX_STD ON)
########## ARM Options ##########
set(CMAKE_SYSTEM_NAME Linux)
if(ARM_ABI STREQUAL "armv8")
set(CMAKE_SYSTEM_PROCESSOR aarch64)
#set(CMAKE_C_COMPILER "aarch64-linux-gnu-gcc")
#set(CMAKE_CXX_COMPILER "aarch64-linux-gnu-g++")
elseif(ARM_ABI STREQUAL "armv7hf")
set(CMAKE_SYSTEM_PROCESSOR arm)
#set(CMAKE_C_COMPILER "arm-linux-gnueabihf-gcc")
#set(CMAKE_CXX_COMPILER "arm-linux-gnueabihf-g++")
else()
message(FATAL_ERROR "Unknown arch abi ${ARM_ABI}, only support armv8 and armv7hf.")
return()
endif()
########## Paddle Lite Options ##########
message(STATUS "TARGET ARCH ABI: ${ARM_ABI}")
message(STATUS "PADDLE LITE DIR: ${PADDLE_LITE_DIR}")
include_directories(${PADDLE_LITE_DIR}/include)
link_directories(${PADDLE_LITE_DIR}/libs/${ARM_ABI})
link_directories(${PADDLE_LITE_DIR}/lib)
if(ARM_ABI STREQUAL "armv8")
set(CMAKE_CXX_FLAGS "-march=armv8-a ${CMAKE_CXX_FLAGS}")
set(CMAKE_C_FLAGS "-march=armv8-a ${CMAKE_C_FLAGS}")
elseif(ARM_ABI STREQUAL "armv7hf")
set(CMAKE_CXX_FLAGS "-march=armv7-a -mfloat-abi=hard -mfpu=neon-vfpv4 ${CMAKE_CXX_FLAGS}")
set(CMAKE_C_FLAGS "-march=armv7-a -mfloat-abi=hard -mfpu=neon-vfpv4 ${CMAKE_C_FLAGS}" )
endif()
########## Dependencies ##########
find_package(OpenMP REQUIRED)
if(OpenMP_FOUND OR OpenMP_CXX_FOUND)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
message(STATUS "Found OpenMP ${OpenMP_VERSION} ${OpenMP_CXX_VERSION}")
message(STATUS "OpenMP C flags: ${OpenMP_C_FLAGS}")
message(STATUS "OpenMP CXX flags: ${OpenMP_CXX_FLAGS}")
message(STATUS "OpenMP OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}")
message(STATUS "OpenMP OpenMP_CXX_LIBRARIES: ${OpenMP_CXX_LIBRARIES}")
else()
message(FATAL_ERROR "Could not found OpenMP!")
return()
endif()
############### tts cpp frontend ###############
add_subdirectory(TTSCppFrontend)
include_directories(
TTSCppFrontend/src
third-party/build/src/cppjieba/include
third-party/build/src/limonp/include
)
############### paddlespeech_tts_demo ###############
add_executable(paddlespeech_tts_demo main.cc)
target_link_libraries(paddlespeech_tts_demo paddle_light_api_shared paddlespeech_tts_front)

@ -0,0 +1,320 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <chrono>
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "paddle_api.h"
using namespace paddle::lite_api;
class PredictorInterface {
public:
virtual ~PredictorInterface() = 0;
virtual bool Init(const std::string &AcousticModelPath,
const std::string &VocoderPath,
PowerMode cpuPowerMode,
int cpuThreadNum,
// WAV采样率必须与模型输出匹配
// 如果播放速度和音调异常,请修改采样率
// 常见采样率16000, 24000, 32000, 44100, 48000, 96000
uint32_t wavSampleRate) = 0;
virtual std::shared_ptr<PaddlePredictor> LoadModel(
const std::string &modelPath,
int cpuThreadNum,
PowerMode cpuPowerMode) = 0;
virtual void ReleaseModel() = 0;
virtual bool RunModel(const std::vector<int64_t> &phones) = 0;
virtual std::unique_ptr<const Tensor> GetAcousticModelOutput(
const std::vector<int64_t> &phones) = 0;
virtual std::unique_ptr<const Tensor> GetVocoderOutput(
std::unique_ptr<const Tensor> &&amOutput) = 0;
virtual void VocoderOutputToWav(
std::unique_ptr<const Tensor> &&vocOutput) = 0;
virtual void SaveFloatWav(float *floatWav, int64_t size) = 0;
virtual bool IsLoaded() = 0;
virtual float GetInferenceTime() = 0;
virtual int GetWavSize() = 0;
// 获取WAV持续时间单位毫秒
virtual float GetWavDuration() = 0;
// 获取RTF合成时间 / 音频时长)
virtual float GetRTF() = 0;
virtual void ReleaseWav() = 0;
virtual bool WriteWavToFile(const std::string &wavPath) = 0;
};
PredictorInterface::~PredictorInterface() {}
// WavDataType: WAV数据类型
// 可在 int16_t 和 float 之间切换,
// 用于生成 16-bit PCM 或 32-bit IEEE float 格式的 WAV
template <typename WavDataType>
class Predictor : public PredictorInterface {
public:
bool Init(const std::string &AcousticModelPath,
const std::string &VocoderPath,
PowerMode cpuPowerMode,
int cpuThreadNum,
// WAV采样率必须与模型输出匹配
// 如果播放速度和音调异常,请修改采样率
// 常见采样率16000, 24000, 32000, 44100, 48000, 96000
uint32_t wavSampleRate) override {
// Release model if exists
ReleaseModel();
acoustic_model_predictor_ =
LoadModel(AcousticModelPath, cpuThreadNum, cpuPowerMode);
if (acoustic_model_predictor_ == nullptr) {
return false;
}
vocoder_predictor_ = LoadModel(VocoderPath, cpuThreadNum, cpuPowerMode);
if (vocoder_predictor_ == nullptr) {
return false;
}
wav_sample_rate_ = wavSampleRate;
return true;
}
virtual ~Predictor() {
ReleaseModel();
ReleaseWav();
}
std::shared_ptr<PaddlePredictor> LoadModel(
const std::string &modelPath,
int cpuThreadNum,
PowerMode cpuPowerMode) override {
if (modelPath.empty()) {
return nullptr;
}
// 设置MobileConfig
MobileConfig config;
config.set_model_from_file(modelPath);
config.set_threads(cpuThreadNum);
config.set_power_mode(cpuPowerMode);
return CreatePaddlePredictor<MobileConfig>(config);
}
void ReleaseModel() override {
acoustic_model_predictor_ = nullptr;
vocoder_predictor_ = nullptr;
}
bool RunModel(const std::vector<int64_t> &phones) override {
if (!IsLoaded()) {
return false;
}
// 计时开始
auto start = std::chrono::system_clock::now();
// 执行推理
VocoderOutputToWav(GetVocoderOutput(GetAcousticModelOutput(phones)));
// 计时结束
auto end = std::chrono::system_clock::now();
// 计算用时
std::chrono::duration<float> duration = end - start;
inference_time_ = duration.count() * 1000; // 单位:毫秒
return true;
}
std::unique_ptr<const Tensor> GetAcousticModelOutput(
const std::vector<int64_t> &phones) override {
auto phones_handle = acoustic_model_predictor_->GetInput(0);
phones_handle->Resize({static_cast<int64_t>(phones.size())});
phones_handle->CopyFromCpu(phones.data());
acoustic_model_predictor_->Run();
// 获取输出Tensor
auto am_output_handle = acoustic_model_predictor_->GetOutput(0);
// 打印输出Tensor的shape
std::cout << "Acoustic Model Output shape: ";
auto shape = am_output_handle->shape();
for (auto s : shape) {
std::cout << s << ", ";
}
std::cout << std::endl;
return am_output_handle;
}
std::unique_ptr<const Tensor> GetVocoderOutput(
std::unique_ptr<const Tensor> &&amOutput) override {
auto mel_handle = vocoder_predictor_->GetInput(0);
// [?, 80]
auto dims = amOutput->shape();
mel_handle->Resize(dims);
auto am_output_data = amOutput->mutable_data<float>();
mel_handle->CopyFromCpu(am_output_data);
vocoder_predictor_->Run();
// 获取输出Tensor
auto voc_output_handle = vocoder_predictor_->GetOutput(0);
// 打印输出Tensor的shape
std::cout << "Vocoder Output shape: ";
auto shape = voc_output_handle->shape();
for (auto s : shape) {
std::cout << s << ", ";
}
std::cout << std::endl;
return voc_output_handle;
}
void VocoderOutputToWav(
std::unique_ptr<const Tensor> &&vocOutput) override {
// 获取输出Tensor的数据
int64_t output_size = 1;
for (auto dim : vocOutput->shape()) {
output_size *= dim;
}
auto output_data = vocOutput->mutable_data<float>();
SaveFloatWav(output_data, output_size);
}
void SaveFloatWav(float *floatWav, int64_t size) override;
bool IsLoaded() override {
return acoustic_model_predictor_ != nullptr &&
vocoder_predictor_ != nullptr;
}
float GetInferenceTime() override { return inference_time_; }
const std::vector<WavDataType> &GetWav() { return wav_; }
int GetWavSize() override { return wav_.size() * sizeof(WavDataType); }
// 获取WAV持续时间单位毫秒
float GetWavDuration() override {
return static_cast<float>(GetWavSize()) / sizeof(WavDataType) /
static_cast<float>(wav_sample_rate_) * 1000;
}
// 获取RTF合成时间 / 音频时长)
float GetRTF() override { return GetInferenceTime() / GetWavDuration(); }
void ReleaseWav() override { wav_.clear(); }
bool WriteWavToFile(const std::string &wavPath) override {
std::ofstream fout(wavPath, std::ios::binary);
if (!fout.is_open()) {
return false;
}
// 写入头信息
WavHeader header;
header.audio_format = GetWavAudioFormat();
header.data_size = GetWavSize();
header.size = sizeof(header) - 8 + header.data_size;
header.sample_rate = wav_sample_rate_;
header.byte_rate = header.sample_rate * header.num_channels *
header.bits_per_sample / 8;
header.block_align = header.num_channels * header.bits_per_sample / 8;
fout.write(reinterpret_cast<const char *>(&header), sizeof(header));
// 写入wav数据
fout.write(reinterpret_cast<const char *>(wav_.data()),
header.data_size);
fout.close();
return true;
}
protected:
struct WavHeader {
// RIFF 头
char riff[4] = {'R', 'I', 'F', 'F'};
uint32_t size = 0;
char wave[4] = {'W', 'A', 'V', 'E'};
// FMT 头
char fmt[4] = {'f', 'm', 't', ' '};
uint32_t fmt_size = 16;
uint16_t audio_format = 0;
uint16_t num_channels = 1;
uint32_t sample_rate = 0;
uint32_t byte_rate = 0;
uint16_t block_align = 0;
uint16_t bits_per_sample = sizeof(WavDataType) * 8;
// DATA 头
char data[4] = {'d', 'a', 't', 'a'};
uint32_t data_size = 0;
};
enum WavAudioFormat {
WAV_FORMAT_16BIT_PCM = 1, // 16-bit PCM 格式
WAV_FORMAT_32BIT_FLOAT = 3 // 32-bit IEEE float 格式
};
protected:
// 返回值通过模板特化由 WavDataType 决定
inline uint16_t GetWavAudioFormat();
inline float Abs(float number) { return (number < 0) ? -number : number; }
protected:
float inference_time_ = 0;
uint32_t wav_sample_rate_ = 0;
std::vector<WavDataType> wav_;
std::shared_ptr<PaddlePredictor> acoustic_model_predictor_ = nullptr;
std::shared_ptr<PaddlePredictor> vocoder_predictor_ = nullptr;
};
template <>
uint16_t Predictor<int16_t>::GetWavAudioFormat() {
return Predictor::WAV_FORMAT_16BIT_PCM;
}
template <>
uint16_t Predictor<float>::GetWavAudioFormat() {
return Predictor::WAV_FORMAT_32BIT_FLOAT;
}
// 保存 16-bit PCM 格式 WAV
template <>
void Predictor<int16_t>::SaveFloatWav(float *floatWav, int64_t size) {
wav_.resize(size);
float maxSample = 0.01;
// 寻找最大采样值
for (int64_t i = 0; i < size; i++) {
float sample = Abs(floatWav[i]);
if (sample > maxSample) {
maxSample = sample;
}
}
// 把采样值缩放到 int_16 范围
for (int64_t i = 0; i < size; i++) {
wav_[i] = floatWav[i] * 32767.0f / maxSample;
}
}
// 保存 32-bit IEEE float 格式 WAV
template <>
void Predictor<float>::SaveFloatWav(float *floatWav, int64_t size) {
wav_.resize(size);
std::copy_n(floatWav, size, wav_.data());
}

@ -0,0 +1 @@
../../TTSCppFrontend/

@ -0,0 +1,162 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <front/front_interface.h>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <paddle_api.h>
#include <cstdlib>
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include "Predictor.hpp"
using namespace paddle::lite_api;
DEFINE_string(
sentence,
"你好,欢迎使用语音合成服务",
"Text to be synthesized (Chinese only. English will crash the program.)");
DEFINE_string(front_conf, "./front.conf", "Front configuration file");
DEFINE_string(acoustic_model,
"./models/cpu/fastspeech2_csmsc_arm.nb",
"Acoustic model .nb file");
DEFINE_string(vocoder,
"./models/cpu/fastspeech2_csmsc_arm.nb",
"vocoder .nb file");
DEFINE_string(output_wav, "./output/tts.wav", "Output WAV file");
DEFINE_string(wav_bit_depth,
"16",
"WAV bit depth, 16 (16-bit PCM) or 32 (32-bit IEEE float)");
DEFINE_string(wav_sample_rate,
"24000",
"WAV sample rate, should match the output of the vocoder");
DEFINE_string(cpu_thread, "1", "CPU thread numbers");
int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
PredictorInterface *predictor;
if (FLAGS_wav_bit_depth == "16") {
predictor = new Predictor<int16_t>();
} else if (FLAGS_wav_bit_depth == "32") {
predictor = new Predictor<float>();
} else {
LOG(ERROR) << "Unsupported WAV bit depth: " << FLAGS_wav_bit_depth;
return -1;
}
/////////////////////////// 前端:文本转音素 ///////////////////////////
// 实例化文本前端引擎
ppspeech::FrontEngineInterface *front_inst = nullptr;
front_inst = new ppspeech::FrontEngineInterface(FLAGS_front_conf);
if ((!front_inst) || (front_inst->init())) {
LOG(ERROR) << "Creater tts engine failed!";
if (front_inst != nullptr) {
delete front_inst;
}
front_inst = nullptr;
return -1;
}
std::wstring ws_sentence = ppspeech::utf8string2wstring(FLAGS_sentence);
// 繁体转简体
std::wstring sentence_simp;
front_inst->Trand2Simp(ws_sentence, &sentence_simp);
ws_sentence = sentence_simp;
std::string s_sentence;
std::vector<std::wstring> sentence_part;
std::vector<int> phoneids = {};
std::vector<int> toneids = {};
// 根据标点进行分句
LOG(INFO) << "Start to segment sentences by punctuation";
front_inst->SplitByPunc(ws_sentence, &sentence_part);
LOG(INFO) << "Segment sentences through punctuation successfully";
// 分句后获取音素id
LOG(INFO)
<< "Start to get the phoneme and tone id sequence of each sentence";
for (int i = 0; i < sentence_part.size(); i++) {
LOG(INFO) << "Raw sentence is: "
<< ppspeech::wstring2utf8string(sentence_part[i]);
front_inst->SentenceNormalize(&sentence_part[i]);
s_sentence = ppspeech::wstring2utf8string(sentence_part[i]);
LOG(INFO) << "After normalization sentence is: " << s_sentence;
if (0 != front_inst->GetSentenceIds(s_sentence, &phoneids, &toneids)) {
LOG(ERROR) << "TTS inst get sentence phoneids and toneids failed";
return -1;
}
}
LOG(INFO) << "The phoneids of the sentence is: "
<< limonp::Join(phoneids.begin(), phoneids.end(), " ");
LOG(INFO) << "The toneids of the sentence is: "
<< limonp::Join(toneids.begin(), toneids.end(), " ");
LOG(INFO) << "Get the phoneme id sequence of each sentence successfully";
/////////////////////////// 后端:音素转音频 ///////////////////////////
// WAV采样率必须与模型输出匹配
// 如果播放速度和音调异常,请修改采样率
// 常见采样率16000, 24000, 32000, 44100, 48000, 96000
const uint32_t wavSampleRate = std::stoul(FLAGS_wav_sample_rate);
// CPU线程数
const int cpuThreadNum = std::stol(FLAGS_cpu_thread);
// CPU电源模式
const PowerMode cpuPowerMode = PowerMode::LITE_POWER_HIGH;
if (!predictor->Init(FLAGS_acoustic_model,
FLAGS_vocoder,
cpuPowerMode,
cpuThreadNum,
wavSampleRate)) {
LOG(ERROR) << "predictor init failed" << std::endl;
return -1;
}
std::vector<int64_t> phones(phoneids.size());
std::transform(phoneids.begin(), phoneids.end(), phones.begin(), [](int x) {
return static_cast<int64_t>(x);
});
if (!predictor->RunModel(phones)) {
LOG(ERROR) << "predictor run model failed" << std::endl;
return -1;
}
LOG(INFO) << "Inference time: " << predictor->GetInferenceTime() << " ms, "
<< "WAV size (without header): " << predictor->GetWavSize()
<< " bytes, "
<< "WAV duration: " << predictor->GetWavDuration() << " ms, "
<< "RTF: " << predictor->GetRTF() << std::endl;
if (!predictor->WriteWavToFile(FLAGS_output_wav)) {
LOG(ERROR) << "write wav file failed" << std::endl;
return -1;
}
delete predictor;
return 0;
}

@ -0,0 +1 @@
TTSCppFrontend/third-party

@ -0,0 +1,2 @@
build/
dict/

@ -0,0 +1,63 @@
cmake_minimum_required(VERSION 3.10)
project(paddlespeech_tts_cpp)
########## Global Options ##########
option(WITH_FRONT_DEMO "Build front demo" ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(ABSL_PROPAGATE_CXX_STD ON)
########## Dependencies ##########
set(ENV{PKG_CONFIG_PATH} "${CMAKE_SOURCE_DIR}/third-party/build/lib/pkgconfig:${CMAKE_SOURCE_DIR}/third-party/build/lib64/pkgconfig")
find_package(PkgConfig REQUIRED)
# It is hard to load xxx-config.cmake in a custom location, so use pkgconfig instead.
pkg_check_modules(ABSL REQUIRED absl_strings IMPORTED_TARGET)
pkg_check_modules(GFLAGS REQUIRED gflags IMPORTED_TARGET)
pkg_check_modules(GLOG REQUIRED libglog IMPORTED_TARGET)
# load header-only libraries
include_directories(
${CMAKE_SOURCE_DIR}/third-party/build/src/cppjieba/include
${CMAKE_SOURCE_DIR}/third-party/build/src/limonp/include
)
find_package(Threads REQUIRED)
########## paddlespeech_tts_front ##########
include_directories(src)
file(GLOB FRONT_SOURCES
./src/base/*.cpp
./src/front/*.cpp
)
add_library(paddlespeech_tts_front STATIC ${FRONT_SOURCES})
target_link_libraries(
paddlespeech_tts_front
PUBLIC
PkgConfig::GFLAGS
PkgConfig::GLOG
PkgConfig::ABSL
Threads::Threads
)
########## tts_front_demo ##########
if (WITH_FRONT_DEMO)
file(GLOB FRONT_DEMO_SOURCES front_demo/*.cpp)
add_executable(tts_front_demo ${FRONT_DEMO_SOURCES})
target_include_directories(tts_front_demo PRIVATE ./front_demo)
target_link_libraries(tts_front_demo PRIVATE paddlespeech_tts_front)
endif (WITH_FRONT_DEMO)

@ -0,0 +1,56 @@
# PaddleSpeech TTS CPP Frontend
A TTS frontend that implements text-to-phoneme conversion.
Currently it only supports Chinese, any English word will crash the demo.
## Install Build Tools
```bash
# Ubuntu
sudo apt install build-essential cmake pkg-config
# CentOS
sudo yum groupinstall "Development Tools"
sudo yum install cmake
```
If your cmake version is too old, you can go here to download a precompiled new version: https://cmake.org/download/
## Build
```bash
# Build with all CPU cores
./build.sh
# Build with 1 core
./build.sh -j1
```
Dependent libraries will be automatically downloaded to the `third-party/build` folder.
If the download speed is too slow, you can open [third-party/CMakeLists.txt](third-party/CMakeLists.txt) and modify `GIT_REPOSITORY` URLs.
## Download dictionary files
```bash
./download.sh
```
## Run
You can change `--phone2id_path` in `./front_demo/front.conf` to the `phone_id_map.txt` of your own acoustic model.
```bash
./run_front_demo.sh
./run_front_demo.sh --help
./run_front_demo.sh --sentence "这是语音合成服务的文本前端,用于将文本转换为音素序号数组。"
./run_front_demo.sh --front_conf ./front_demo/front.conf --sentence "你还需要一个语音合成后端才能将其转换为实际的声音。"
```
## Clean
```bash
./clean.sh
```
The folders `front_demo/dict`, `build` and `third-party/build` will be deleted.

@ -0,0 +1,20 @@
#!/bin/bash
set -e
set -x
cd "$(dirname "$(realpath "$0")")"
cd ./third-party
mkdir -p build
cd build
cmake ..
if [ "$*" = "" ]; then
make -j$(nproc)
else
make "$@"
fi
echo "Done."

@ -0,0 +1,21 @@
#!/bin/bash
set -e
set -x
cd "$(dirname "$(realpath "$0")")"
echo "************* Download & Build Dependencies *************"
./build-depends.sh "$@"
echo "************* Build Front Lib and Demo *************"
mkdir -p ./build
cd ./build
cmake ..
if [ "$*" = "" ]; then
make -j$(nproc)
else
make "$@"
fi
echo "Done."

@ -0,0 +1,10 @@
#!/bin/bash
set -e
set -x
cd "$(dirname "$(realpath "$0")")"
rm -rf "./front_demo/dict"
rm -rf "./build"
rm -rf "./third-party/build"
echo "Done."

@ -0,0 +1,62 @@
#!/bin/bash
set -e
cd "$(dirname "$(realpath "$0")")"
download() {
file="$1"
url="$2"
md5="$3"
dir="$4"
cd "$dir"
if [ -f "$file" ] && [ "$(md5sum "$file" | awk '{ print $1 }')" = "$md5" ]; then
echo "File $file (MD5: $md5) has been downloaded."
else
echo "Downloading $file..."
wget -O "$file" "$url"
# MD5 verify
fileMd5="$(md5sum "$file" | awk '{ print $1 }')"
if [ "$fileMd5" == "$md5" ]; then
echo "File $file (MD5: $md5) has been downloaded."
else
echo "MD5 mismatch, file may be corrupt"
echo "$file MD5: $fileMd5, it should be $md5"
fi
fi
echo "Extracting $file..."
echo '-----------------------'
tar -vxf "$file"
echo '======================='
}
########################################
DIST_DIR="$PWD/front_demo/dict"
mkdir -p "$DIST_DIR"
download 'fastspeech2_nosil_baker_ckpt_0.4.tar.gz' \
'https://paddlespeech.bj.bcebos.com/t2s/text_frontend/fastspeech2_nosil_baker_ckpt_0.4.tar.gz' \
'7bf1bab1737375fa123c413eb429c573' \
"$DIST_DIR"
download 'speedyspeech_nosil_baker_ckpt_0.5.tar.gz' \
'https://paddlespeech.bj.bcebos.com/t2s/text_frontend/speedyspeech_nosil_baker_ckpt_0.5.tar.gz' \
'0b7754b21f324789aef469c61f4d5b8f' \
"$DIST_DIR"
download 'jieba.tar.gz' \
'https://paddlespeech.bj.bcebos.com/t2s/text_frontend/jieba.tar.gz' \
'6d30f426bd8c0025110a483f051315ca' \
"$DIST_DIR"
download 'tranditional_to_simplified.tar.gz' \
'https://paddlespeech.bj.bcebos.com/t2s/text_frontend/tranditional_to_simplified.tar.gz' \
'258f5b59d5ebfe96d02007ca1d274a7f' \
"$DIST_DIR"
echo "Done."

@ -0,0 +1,21 @@
# jieba conf
--jieba_dict_path=./front_demo/dict/jieba/jieba.dict.utf8
--jieba_hmm_path=./front_demo/dict/jieba/hmm_model.utf8
--jieba_user_dict_path=./front_demo/dict/jieba/user.dict.utf8
--jieba_idf_path=./front_demo/dict/jieba/idf.utf8
--jieba_stop_word_path=./front_demo/dict/jieba/stop_words.utf8
# dict conf fastspeech2_0.4
--separate_tone=false
--word2phone_path=./front_demo/dict/fastspeech2_nosil_baker_ckpt_0.4/word2phone_fs2.dict
--phone2id_path=./front_demo/dict/fastspeech2_nosil_baker_ckpt_0.4/phone_id_map.txt
--tone2id_path=./front_demo/dict/fastspeech2_nosil_baker_ckpt_0.4/word2phone_fs2.dict
# dict conf speedyspeech_0.5
#--separate_tone=true
#--word2phone_path=./front_demo/dict/speedyspeech_nosil_baker_ckpt_0.5/word2phone.dict
#--phone2id_path=./front_demo/dict/speedyspeech_nosil_baker_ckpt_0.5/phone_id_map.txt
#--tone2id_path=./front_demo/dict/speedyspeech_nosil_baker_ckpt_0.5/tone_id_map.txt
# dict of tranditional_to_simplified
--trand2simpd_path=./front_demo/dict/tranditional_to_simplified/trand2simp.txt

@ -0,0 +1,79 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <map>
#include <string>
#include "front/front_interface.h"
DEFINE_string(sentence, "你好,欢迎使用语音合成服务", "Text to be synthesized");
DEFINE_string(front_conf, "./front_demo/front.conf", "Front conf file");
// DEFINE_string(separate_tone, "true", "If true, get phoneids and tonesid");
int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
// 实例化文本前端引擎
ppspeech::FrontEngineInterface* front_inst = nullptr;
front_inst = new ppspeech::FrontEngineInterface(FLAGS_front_conf);
if ((!front_inst) || (front_inst->init())) {
LOG(ERROR) << "Creater tts engine failed!";
if (front_inst != nullptr) {
delete front_inst;
}
front_inst = nullptr;
return -1;
}
std::wstring ws_sentence = ppspeech::utf8string2wstring(FLAGS_sentence);
// 繁体转简体
std::wstring sentence_simp;
front_inst->Trand2Simp(ws_sentence, &sentence_simp);
ws_sentence = sentence_simp;
std::string s_sentence;
std::vector<std::wstring> sentence_part;
std::vector<int> phoneids = {};
std::vector<int> toneids = {};
// 根据标点进行分句
LOG(INFO) << "Start to segment sentences by punctuation";
front_inst->SplitByPunc(ws_sentence, &sentence_part);
LOG(INFO) << "Segment sentences through punctuation successfully";
// 分句后获取音素id
LOG(INFO)
<< "Start to get the phoneme and tone id sequence of each sentence";
for (int i = 0; i < sentence_part.size(); i++) {
LOG(INFO) << "Raw sentence is: "
<< ppspeech::wstring2utf8string(sentence_part[i]);
front_inst->SentenceNormalize(&sentence_part[i]);
s_sentence = ppspeech::wstring2utf8string(sentence_part[i]);
LOG(INFO) << "After normalization sentence is: " << s_sentence;
if (0 != front_inst->GetSentenceIds(s_sentence, &phoneids, &toneids)) {
LOG(ERROR) << "TTS inst get sentence phoneids and toneids failed";
return -1;
}
}
LOG(INFO) << "The phoneids of the sentence is: "
<< limonp::Join(phoneids.begin(), phoneids.end(), " ");
LOG(INFO) << "The toneids of the sentence is: "
<< limonp::Join(toneids.begin(), toneids.end(), " ");
LOG(INFO) << "Get the phoneme id sequence of each sentence successfully";
return EXIT_SUCCESS;
}

@ -0,0 +1,111 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import configparser
from paddlespeech.t2s.frontend.zh_frontend import Frontend
def get_phone(frontend,
word,
merge_sentences=True,
print_info=False,
robot=False,
get_tone_ids=False):
phonemes = frontend.get_phonemes(word, merge_sentences, print_info, robot)
# Some optimizations
phones, tones = frontend._get_phone_tone(phonemes[0], get_tone_ids)
#print(type(phones), phones)
#print(type(tones), tones)
return phones, tones
def gen_word2phone_dict(frontend,
jieba_words_dict,
word2phone_dict,
get_tone=False):
with open(jieba_words_dict, "r") as f1, open(word2phone_dict, "w+") as f2:
for line in f1.readlines():
word = line.split(" ")[0]
phone, tone = get_phone(frontend, word, get_tone_ids=get_tone)
phone_str = ""
if tone:
assert (len(phone) == len(tone))
for i in range(len(tone)):
phone_tone = phone[i] + tone[i]
phone_str += (" " + phone_tone)
phone_str = phone_str.strip("sp0").strip(" ")
else:
for x in phone:
phone_str += (" " + x)
phone_str = phone_str.strip("sp").strip(" ")
print(phone_str)
f2.write(word + " " + phone_str + "\n")
print("Generate word2phone dict successfully.")
def main():
parser = argparse.ArgumentParser(description="Generate dictionary")
parser.add_argument(
"--config", type=str, default="./config.ini", help="config file.")
parser.add_argument(
"--am_type",
type=str,
default="fastspeech2",
help="fastspeech2 or speedyspeech")
args = parser.parse_args()
# Read config
cf = configparser.ConfigParser()
cf.read(args.config)
jieba_words_dict_file = cf.get("jieba",
"jieba_words_dict") # get words dict
am_type = args.am_type
if (am_type == "fastspeech2"):
phone2id_dict_file = cf.get(am_type, "phone2id_dict")
word2phone_dict_file = cf.get(am_type, "word2phone_dict")
frontend = Frontend(phone_vocab_path=phone2id_dict_file)
print("frontend done!")
gen_word2phone_dict(
frontend,
jieba_words_dict_file,
word2phone_dict_file,
get_tone=False)
elif (am_type == "speedyspeech"):
phone2id_dict_file = cf.get(am_type, "phone2id_dict")
tone2id_dict_file = cf.get(am_type, "tone2id_dict")
word2phone_dict_file = cf.get(am_type, "word2phone_dict")
frontend = Frontend(
phone_vocab_path=phone2id_dict_file,
tone_vocab_path=tone2id_dict_file)
print("frontend done!")
gen_word2phone_dict(
frontend,
jieba_words_dict_file,
word2phone_dict_file,
get_tone=True)
else:
print("Please set correct am type, fastspeech2 or speedyspeech.")
if __name__ == "__main__":
main()

@ -0,0 +1,35 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
PHONESFILE = "./dict/phones.txt"
PHONES_ID_FILE = "./dict/phonesid.dict"
TONESFILE = "./dict/tones.txt"
TONES_ID_FILE = "./dict/tonesid.dict"
def GenIdFile(file, idfile):
id = 2
with open(file, 'r') as f1, open(idfile, "w+") as f2:
f2.write("<pad> 0\n")
f2.write("<unk> 1\n")
for line in f1.readlines():
phone = line.strip()
print(phone + " " + str(id) + "\n")
f2.write(phone + " " + str(id) + "\n")
id += 1
if __name__ == "__main__":
GenIdFile(PHONESFILE, PHONES_ID_FILE)
GenIdFile(TONESFILE, TONES_ID_FILE)

@ -0,0 +1,55 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from pypinyin import lazy_pinyin
from pypinyin import Style
worddict = "./dict/jieba_part.dict.utf8"
newdict = "./dict/word_phones.dict"
def GenPhones(initials, finals, separate=True):
phones = []
for c, v in zip(initials, finals):
if re.match(r'i\d', v):
if c in ['z', 'c', 's']:
v = re.sub('i', 'ii', v)
elif c in ['zh', 'ch', 'sh', 'r']:
v = re.sub('i', 'iii', v)
if c:
if separate is True:
phones.append(c + '0')
elif separate is False:
phones.append(c)
else:
print("Not sure whether phone and tone need to be separated")
if v:
phones.append(v)
return phones
with open(worddict, "r") as f1, open(newdict, "w+") as f2:
for line in f1.readlines():
word = line.split(" ")[0]
initials = lazy_pinyin(
word, neutral_tone_with_five=True, style=Style.INITIALS)
finals = lazy_pinyin(
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
phones = GenPhones(initials, finals, True)
temp = " ".join(phones)
f2.write(word + " " + temp + "\n")

@ -0,0 +1,7 @@
#!/bin/bash
set -e
set -x
cd "$(dirname "$(realpath "$0")")"
./build/tts_front_demo "$@"

@ -0,0 +1,28 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/type_conv.h"
namespace ppspeech {
// wstring to string
std::string wstring2utf8string(const std::wstring& str) {
static std::wstring_convert<std::codecvt_utf8<wchar_t>> strCnv;
return strCnv.to_bytes(str);
}
// string to wstring
std::wstring utf8string2wstring(const std::string& str) {
static std::wstring_convert<std::codecvt_utf8<wchar_t>> strCnv;
return strCnv.from_bytes(str);
}
} // namespace ppspeech

@ -0,0 +1,31 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef BASE_TYPE_CONVC_H
#define BASE_TYPE_CONVC_H
#include <codecvt>
#include <locale>
#include <string>
namespace ppspeech {
// wstring to string
std::string wstring2utf8string(const std::wstring& str);
// string to wstring
std::wstring utf8string2wstring(const std::string& str);
}
#endif // BASE_TYPE_CONVC_H

File diff suppressed because it is too large Load Diff

@ -0,0 +1,198 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_TTS_SERVING_FRONT_FRONT_INTERFACE_H
#define PADDLE_TTS_SERVING_FRONT_FRONT_INTERFACE_H
#include <glog/logging.h>
#include <fstream>
#include <map>
#include <memory>
#include <string>
//#include "utils/dir_utils.h"
#include <cppjieba/Jieba.hpp>
#include "absl/strings/str_split.h"
#include "front/text_normalize.h"
namespace ppspeech {
class FrontEngineInterface : public TextNormalizer {
public:
explicit FrontEngineInterface(std::string conf) : _conf_file(conf) {
TextNormalizer();
_jieba = nullptr;
_initialed = false;
init();
}
int init();
~FrontEngineInterface() {}
// 读取配置文件
int ReadConfFile();
// 简体转繁体
int Trand2Simp(const std::wstring &sentence, std::wstring *sentence_simp);
// 生成字典
int GenDict(const std::string &file,
std::map<std::string, std::string> *map);
// 由 词+词性的分词结果转为仅包含词的结果
int GetSegResult(std::vector<std::pair<std::string, std::string>> *seg,
std::vector<std::string> *seg_words);
// 生成句子的音素音调id。如果音素和音调未分开则 toneids
// 为空fastspeech2反之则不为空(speedyspeech)
int GetSentenceIds(const std::string &sentence,
std::vector<int> *phoneids,
std::vector<int> *toneids);
// 根据分词结果获取词的音素音调id并对读音进行适当修改
// (ModifyTone)。如果音素和音调未分开,则 toneids
// 为空fastspeech2反之则不为空(speedyspeech)
int GetWordsIds(
const std::vector<std::pair<std::string, std::string>> &cut_result,
std::vector<int> *phoneids,
std::vector<int> *toneids);
// 结巴分词生成包含词和词性的分词结果,再对分词结果进行适当修改
// (MergeforModify)
int Cut(const std::string &sentence,
std::vector<std::pair<std::string, std::string>> *cut_result);
// 字词到音素的映射,查找字典
int GetPhone(const std::string &word, std::string *phone);
// 音素到音素id
int Phone2Phoneid(const std::string &phone,
std::vector<int> *phoneid,
std::vector<int> *toneids);
// 根据韵母判断该词中每个字的读音都为第三声。true表示词中每个字都是第三声
bool AllToneThree(const std::vector<std::string> &finals);
// 判断词是否是叠词
bool IsReduplication(const std::string &word);
// 获取每个字词的声母韵母列表
int GetInitialsFinals(const std::string &word,
std::vector<std::string> *word_initials,
std::vector<std::string> *word_finals);
// 获取每个字词的韵母列表
int GetFinals(const std::string &word,
std::vector<std::string> *word_finals);
// 整个词转成向量形式,向量的每个元素对应词的一个字
int Word2WordVec(const std::string &word,
std::vector<std::wstring> *wordvec);
// 将整个词重新进行 full cut分词后各个词会在词典中
int SplitWord(const std::string &word,
std::vector<std::string> *fullcut_word);
// 对分词结果进行处理:对包含“不”字的分词结果进行整理
std::vector<std::pair<std::string, std::string>> MergeBu(
std::vector<std::pair<std::string, std::string>> *seg_result);
// 对分词结果进行处理:对包含“一”字的分词结果进行整理
std::vector<std::pair<std::string, std::string>> Mergeyi(
std::vector<std::pair<std::string, std::string>> *seg_result);
// 对分词结果进行处理:对前后相同的两个字进行合并
std::vector<std::pair<std::string, std::string>> MergeReduplication(
std::vector<std::pair<std::string, std::string>> *seg_result);
// 对一个词和后一个词他们的读音均为第三声的两个词进行合并
std::vector<std::pair<std::string, std::string>> MergeThreeTones(
std::vector<std::pair<std::string, std::string>> *seg_result);
// 对一个词的最后一个读音和后一个词的第一个读音为第三声的两个词进行合并
std::vector<std::pair<std::string, std::string>> MergeThreeTones2(
std::vector<std::pair<std::string, std::string>> *seg_result);
// 对分词结果进行处理:对包含“儿”字的分词结果进行整理
std::vector<std::pair<std::string, std::string>> MergeEr(
std::vector<std::pair<std::string, std::string>> *seg_result);
// 对分词结果进行处理、修改
int MergeforModify(
std::vector<std::pair<std::string, std::string>> *seg_result,
std::vector<std::pair<std::string, std::string>> *merge_seg_result);
// 对包含“不”字的相关词音调进行修改
int BuSandi(const std::string &word, std::vector<std::string> *finals);
// 对包含“一”字的相关词音调进行修改
int YiSandhi(const std::string &word, std::vector<std::string> *finals);
// 对一些特殊词(包括量词,语助词等)的相关词音调进行修改
int NeuralSandhi(const std::string &word,
const std::string &pos,
std::vector<std::string> *finals);
// 对包含第三声的相关词音调进行修改
int ThreeSandhi(const std::string &word, std::vector<std::string> *finals);
// 对字词音调进行处理、修改
int ModifyTone(const std::string &word,
const std::string &pos,
std::vector<std::string> *finals);
// 对儿化音进行处理
std::vector<std::vector<std::string>> MergeErhua(
const std::vector<std::string> &initials,
const std::vector<std::string> &finals,
const std::string &word,
const std::string &pos);
private:
bool _initialed;
cppjieba::Jieba *_jieba;
std::vector<std::string> _punc;
std::vector<std::string> _punc_omit;
std::string _conf_file;
std::map<std::string, std::string> conf_map;
std::map<std::string, std::string> word_phone_map;
std::map<std::string, std::string> phone_id_map;
std::map<std::string, std::string> tone_id_map;
std::map<std::string, std::string> trand_simp_map;
std::string _jieba_dict_path;
std::string _jieba_hmm_path;
std::string _jieba_user_dict_path;
std::string _jieba_idf_path;
std::string _jieba_stop_word_path;
std::string _separate_tone;
std::string _word2phone_path;
std::string _phone2id_path;
std::string _tone2id_path;
std::string _trand2simp_path;
std::vector<std::string> must_erhua;
std::vector<std::string> not_erhua;
std::vector<std::string> must_not_neural_tone_words;
std::vector<std::string> must_neural_tone_words;
};
} // namespace ppspeech
#endif

@ -0,0 +1,542 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "front/text_normalize.h"
namespace ppspeech {
// 初始化 digits_map and unit_map
int TextNormalizer::InitMap() {
digits_map["0"] = "";
digits_map["1"] = "";
digits_map["2"] = "";
digits_map["3"] = "";
digits_map["4"] = "";
digits_map["5"] = "";
digits_map["6"] = "";
digits_map["7"] = "";
digits_map["8"] = "";
digits_map["9"] = "";
units_map[1] = "";
units_map[2] = "";
units_map[3] = "";
units_map[4] = "";
units_map[8] = "亿";
return 0;
}
// 替换
int TextNormalizer::Replace(std::wstring *sentence,
const int &pos,
const int &len,
const std::wstring &repstr) {
// 删除原来的
sentence->erase(pos, len);
// 插入新的
sentence->insert(pos, repstr);
return 0;
}
// 根据标点符号切分句子
int TextNormalizer::SplitByPunc(const std::wstring &sentence,
std::vector<std::wstring> *sentence_part) {
std::wstring temp = sentence;
std::wregex reg(L"[:,;。?!,;?!]");
std::wsmatch match;
while (std::regex_search(temp, match, reg)) {
sentence_part->push_back(
temp.substr(0, match.position(0) + match.length(0)));
Replace(&temp, 0, match.position(0) + match.length(0), L"");
}
// 如果最后没有标点符号
if (temp != L"") {
sentence_part->push_back(temp);
}
return 0;
}
// 数字转文本10200 - > 一万零二百
std::string TextNormalizer::CreateTextValue(const std::string &num_str,
bool use_zero) {
std::string num_lstrip =
std::string(absl::StripPrefix(num_str, "0")).data();
int len = num_lstrip.length();
if (len == 0) {
return "";
} else if (len == 1) {
if (use_zero && (len < num_str.length())) {
return digits_map["0"] + digits_map[num_lstrip];
} else {
return digits_map[num_lstrip];
}
} else {
int largest_unit = 0; // 最大单位
std::string first_part;
std::string second_part;
if (len > 1 && len <= 2) {
largest_unit = 1;
} else if (len > 2 && len <= 3) {
largest_unit = 2;
} else if (len > 3 && len <= 4) {
largest_unit = 3;
} else if (len > 4 && len <= 8) {
largest_unit = 4;
} else if (len > 8) {
largest_unit = 8;
}
first_part = num_str.substr(0, num_str.length() - largest_unit);
second_part = num_str.substr(num_str.length() - largest_unit);
return CreateTextValue(first_part, use_zero) + units_map[largest_unit] +
CreateTextValue(second_part, use_zero);
}
}
// 数字一个一个对应,可直接用于年份,电话,手机,
std::string TextNormalizer::SingleDigit2Text(const std::string &num_str,
bool alt_one) {
std::string text = "";
if (alt_one) {
digits_map["1"] = "";
} else {
digits_map["1"] = "";
}
for (size_t i = 0; i < num_str.size(); i++) {
std::string num_int(1, num_str[i]);
if (digits_map.find(num_int) == digits_map.end()) {
LOG(ERROR) << "digits_map doesn't have key: " << num_int;
}
text += digits_map[num_int];
}
return text;
}
std::string TextNormalizer::SingleDigit2Text(const std::wstring &num,
bool alt_one) {
std::string num_str = wstring2utf8string(num);
return SingleDigit2Text(num_str, alt_one);
}
// 数字整体对应,可直接用于月份,日期,数值整数部分
std::string TextNormalizer::MultiDigit2Text(const std::string &num_str,
bool alt_one,
bool use_zero) {
LOG(INFO) << "aaaaaaaaaaaaaaaa: " << alt_one << use_zero;
if (alt_one) {
digits_map["1"] = "";
} else {
digits_map["1"] = "";
}
std::wstring result =
utf8string2wstring(CreateTextValue(num_str, use_zero));
std::wstring result_0(1, result[0]);
std::wstring result_1(1, result[1]);
// 一十八 --> 十八
if ((result_0 == utf8string2wstring(digits_map["1"])) &&
(result_1 == utf8string2wstring(units_map[1]))) {
return wstring2utf8string(result.substr(1, result.length()));
} else {
return wstring2utf8string(result);
}
}
std::string TextNormalizer::MultiDigit2Text(const std::wstring &num,
bool alt_one,
bool use_zero) {
std::string num_str = wstring2utf8string(num);
return MultiDigit2Text(num_str, alt_one, use_zero);
}
// 数字转文本,包括整数和小数
std::string TextNormalizer::Digits2Text(const std::string &num_str) {
std::string text;
std::vector<std::string> integer_decimal;
integer_decimal = absl::StrSplit(num_str, ".");
if (integer_decimal.size() == 1) { // 整数
text = MultiDigit2Text(integer_decimal[0]);
} else if (integer_decimal.size() == 2) { // 小数
if (integer_decimal[0] == "") { // 无整数的小数类型,例如:.22
text = "" +
SingleDigit2Text(
std::string(absl::StripSuffix(integer_decimal[1], "0"))
.data());
} else { // 常规小数类型例如12.34
text = MultiDigit2Text(integer_decimal[0]) + "" +
SingleDigit2Text(
std::string(absl::StripSuffix(integer_decimal[1], "0"))
.data());
}
} else {
return "The value does not conform to the numeric format";
}
return text;
}
std::string TextNormalizer::Digits2Text(const std::wstring &num) {
std::string num_str = wstring2utf8string(num);
return Digits2Text(num_str);
}
// 日期2021年8月18日 --> 二零二一年八月十八日
int TextNormalizer::ReData(std::wstring *sentence) {
std::wregex reg(
L"(\\d{4}|\\d{2})年((0?[1-9]|1[0-2])月)?(((0?[1-9])|((1|2)[0-9])|30|31)"
L"([日号]))?");
std::wsmatch match;
std::string rep;
while (std::regex_search(*sentence, match, reg)) {
rep = "";
rep += SingleDigit2Text(match[1]) + "";
if (match[3] != L"") {
rep += MultiDigit2Text(match[3], false, false) + "";
}
if (match[5] != L"") {
rep += MultiDigit2Text(match[5], false, false) +
wstring2utf8string(match[9]);
}
Replace(sentence,
match.position(0),
match.length(0),
utf8string2wstring(rep));
}
return 0;
}
// XX-XX-XX or XX/XX/XX 例如2021/08/18 --> 二零二一年八月十八日
int TextNormalizer::ReData2(std::wstring *sentence) {
std::wregex reg(
L"(\\d{4})([- /.])(0[1-9]|1[012])\\2(0[1-9]|[12][0-9]|3[01])");
std::wsmatch match;
std::string rep;
while (std::regex_search(*sentence, match, reg)) {
rep = "";
rep += (SingleDigit2Text(match[1]) + "");
rep += (MultiDigit2Text(match[3], false, false) + "");
rep += (MultiDigit2Text(match[4], false, false) + "");
Replace(sentence,
match.position(0),
match.length(0),
utf8string2wstring(rep));
}
return 0;
}
// XX:XX:XX 09:09:02 --> 九点零九分零二秒
int TextNormalizer::ReTime(std::wstring *sentence) {
std::wregex reg(L"([0-1]?[0-9]|2[0-3]):([0-5][0-9])(:([0-5][0-9]))?");
std::wsmatch match;
std::string rep;
while (std::regex_search(*sentence, match, reg)) {
rep = "";
rep += (MultiDigit2Text(match[1], false, false) + "");
if (absl::StartsWith(wstring2utf8string(match[2]), "0")) {
rep += "";
}
rep += (MultiDigit2Text(match[2]) + "");
if (absl::StartsWith(wstring2utf8string(match[4]), "0")) {
rep += "";
}
rep += (MultiDigit2Text(match[4]) + "");
Replace(sentence,
match.position(0),
match.length(0),
utf8string2wstring(rep));
}
return 0;
}
// 温度,例如:-24.3℃ --> 零下二十四点三度
int TextNormalizer::ReTemperature(std::wstring *sentence) {
std::wregex reg(L"(-?)(\\d+(\\.\\d+)?)(°C|℃|度|摄氏度)");
std::wsmatch match;
std::string rep;
std::string sign;
std::vector<std::string> integer_decimal;
std::string unit;
while (std::regex_search(*sentence, match, reg)) {
match[1] == L"-" ? sign = "" : sign = "";
match[4] == L"摄氏度" ? unit = "摄氏度" : unit = "";
rep = sign + Digits2Text(match[2]) + unit;
Replace(sentence,
match.position(0),
match.length(0),
utf8string2wstring(rep));
}
return 0;
}
// 分数,例如: 1/3 --> 三分之一
int TextNormalizer::ReFrac(std::wstring *sentence) {
std::wregex reg(L"(-?)(\\d+)/(\\d+)");
std::wsmatch match;
std::string sign;
std::string rep;
while (std::regex_search(*sentence, match, reg)) {
match[1] == L"-" ? sign = "" : sign = "";
rep = sign + MultiDigit2Text(match[3]) + "分之" +
MultiDigit2Text(match[2]);
Replace(sentence,
match.position(0),
match.length(0),
utf8string2wstring(rep));
}
return 0;
}
// 百分数例如45.5% --> 百分之四十五点五
int TextNormalizer::RePercentage(std::wstring *sentence) {
std::wregex reg(L"(-?)(\\d+(\\.\\d+)?)%");
std::wsmatch match;
std::string sign;
std::string rep;
std::vector<std::string> integer_decimal;
while (std::regex_search(*sentence, match, reg)) {
match[1] == L"-" ? sign = "" : sign = "";
rep = sign + "百分之" + Digits2Text(match[2]);
Replace(sentence,
match.position(0),
match.length(0),
utf8string2wstring(rep));
}
return 0;
}
// 手机号码,例如:+86 18883862235 --> 八六幺八八八三八六二二三五
int TextNormalizer::ReMobilePhone(std::wstring *sentence) {
std::wregex reg(
L"(\\d)?((\\+?86 ?)?1([38]\\d|5[0-35-9]|7[678]|9[89])\\d{8})(\\d)?");
std::wsmatch match;
std::string rep;
std::vector<std::string> country_phonenum;
while (std::regex_search(*sentence, match, reg)) {
country_phonenum = absl::StrSplit(wstring2utf8string(match[0]), "+");
rep = "";
for (int i = 0; i < country_phonenum.size(); i++) {
LOG(INFO) << country_phonenum[i];
rep += SingleDigit2Text(country_phonenum[i], true);
}
Replace(sentence,
match.position(0),
match.length(0),
utf8string2wstring(rep));
}
return 0;
}
// 座机号码例如010-51093154 --> 零幺零五幺零九三幺五四
int TextNormalizer::RePhone(std::wstring *sentence) {
std::wregex reg(
L"(\\d)?((0(10|2[1-3]|[3-9]\\d{2})-?)?[1-9]\\d{6,7})(\\d)?");
std::wsmatch match;
std::vector<std::string> zone_phonenum;
std::string rep;
while (std::regex_search(*sentence, match, reg)) {
rep = "";
zone_phonenum = absl::StrSplit(wstring2utf8string(match[0]), "-");
for (int i = 0; i < zone_phonenum.size(); i++) {
rep += SingleDigit2Text(zone_phonenum[i], true);
}
Replace(sentence,
match.position(0),
match.length(0),
utf8string2wstring(rep));
}
return 0;
}
// 范围例如60~90 --> 六十到九十
int TextNormalizer::ReRange(std::wstring *sentence) {
std::wregex reg(
L"((-?)((\\d+)(\\.\\d+)?)|(\\.(\\d+)))[-~]((-?)((\\d+)(\\.\\d+)?)|(\\.("
L"\\d+)))");
std::wsmatch match;
std::string rep;
std::string sign1;
std::string sign2;
while (std::regex_search(*sentence, match, reg)) {
rep = "";
match[2] == L"-" ? sign1 = "" : sign1 = "";
if (match[6] != L"") {
rep += sign1 + Digits2Text(match[6]) + "";
} else {
rep += sign1 + Digits2Text(match[3]) + "";
}
match[9] == L"-" ? sign2 = "" : sign2 = "";
if (match[13] != L"") {
rep += sign2 + Digits2Text(match[13]);
} else {
rep += sign2 + Digits2Text(match[10]);
}
Replace(sentence,
match.position(0),
match.length(0),
utf8string2wstring(rep));
}
return 0;
}
// 带负号的整数,例如:-10 --> 负十
int TextNormalizer::ReInterger(std::wstring *sentence) {
std::wregex reg(L"(-)(\\d+)");
std::wsmatch match;
std::string rep;
while (std::regex_search(*sentence, match, reg)) {
rep = "" + MultiDigit2Text(match[2]);
Replace(sentence,
match.position(0),
match.length(0),
utf8string2wstring(rep));
}
return 0;
}
// 纯小数
int TextNormalizer::ReDecimalNum(std::wstring *sentence) {
std::wregex reg(L"(-?)((\\d+)(\\.\\d+))|(\\.(\\d+))");
std::wsmatch match;
std::string sign;
std::string rep;
// std::vector<std::string> integer_decimal;
while (std::regex_search(*sentence, match, reg)) {
match[1] == L"-" ? sign = "" : sign = "";
if (match[5] != L"") {
rep = sign + Digits2Text(match[5]);
} else {
rep = sign + Digits2Text(match[2]);
}
Replace(sentence,
match.position(0),
match.length(0),
utf8string2wstring(rep));
}
return 0;
}
// 正整数 + 量词
int TextNormalizer::RePositiveQuantifiers(std::wstring *sentence) {
std::wstring common_quantifiers =
L"(朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|"
L"担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|"
L"溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|"
L"本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|"
L"毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|"
L"合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|"
L"卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|纪|岁|世|更|"
L"夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|"
L"元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|"
L"百万|万|千|百|)块|角|毛|分)";
std::wregex reg(L"(\\d+)([多余几])?" + common_quantifiers);
std::wsmatch match;
std::string rep;
while (std::regex_search(*sentence, match, reg)) {
rep = MultiDigit2Text(match[1]);
Replace(sentence,
match.position(1),
match.length(1),
utf8string2wstring(rep));
}
return 0;
}
// 编号类数字,例如: 89757 --> 八九七五七
int TextNormalizer::ReDefalutNum(std::wstring *sentence) {
std::wregex reg(L"\\d{3}\\d*");
std::wsmatch match;
while (std::regex_search(*sentence, match, reg)) {
Replace(sentence,
match.position(0),
match.length(0),
utf8string2wstring(SingleDigit2Text(match[0])));
}
return 0;
}
int TextNormalizer::ReNumber(std::wstring *sentence) {
std::wregex reg(L"(-?)((\\d+)(\\.\\d+)?)|(\\.(\\d+))");
std::wsmatch match;
std::string sign;
std::string rep;
while (std::regex_search(*sentence, match, reg)) {
match[1] == L"-" ? sign = "" : sign = "";
if (match[5] != L"") {
rep = sign + Digits2Text(match[5]);
} else {
rep = sign + Digits2Text(match[2]);
}
Replace(sentence,
match.position(0),
match.length(0),
utf8string2wstring(rep));
}
return 0;
}
// 整体正则,按顺序
int TextNormalizer::SentenceNormalize(std::wstring *sentence) {
ReData(sentence);
ReData2(sentence);
ReTime(sentence);
ReTemperature(sentence);
ReFrac(sentence);
RePercentage(sentence);
ReMobilePhone(sentence);
RePhone(sentence);
ReRange(sentence);
ReInterger(sentence);
ReDecimalNum(sentence);
RePositiveQuantifiers(sentence);
ReDefalutNum(sentence);
ReNumber(sentence);
return 0;
}
} // namespace ppspeech

@ -0,0 +1,77 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_TTS_SERVING_FRONT_TEXT_NORMALIZE_H
#define PADDLE_TTS_SERVING_FRONT_TEXT_NORMALIZE_H
#include <glog/logging.h>
#include <codecvt>
#include <map>
#include <regex>
#include <string>
#include "absl/strings/str_split.h"
#include "absl/strings/strip.h"
#include "base/type_conv.h"
namespace ppspeech {
class TextNormalizer {
public:
TextNormalizer() { InitMap(); }
~TextNormalizer() {}
int InitMap();
int Replace(std::wstring *sentence,
const int &pos,
const int &len,
const std::wstring &repstr);
int SplitByPunc(const std::wstring &sentence,
std::vector<std::wstring> *sentence_part);
std::string CreateTextValue(const std::string &num, bool use_zero = true);
std::string SingleDigit2Text(const std::string &num_str,
bool alt_one = false);
std::string SingleDigit2Text(const std::wstring &num, bool alt_one = false);
std::string MultiDigit2Text(const std::string &num_str,
bool alt_one = false,
bool use_zero = true);
std::string MultiDigit2Text(const std::wstring &num,
bool alt_one = false,
bool use_zero = true);
std::string Digits2Text(const std::string &num_str);
std::string Digits2Text(const std::wstring &num);
int ReData(std::wstring *sentence);
int ReData2(std::wstring *sentence);
int ReTime(std::wstring *sentence);
int ReTemperature(std::wstring *sentence);
int ReFrac(std::wstring *sentence);
int RePercentage(std::wstring *sentence);
int ReMobilePhone(std::wstring *sentence);
int RePhone(std::wstring *sentence);
int ReRange(std::wstring *sentence);
int ReInterger(std::wstring *sentence);
int ReDecimalNum(std::wstring *sentence);
int RePositiveQuantifiers(std::wstring *sentence);
int ReDefalutNum(std::wstring *sentence);
int ReNumber(std::wstring *sentence);
int SentenceNormalize(std::wstring *sentence);
private:
std::map<std::string, std::string> digits_map;
std::map<int, std::string> units_map;
};
} // namespace ppspeech
#endif

@ -0,0 +1,64 @@
cmake_minimum_required(VERSION 3.10)
project(tts_third_party_libs)
include(ExternalProject)
# gflags
ExternalProject_Add(gflags
GIT_REPOSITORY https://github.com/gflags/gflags.git
GIT_TAG v2.2.2
PREFIX ${CMAKE_CURRENT_BINARY_DIR}
INSTALL_DIR ${CMAKE_CURRENT_BINARY_DIR}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=<INSTALL_DIR>
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DBUILD_STATIC_LIBS=OFF
-DBUILD_SHARED_LIBS=ON
)
# glog
ExternalProject_Add(
glog
GIT_REPOSITORY https://github.com/google/glog.git
GIT_TAG v0.6.0
PREFIX ${CMAKE_CURRENT_BINARY_DIR}
INSTALL_DIR ${CMAKE_CURRENT_BINARY_DIR}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=<INSTALL_DIR>
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
DEPENDS gflags
)
# abseil
ExternalProject_Add(
abseil
GIT_REPOSITORY https://github.com/abseil/abseil-cpp.git
GIT_TAG 20230125.1
PREFIX ${CMAKE_CURRENT_BINARY_DIR}
INSTALL_DIR ${CMAKE_CURRENT_BINARY_DIR}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=<INSTALL_DIR>
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DABSL_PROPAGATE_CXX_STD=ON
)
# cppjieba (header-only)
ExternalProject_Add(
cppjieba
GIT_REPOSITORY https://github.com/yanyiwu/cppjieba.git
GIT_TAG v5.0.3
PREFIX ${CMAKE_CURRENT_BINARY_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)
# limonp (header-only)
ExternalProject_Add(
limonp
GIT_REPOSITORY https://github.com/yanyiwu/limonp.git
GIT_TAG v0.6.6
PREFIX ${CMAKE_CURRENT_BINARY_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)

@ -14,8 +14,8 @@
from audio_search import app
from fastapi.testclient import TestClient
from utils.utility import download
from utils.utility import unpack
from paddlespeech.dataset.download import download
from paddlespeech.dataset.download import unpack
client = TestClient(app)

@ -14,8 +14,8 @@
from fastapi.testclient import TestClient
from vpr_search import app
from utils.utility import download
from utils.utility import unpack
from paddlespeech.dataset.download import download
from paddlespeech.dataset.download import unpack
client = TestClient(app)

@ -17,7 +17,7 @@ The input of this demo should be a WAV file(`.wav`), and the sample rate must be
Here are sample files for this demo that can be downloaded:
```bash
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/ch_zh_mix.wav
```
### 3. Usage
@ -27,6 +27,8 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
paddlespeech asr --input ./zh.wav -v
# English
paddlespeech asr --model transformer_librispeech --lang en --input ./en.wav -v
# Code-Switch
paddlespeech asr --model conformer_talcs --lang zh_en --codeswitch True --input ./ch_zh_mix.wav -v
# Chinese ASR + Punctuation Restoration
paddlespeech asr --input ./zh.wav -v | paddlespeech text --task punc -v
```
@ -40,6 +42,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
- `input`(required): Audio file to recognize.
- `model`: Model type of asr task. Default: `conformer_wenetspeech`.
- `lang`: Model language. Default: `zh`.
- `codeswitch`: Code Swith Model. Default: `False`
- `sample_rate`: Sample rate of the model. Default: `16000`.
- `config`: Config of asr task. Use pretrained model when it is None. Default: `None`.
- `ckpt_path`: Model checkpoint. Use pretrained model when it is None. Default: `None`.
@ -83,14 +86,15 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
Here is a list of pretrained models released by PaddleSpeech that can be used by command and python API:
| Model | Language | Sample Rate
| :--- | :---: | :---: |
| conformer_wenetspeech | zh | 16k
| conformer_online_multicn | zh | 16k
| conformer_aishell | zh | 16k
| conformer_online_aishell | zh | 16k
| transformer_librispeech | en | 16k
| deepspeech2online_wenetspeech | zh | 16k
| deepspeech2offline_aishell| zh| 16k
| deepspeech2online_aishell | zh | 16k
| deepspeech2offline_librispeech | en | 16k
| Model | Code Switch | Language | Sample Rate
| :--- | :---: | :---: | :---: |
| conformer_wenetspeech | False | zh | 16k
| conformer_online_multicn | False | zh | 16k
| conformer_aishell | False | zh | 16k
| conformer_online_aishell | False | zh | 16k
| transformer_librispeech | False | en | 16k
| deepspeech2online_wenetspeech | False | zh | 16k
| deepspeech2offline_aishell | False | zh| 16k
| deepspeech2online_aishell | False | zh | 16k
| deepspeech2offline_librispeech | False | en | 16k
| conformer_talcs | True | zh_en | 16k

@ -1,4 +1,5 @@
(简体中文|[English](./README.md))
(简体中文|[English](./README.md))
# 语音识别
## 介绍
@ -16,7 +17,7 @@
可以下载此 demo 的示例音频:
```bash
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/ch_zh_mix.wav
```
### 3. 使用方法
- 命令行 (推荐使用)
@ -25,6 +26,8 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
paddlespeech asr --input ./zh.wav -v
# 英文
paddlespeech asr --model transformer_librispeech --lang en --input ./en.wav -v
#中英混合
paddlespeech asr --model conformer_talcs --lang zh_en --codeswitch True --input ./ch_zh_mix.wav -v
# 中文 + 标点恢复
paddlespeech asr --input ./zh.wav -v | paddlespeech text --task punc -v
```
@ -38,6 +41,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
- `input`(必须输入):用于识别的音频文件。
- `model`ASR 任务的模型,默认值:`conformer_wenetspeech`。
- `lang`:模型语言,默认值:`zh`。
- `codeswitch`: 是否使用语言转换,默认值:`False`。
- `sample_rate`:音频采样率,默认值:`16000`。
- `config`ASR 任务的参数文件,若不设置则使用预训练模型中的默认配置,默认值:`None`。
- `ckpt_path`:模型参数文件,若不设置则下载预训练模型使用,默认值:`None`。
@ -80,14 +84,15 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
### 4.预训练模型
以下是 PaddleSpeech 提供的可以被命令行和 python API 使用的预训练模型列表:
| 模型 | 语言 | 采样率
| :--- | :---: | :---: |
| conformer_wenetspeech | zh | 16k
| conformer_online_multicn | zh | 16k
| conformer_aishell | zh | 16k
| conformer_online_aishell | zh | 16k
| transformer_librispeech | en | 16k
| deepspeech2online_wenetspeech | zh | 16k
| deepspeech2offline_aishell| zh| 16k
| deepspeech2online_aishell | zh | 16k
| deepspeech2offline_librispeech | en | 16k
| 模型 | 语言转换 | 语言 | 采样率
| :--- | :---: | :---: | :---: |
| conformer_wenetspeech | False | zh | 16k
| conformer_online_multicn | False | zh | 16k
| conformer_aishell | False | zh | 16k
| conformer_online_aishell | False | zh | 16k
| transformer_librispeech | False | en | 16k
| deepspeech2online_wenetspeech | False | zh | 16k
| deepspeech2offline_aishell | False | zh| 16k
| deepspeech2online_aishell | False | zh | 16k
| deepspeech2offline_librispeech | False | en | 16k
| conformer_talcs | True | zh_en | 16k

@ -2,6 +2,7 @@
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/ch_zh_mix.wav
# asr
paddlespeech asr --input ./zh.wav
@ -18,6 +19,11 @@ paddlespeech asr --help
# english asr
paddlespeech asr --lang en --model transformer_librispeech --input ./en.wav
# code-switch asr
paddlespeech asr --lang zh_en --codeswitch True --model conformer_talcs --input ./ch_zh_mix.wav
# model stats
paddlespeech stats --task asr

@ -23,7 +23,7 @@ Paddle Speech Demo 是一个以 PaddleSpeech 的语音交互功能为主体开
+ ERNIE-SAT语言-语音跨模态大模型 ERNIE-SAT 可视化展示示例,支持个性化合成,跨语言语音合成(音频为中文则输入英文文本进行合成),语音编辑(修改音频文字中间的结果)功能。 ERNIE-SAT 更多实现细节,可以参考:
+ [【ERNIE-SAT with AISHELL-3 dataset】](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/ernie_sat)
+ [【ERNIE-SAT with with AISHELL3 and VCTK datasets】](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3_vctk/ernie_sat)
+ [【ERNIE-SAT with AISHELL3 and VCTK datasets】](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3_vctk/ernie_sat)
+ [【ERNIE-SAT with VCTK dataset】](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/vctk/ernie_sat)
运行效果:

@ -260,7 +260,7 @@ async def websocket_endpoint_online(websocket: WebSocket):
# and we break the loop
if message['signal'] == 'start':
resp = {"status": "ok", "signal": "server_ready"}
# do something at begining here
# do something at beginning here
# create the instance to process the audio
# connection_handler = chatbot.asr.connection_handler
connection_handler = PaddleASRConnectionHanddler(engine)

@ -1,8 +1,6 @@
aiofiles
faiss-cpu
praatio==5.0.0
praatio>=5.0.0
pydantic
python-multipart
scikit_learn
starlette
uvicorn

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

@ -0,0 +1,162 @@
#!/usr/bin/python
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# calc avg RTF(NOT Accurate): grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}'
# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav
# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --wavfile ./zh.wav
import argparse
import asyncio
import codecs
import os
from pydub import AudioSegment
import re
from paddlespeech.cli.log import logger
from paddlespeech.server.utils.audio_handler import ASRWsAudioHandler
def convert_to_wav(input_file):
# Load audio file
audio = AudioSegment.from_file(input_file)
# Set parameters for audio file
audio = audio.set_channels(1)
audio = audio.set_frame_rate(16000)
# Create output filename
output_file = os.path.splitext(input_file)[0] + ".wav"
# Export audio file as WAV
audio.export(output_file, format="wav")
logger.info(f"{input_file} converted to {output_file}")
def format_time(sec):
# Convert seconds to SRT format (HH:MM:SS,ms)
hours = int(sec/3600)
minutes = int((sec%3600)/60)
seconds = int(sec%60)
milliseconds = int((sec%1)*1000)
return f'{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}'
def results2srt(results, srt_file):
"""convert results from paddlespeech to srt format for subtitle
Args:
results (dict): results from paddlespeech
"""
# times contains start and end time of each word
times = results['times']
# result contains the whole sentence including punctuation
result = results['result']
# split result into several sencences by '' and '。'
sentences = re.split('|。', result)[:-1]
# print("sentences: ", sentences)
# generate relative time for each sentence in sentences
relative_times = []
word_i = 0
for sentence in sentences:
relative_times.append([])
for word in sentence:
if relative_times[-1] == []:
relative_times[-1].append(times[word_i]['bg'])
if len(relative_times[-1]) == 1:
relative_times[-1].append(times[word_i]['ed'])
else:
relative_times[-1][1] = times[word_i]['ed']
word_i += 1
# print("relative_times: ", relative_times)
# generate srt file acoording to relative_times and sentences
with open(srt_file, 'w') as f:
for i in range(len(sentences)):
# Write index number
f.write(str(i+1)+'\n')
# Write start and end times
start = format_time(relative_times[i][0])
end = format_time(relative_times[i][1])
f.write(start + ' --> ' + end + '\n')
# Write text
f.write(sentences[i]+'\n\n')
logger.info(f"results saved to {srt_file}")
def main(args):
logger.info("asr websocket client start")
handler = ASRWsAudioHandler(
args.server_ip,
args.port,
endpoint=args.endpoint,
punc_server_ip=args.punc_server_ip,
punc_server_port=args.punc_server_port)
loop = asyncio.get_event_loop()
# check if the wav file is mp3 format
# if so, convert it to wav format using convert_to_wav function
if args.wavfile and os.path.exists(args.wavfile):
if args.wavfile.endswith(".mp3"):
convert_to_wav(args.wavfile)
args.wavfile = args.wavfile.replace(".mp3", ".wav")
# support to process single audio file
if args.wavfile and os.path.exists(args.wavfile):
logger.info(f"start to process the wavscp: {args.wavfile}")
result = loop.run_until_complete(handler.run(args.wavfile))
# result = result["result"]
# logger.info(f"asr websocket client finished : {result}")
results2srt(result, args.wavfile.replace(".wav", ".srt"))
# support to process batch audios from wav.scp
if args.wavscp and os.path.exists(args.wavscp):
logger.info(f"start to process the wavscp: {args.wavscp}")
with codecs.open(args.wavscp, 'r', encoding='utf-8') as f,\
codecs.open("result.txt", 'w', encoding='utf-8') as w:
for line in f:
utt_name, utt_path = line.strip().split()
result = loop.run_until_complete(handler.run(utt_path))
result = result["result"]
w.write(f"{utt_name} {result}\n")
if __name__ == "__main__":
logger.info("Start to do streaming asr client")
parser = argparse.ArgumentParser()
parser.add_argument(
'--server_ip', type=str, default='127.0.0.1', help='server ip')
parser.add_argument('--port', type=int, default=8090, help='server port')
parser.add_argument(
'--punc.server_ip',
type=str,
default=None,
dest="punc_server_ip",
help='Punctuation server ip')
parser.add_argument(
'--punc.port',
type=int,
default=8091,
dest="punc_server_port",
help='Punctuation server port')
parser.add_argument(
"--endpoint",
type=str,
default="/paddlespeech/asr/streaming",
help="ASR websocket endpoint")
parser.add_argument(
"--wavfile",
action="store",
help="wav file path ",
default="./16_audio.wav")
parser.add_argument(
"--wavscp", type=str, default=None, help="The batch audios dict text")
args = parser.parse_args()
main(args)

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save