diff --git a/README.md b/README.md
index a90498293..5093dbd67 100644
--- a/README.md
+++ b/README.md
@@ -280,10 +280,14 @@ paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav
For more information about server command lines, please see: [speech server demos](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/demos/speech_server)
+
+
## Model List
PaddleSpeech supports a series of most popular models. They are summarized in [released models](./docs/source/released_model.md) and attached with available pretrained models.
+
+
**Speech-to-Text** contains *Acoustic Model*, *Language Model*, and *Speech Translation*, with the following details:
@@ -357,6 +361,8 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
+
+
**Text-to-Speech** in PaddleSpeech mainly contains three modules: *Text Frontend*, *Acoustic Model* and *Vocoder*. Acoustic Model and Vocoder models are listed as follow:
@@ -457,10 +463,10 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
- GE2E + Tactron2
+ GE2E + Tacotron2
AISHELL-3
- ge2e-tactron2-aishell3
+ ge2e-tacotron2-aishell3
@@ -473,6 +479,8 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
+
+
**Audio Classification**
@@ -496,6 +504,8 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
+
+
**Speaker Verification**
@@ -519,6 +529,8 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
+
+
**Punctuation Restoration**
@@ -559,10 +571,18 @@ Normally, [Speech SoTA](https://paperswithcode.com/area/speech), [Audio SoTA](ht
- [Advanced Usage](./docs/source/tts/advanced_usage.md)
- [Chinese Rule Based Text Frontend](./docs/source/tts/zh_text_frontend.md)
- [Test Audio Samples](https://paddlespeech.readthedocs.io/en/latest/tts/demo.html)
+ - Speaker Verification
+ - [Audio Searching](./demos/audio_searching/README.md)
+ - [Speaker Verification](./demos/speaker_verification/README.md)
- [Audio Classification](./demos/audio_tagging/README.md)
- - [Speaker Verification](./demos/speaker_verification/README.md)
- [Speech Translation](./demos/speech_translation/README.md)
+ - [Speech Server](./demos/speech_server/README.md)
- [Released Models](./docs/source/released_model.md)
+ - [Speech-to-Text](#SpeechToText)
+ - [Text-to-Speech](#TextToSpeech)
+ - [Audio Classification](#AudioClassification)
+ - [Speaker Verification](#SpeakerVerification)
+ - [Punctuation Restoration](#PunctuationRestoration)
- [Community](#Community)
- [Welcome to contribute](#contribution)
- [License](#License)
diff --git a/README_cn.md b/README_cn.md
index ab4ce6e6b..5dab7fa0c 100644
--- a/README_cn.md
+++ b/README_cn.md
@@ -273,6 +273,8 @@ paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav
## 模型列表
PaddleSpeech 支持很多主流的模型,并提供了预训练模型,详情请见[模型列表](./docs/source/released_model.md)。
+
+
PaddleSpeech 的 **语音转文本** 包含语音识别声学模型、语音识别语言模型和语音翻译, 详情如下:
@@ -347,6 +349,7 @@ PaddleSpeech 的 **语音转文本** 包含语音识别声学模型、语音识
+
PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声学模型和声码器。声学模型和声码器模型如下:
+
+
**声纹识别**
@@ -511,6 +516,8 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
+
+
**标点恢复**
@@ -556,13 +563,18 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
- [进阶用法](./docs/source/tts/advanced_usage.md)
- [中文文本前端](./docs/source/tts/zh_text_frontend.md)
- [测试语音样本](https://paddlespeech.readthedocs.io/en/latest/tts/demo.html)
+ - 声纹识别
+ - [声纹识别](./demos/speaker_verification/README_cn.md)
+ - [音频检索](./demos/audio_searching/README_cn.md)
- [声音分类](./demos/audio_tagging/README_cn.md)
- - [声纹识别](./demos/speaker_verification/README_cn.md)
- [语音翻译](./demos/speech_translation/README_cn.md)
+ - [服务化部署](./demos/speech_server/README_cn.md)
- [模型列表](#模型列表)
- [语音识别](#语音识别模型)
- [语音合成](#语音合成模型)
- [声音分类](#声音分类模型)
+ - [声纹识别](#声纹识别模型)
+ - [标点恢复](#标点恢复模型)
- [技术交流群](#技术交流群)
- [欢迎贡献](#欢迎贡献)
- [License](#License)
diff --git a/dataset/rir_noise/rir_noise.py b/dataset/rir_noise/rir_noise.py
index e7b122890..009175e5b 100644
--- a/dataset/rir_noise/rir_noise.py
+++ b/dataset/rir_noise/rir_noise.py
@@ -34,14 +34,14 @@ from utils.utility import unzip
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
-URL_ROOT = 'http://www.openslr.org/resources/28'
+URL_ROOT = '--no-check-certificate http://www.openslr.org/resources/28'
DATA_URL = URL_ROOT + '/rirs_noises.zip'
MD5_DATA = 'e6f48e257286e05de56413b4779d8ffb'
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
- default=DATA_HOME + "/Aishell",
+ default=DATA_HOME + "/rirs_noise",
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
@@ -81,6 +81,10 @@ def create_manifest(data_dir, manifest_path_prefix):
},
ensure_ascii=False))
manifest_path = manifest_path_prefix + '.' + dtype
+
+ if not os.path.exists(os.path.dirname(manifest_path)):
+ os.makedirs(os.path.dirname(manifest_path))
+
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
diff --git a/dataset/voxceleb/voxceleb1.py b/dataset/voxceleb/voxceleb1.py
index 905862008..95827f708 100644
--- a/dataset/voxceleb/voxceleb1.py
+++ b/dataset/voxceleb/voxceleb1.py
@@ -149,7 +149,7 @@ def prepare_dataset(base_url, data_list, target_dir, manifest_path,
# we will download the voxceleb1 data to ${target_dir}/vox1/dev/ or ${target_dir}/vox1/test directory
if not os.path.exists(os.path.join(target_dir, "wav")):
# download all dataset part
- print("start to download the vox1 dev zip package")
+ print(f"start to download the vox1 zip package to {target_dir}")
for zip_part in data_list.keys():
download_url = " --no-check-certificate " + base_url + "/" + zip_part
download(
diff --git a/dataset/voxceleb/voxceleb2.py b/dataset/voxceleb/voxceleb2.py
index 22a2e2ffe..fe9e8b9c8 100644
--- a/dataset/voxceleb/voxceleb2.py
+++ b/dataset/voxceleb/voxceleb2.py
@@ -22,10 +22,12 @@ import codecs
import glob
import json
import os
+import subprocess
from pathlib import Path
import soundfile
+from utils.utility import check_md5sum
from utils.utility import download
from utils.utility import unzip
@@ -35,12 +37,22 @@ DATA_HOME = os.path.expanduser('.')
BASE_URL = "--no-check-certificate https://www.robots.ox.ac.uk/~vgg/data/voxceleb/data/"
# dev data
-DEV_DATA_URL = BASE_URL + '/vox2_aac.zip'
-DEV_MD5SUM = "bbc063c46078a602ca71605645c2a402"
+DEV_LIST = {
+ "vox2_dev_aac_partaa": "da070494c573e5c0564b1d11c3b20577",
+ "vox2_dev_aac_partab": "17fe6dab2b32b48abaf1676429cdd06f",
+ "vox2_dev_aac_partac": "1de58e086c5edf63625af1cb6d831528",
+ "vox2_dev_aac_partad": "5a043eb03e15c5a918ee6a52aad477f9",
+ "vox2_dev_aac_partae": "cea401b624983e2d0b2a87fb5d59aa60",
+ "vox2_dev_aac_partaf": "fc886d9ba90ab88e7880ee98effd6ae9",
+ "vox2_dev_aac_partag": "d160ecc3f6ee3eed54d55349531cb42e",
+ "vox2_dev_aac_partah": "6b84a81b9af72a9d9eecbb3b1f602e65",
+}
+
+DEV_TARGET_DATA = "vox2_dev_aac_parta* vox2_dev_aac.zip bbc063c46078a602ca71605645c2a402"
# test data
-TEST_DATA_URL = BASE_URL + '/vox2_test_aac.zip'
-TEST_MD5SUM = "0d2b3ea430a821c33263b5ea37ede312"
+TEST_LIST = {"vox2_test_aac.zip": "0d2b3ea430a821c33263b5ea37ede312"}
+TEST_TARGET_DATA = "vox2_test_aac.zip vox2_test_aac.zip 0d2b3ea430a821c33263b5ea37ede312"
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
@@ -68,6 +80,14 @@ args = parser.parse_args()
def create_manifest(data_dir, manifest_path_prefix):
+ """Generate the voxceleb2 dataset manifest file.
+ We will create the ${manifest_path_prefix}.vox2 as the final manifest file
+ The dev and test wav info will be put in one manifest file.
+
+ Args:
+ data_dir (str): voxceleb2 wav directory, which include dev and test subdataset
+ manifest_path_prefix (str): manifest file prefix
+ """
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
data_path = os.path.join(data_dir, "**", "*.wav")
@@ -119,7 +139,19 @@ def create_manifest(data_dir, manifest_path_prefix):
print(f"{total_sec / total_num} sec/utt", file=f)
-def download_dataset(url, md5sum, target_dir, dataset):
+def download_dataset(base_url, data_list, target_data, target_dir, dataset):
+ """Download the voxceleb2 zip package
+
+ Args:
+ base_url (str): the voxceleb2 dataset download baseline url
+ data_list (dict): the dataset part zip package and the md5 value
+ target_data (str): the final dataset zip info
+ target_dir (str): the dataset stored directory
+ dataset (str): the dataset name, dev or test
+
+ Raises:
+ RuntimeError: the md5sum occurs error
+ """
if not os.path.exists(target_dir):
os.makedirs(target_dir)
@@ -129,9 +161,34 @@ def download_dataset(url, md5sum, target_dir, dataset):
# but the test dataset will unzip to aac
# so, wo create the ${target_dir}/test and unzip the m4a to test dir
if not os.path.exists(os.path.join(target_dir, dataset)):
- filepath = download(url, md5sum, target_dir)
+ print(f"start to download the vox2 zip package to {target_dir}")
+ for zip_part in data_list.keys():
+ download_url = " --no-check-certificate " + base_url + "/" + zip_part
+ download(
+ url=download_url,
+ md5sum=data_list[zip_part],
+ target_dir=target_dir)
+
+ # pack the all part to target zip file
+ all_target_part, target_name, target_md5sum = target_data.split()
+ target_name = os.path.join(target_dir, target_name)
+ if not os.path.exists(target_name):
+ pack_part_cmd = "cat {}/{} > {}".format(target_dir, all_target_part,
+ target_name)
+ subprocess.call(pack_part_cmd, shell=True)
+
+ # check the target zip file md5sum
+ if not check_md5sum(target_name, target_md5sum):
+ raise RuntimeError("{} MD5 checkssum failed".format(target_name))
+ else:
+ print("Check {} md5sum successfully".format(target_name))
+
if dataset == "test":
- unzip(filepath, os.path.join(target_dir, "test"))
+ # we need make the test directory
+ unzip(target_name, os.path.join(target_dir, "test"))
+ else:
+ # upzip dev zip pacakge and will create the dev directory
+ unzip(target_name, target_dir)
def main():
@@ -142,14 +199,16 @@ def main():
print("download: {}".format(args.download))
if args.download:
download_dataset(
- url=DEV_DATA_URL,
- md5sum=DEV_MD5SUM,
+ base_url=BASE_URL,
+ data_list=DEV_LIST,
+ target_data=DEV_TARGET_DATA,
target_dir=args.target_dir,
dataset="dev")
download_dataset(
- url=TEST_DATA_URL,
- md5sum=TEST_MD5SUM,
+ base_url=BASE_URL,
+ data_list=TEST_LIST,
+ target_data=TEST_TARGET_DATA,
target_dir=args.target_dir,
dataset="test")
diff --git a/demos/speaker_verification/README.md b/demos/speaker_verification/README.md
index 8739d402d..7d7180ae9 100644
--- a/demos/speaker_verification/README.md
+++ b/demos/speaker_verification/README.md
@@ -30,6 +30,11 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
paddlespeech vector --task spk --input vec.job
echo -e "demo2 85236145389.wav \n demo3 85236145389.wav" | paddlespeech vector --task spk
+
+ paddlespeech vector --task score --input "./85236145389.wav ./123456789.wav"
+
+ echo -e "demo4 85236145389.wav 85236145389.wav \n demo5 85236145389.wav 123456789.wav" > vec.job
+ paddlespeech vector --task score --input vec.job
```
Usage:
@@ -38,6 +43,7 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
```
Arguments:
- `input`(required): Audio file to recognize.
+ - `task` (required): Specify `vector` task. Default `spk`。
- `model`: Model type of vector task. Default: `ecapatdnn_voxceleb12`.
- `sample_rate`: Sample rate of the model. Default: `16000`.
- `config`: Config of vector task. Use pretrained model when it is None. Default: `None`.
@@ -47,45 +53,45 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
Output:
```bash
- demo [ -5.749211 9.505463 -8.200284 -5.2075014 5.3940268
- -3.04878 1.611095 10.127234 -10.534177 -15.821609
- 1.2032688 -0.35080156 1.2629458 -12.643498 -2.5758228
- -11.343508 2.3385992 -8.719341 14.213509 15.404744
- -0.39327756 6.338786 2.688887 8.7104025 17.469526
- -8.77959 7.0576906 4.648855 -1.3089896 -23.294737
- 8.013747 13.891729 -9.926753 5.655307 -5.9422326
- -22.842539 0.6293588 -18.46266 -10.811862 9.8192625
- 3.0070958 3.8072643 -2.3861165 3.0821571 -14.739942
- 1.7594414 -0.6485091 4.485623 2.0207152 7.264915
- -6.40137 23.63524 2.9711294 -22.708025 9.93719
- 20.354511 -10.324688 -0.700492 -8.783211 -5.27593
- 15.999649 3.3004563 12.747926 15.429879 4.7849145
- 5.6699696 -2.3826702 10.605882 3.9112158 3.1500628
- 15.859915 -2.1832209 -23.908653 -6.4799504 -4.5365124
- -9.224193 14.568347 -10.568833 4.982321 -4.342062
- 0.0914714 12.645902 -5.74285 -3.2141201 -2.7173362
- -6.680575 0.4757669 -5.035051 -6.7964664 16.865469
- -11.54324 7.681869 0.44475392 9.708182 -8.932846
- 0.4123232 -4.361452 1.3948607 9.511665 0.11667654
- 2.9079323 6.049952 9.275183 -18.078873 6.2983274
- -0.7500531 -2.725033 -7.6027865 3.3404543 2.990815
- 4.010979 11.000591 -2.8873312 7.1352735 -16.79663
- 18.495346 -14.293832 7.89578 2.2714825 22.976387
- -4.875734 -3.0836344 -2.9999814 13.751918 6.448228
- -11.924197 2.171869 2.0423572 -6.173772 10.778437
- 25.77281 -4.9495463 14.57806 0.3044315 2.6132357
- -7.591999 -2.076944 9.025118 1.7834753 -3.1799617
- -4.9401326 23.465864 5.1685796 -9.018578 9.037825
- -4.4150195 6.859591 -12.274467 -0.88911164 5.186309
- -3.9988663 -13.638606 -9.925445 -0.06329413 -3.6709652
- -12.397416 -12.719869 -1.395601 2.1150916 5.7381287
- -4.4691963 -3.82819 -0.84233856 -1.1604277 -13.490127
- 8.731719 -20.778936 -11.495662 5.8033476 -4.752041
- 10.833007 -6.717991 4.504732 13.4244375 1.1306485
- 7.3435574 1.400918 14.704036 -9.501399 7.2315617
- -6.417456 1.3333273 11.872697 -0.30664724 8.8845
- 6.5569253 4.7948146 0.03662816 -8.704245 6.224871
- -3.2701402 -11.508579 ]
+ demo [ 1.4217498 5.626253 -5.342073 1.1773866 3.308055
+ 1.756596 5.167894 10.80636 -3.8226728 -5.6141334
+ 2.623845 -0.8072968 1.9635103 -7.3128724 0.01103897
+ -9.723131 0.6619743 -6.976803 10.213478 7.494748
+ 2.9105635 3.8949256 3.7999806 7.1061673 16.905321
+ -7.1493764 8.733103 3.4230042 -4.831653 -11.403367
+ 11.232214 7.1274667 -4.2828417 2.452362 -5.130748
+ -18.177666 -2.6116815 -11.000337 -6.7314315 1.6564683
+ 0.7618269 1.1253023 -2.083836 4.725744 -8.782597
+ -3.539873 3.814236 5.1420674 2.162061 4.096431
+ -6.4162116 12.747448 1.9429878 -15.152943 6.417416
+ 16.097002 -9.716668 -1.9920526 -3.3649497 -1.871939
+ 11.567354 3.69788 11.258265 7.442363 9.183411
+ 4.5281515 -1.2417862 4.3959084 6.6727695 5.8898783
+ 7.627124 -0.66919386 -11.889693 -9.208865 -7.4274073
+ -3.7776625 6.917234 -9.848748 -2.0944717 -5.135116
+ 0.49563864 9.317534 -5.9141874 -1.8098574 -0.11738578
+ -7.169265 -1.0578263 -5.7216787 -5.1173844 16.137651
+ -4.473626 7.6624317 -0.55381083 9.631587 -6.4704556
+ -8.548508 4.3716145 -0.79702514 4.478997 -2.9758704
+ 3.272176 2.8382776 5.134597 -9.190781 -0.5657382
+ -4.8745747 2.3165567 -5.984303 -2.1798875 0.35541576
+ -0.31784213 9.493548 2.1144536 4.358092 -12.089823
+ 8.451689 -7.925461 4.6242585 4.4289427 18.692003
+ -2.6204622 -5.149185 -0.35821092 8.488551 4.981496
+ -9.32683 -2.2544234 6.6417594 1.2119585 10.977129
+ 16.555033 3.3238444 9.551863 -1.6676947 -0.79539716
+ -8.605674 -0.47356385 2.6741948 -5.359179 -2.6673796
+ 0.66607 15.443222 4.740594 -3.4725387 11.592567
+ -2.054497 1.7361217 -8.265324 -9.30447 5.4068313
+ -1.5180256 -7.746615 -6.089606 0.07112726 -0.34904733
+ -8.649895 -9.998958 -2.564841 -0.53999114 2.601808
+ -0.31927416 -1.8815292 -2.07215 -3.4105783 -8.2998085
+ 1.483641 -15.365992 -8.288208 3.8847756 -3.4876456
+ 7.3629923 0.4657332 3.132599 12.438889 -1.8337058
+ 4.532936 2.7264361 10.145339 -6.521951 2.897153
+ -3.3925855 5.079156 7.759716 4.677565 5.8457737
+ 2.402413 7.7071047 3.9711342 -6.390043 6.1268735
+ -3.7760346 -11.118123 ]
```
- Python API
@@ -97,56 +103,113 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
audio_emb = vector_executor(
model='ecapatdnn_voxceleb12',
sample_rate=16000,
- config=None,
+ config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
ckpt_path=None,
audio_file='./85236145389.wav',
- force_yes=False,
device=paddle.get_device())
print('Audio embedding Result: \n{}'.format(audio_emb))
+
+ test_emb = vector_executor(
+ model='ecapatdnn_voxceleb12',
+ sample_rate=16000,
+ config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
+ ckpt_path=None,
+ audio_file='./123456789.wav',
+ device=paddle.get_device())
+ print('Test embedding Result: \n{}'.format(test_emb))
+
+ # score range [0, 1]
+ score = vector_executor.get_embeddings_score(audio_emb, test_emb)
+ print(f"Eembeddings Score: {score}")
```
- Output:
+ Output:
+
```bash
# Vector Result:
- [ -5.749211 9.505463 -8.200284 -5.2075014 5.3940268
- -3.04878 1.611095 10.127234 -10.534177 -15.821609
- 1.2032688 -0.35080156 1.2629458 -12.643498 -2.5758228
- -11.343508 2.3385992 -8.719341 14.213509 15.404744
- -0.39327756 6.338786 2.688887 8.7104025 17.469526
- -8.77959 7.0576906 4.648855 -1.3089896 -23.294737
- 8.013747 13.891729 -9.926753 5.655307 -5.9422326
- -22.842539 0.6293588 -18.46266 -10.811862 9.8192625
- 3.0070958 3.8072643 -2.3861165 3.0821571 -14.739942
- 1.7594414 -0.6485091 4.485623 2.0207152 7.264915
- -6.40137 23.63524 2.9711294 -22.708025 9.93719
- 20.354511 -10.324688 -0.700492 -8.783211 -5.27593
- 15.999649 3.3004563 12.747926 15.429879 4.7849145
- 5.6699696 -2.3826702 10.605882 3.9112158 3.1500628
- 15.859915 -2.1832209 -23.908653 -6.4799504 -4.5365124
- -9.224193 14.568347 -10.568833 4.982321 -4.342062
- 0.0914714 12.645902 -5.74285 -3.2141201 -2.7173362
- -6.680575 0.4757669 -5.035051 -6.7964664 16.865469
- -11.54324 7.681869 0.44475392 9.708182 -8.932846
- 0.4123232 -4.361452 1.3948607 9.511665 0.11667654
- 2.9079323 6.049952 9.275183 -18.078873 6.2983274
- -0.7500531 -2.725033 -7.6027865 3.3404543 2.990815
- 4.010979 11.000591 -2.8873312 7.1352735 -16.79663
- 18.495346 -14.293832 7.89578 2.2714825 22.976387
- -4.875734 -3.0836344 -2.9999814 13.751918 6.448228
- -11.924197 2.171869 2.0423572 -6.173772 10.778437
- 25.77281 -4.9495463 14.57806 0.3044315 2.6132357
- -7.591999 -2.076944 9.025118 1.7834753 -3.1799617
- -4.9401326 23.465864 5.1685796 -9.018578 9.037825
- -4.4150195 6.859591 -12.274467 -0.88911164 5.186309
- -3.9988663 -13.638606 -9.925445 -0.06329413 -3.6709652
- -12.397416 -12.719869 -1.395601 2.1150916 5.7381287
- -4.4691963 -3.82819 -0.84233856 -1.1604277 -13.490127
- 8.731719 -20.778936 -11.495662 5.8033476 -4.752041
- 10.833007 -6.717991 4.504732 13.4244375 1.1306485
- 7.3435574 1.400918 14.704036 -9.501399 7.2315617
- -6.417456 1.3333273 11.872697 -0.30664724 8.8845
- 6.5569253 4.7948146 0.03662816 -8.704245 6.224871
- -3.2701402 -11.508579 ]
+ Audio embedding Result:
+ [ 1.4217498 5.626253 -5.342073 1.1773866 3.308055
+ 1.756596 5.167894 10.80636 -3.8226728 -5.6141334
+ 2.623845 -0.8072968 1.9635103 -7.3128724 0.01103897
+ -9.723131 0.6619743 -6.976803 10.213478 7.494748
+ 2.9105635 3.8949256 3.7999806 7.1061673 16.905321
+ -7.1493764 8.733103 3.4230042 -4.831653 -11.403367
+ 11.232214 7.1274667 -4.2828417 2.452362 -5.130748
+ -18.177666 -2.6116815 -11.000337 -6.7314315 1.6564683
+ 0.7618269 1.1253023 -2.083836 4.725744 -8.782597
+ -3.539873 3.814236 5.1420674 2.162061 4.096431
+ -6.4162116 12.747448 1.9429878 -15.152943 6.417416
+ 16.097002 -9.716668 -1.9920526 -3.3649497 -1.871939
+ 11.567354 3.69788 11.258265 7.442363 9.183411
+ 4.5281515 -1.2417862 4.3959084 6.6727695 5.8898783
+ 7.627124 -0.66919386 -11.889693 -9.208865 -7.4274073
+ -3.7776625 6.917234 -9.848748 -2.0944717 -5.135116
+ 0.49563864 9.317534 -5.9141874 -1.8098574 -0.11738578
+ -7.169265 -1.0578263 -5.7216787 -5.1173844 16.137651
+ -4.473626 7.6624317 -0.55381083 9.631587 -6.4704556
+ -8.548508 4.3716145 -0.79702514 4.478997 -2.9758704
+ 3.272176 2.8382776 5.134597 -9.190781 -0.5657382
+ -4.8745747 2.3165567 -5.984303 -2.1798875 0.35541576
+ -0.31784213 9.493548 2.1144536 4.358092 -12.089823
+ 8.451689 -7.925461 4.6242585 4.4289427 18.692003
+ -2.6204622 -5.149185 -0.35821092 8.488551 4.981496
+ -9.32683 -2.2544234 6.6417594 1.2119585 10.977129
+ 16.555033 3.3238444 9.551863 -1.6676947 -0.79539716
+ -8.605674 -0.47356385 2.6741948 -5.359179 -2.6673796
+ 0.66607 15.443222 4.740594 -3.4725387 11.592567
+ -2.054497 1.7361217 -8.265324 -9.30447 5.4068313
+ -1.5180256 -7.746615 -6.089606 0.07112726 -0.34904733
+ -8.649895 -9.998958 -2.564841 -0.53999114 2.601808
+ -0.31927416 -1.8815292 -2.07215 -3.4105783 -8.2998085
+ 1.483641 -15.365992 -8.288208 3.8847756 -3.4876456
+ 7.3629923 0.4657332 3.132599 12.438889 -1.8337058
+ 4.532936 2.7264361 10.145339 -6.521951 2.897153
+ -3.3925855 5.079156 7.759716 4.677565 5.8457737
+ 2.402413 7.7071047 3.9711342 -6.390043 6.1268735
+ -3.7760346 -11.118123 ]
+ # get the test embedding
+ Test embedding Result:
+ [ -1.902964 2.0690894 -8.034194 3.5472693 0.18089125
+ 6.9085927 1.4097427 -1.9487704 -10.021278 -0.20755845
+ -8.04332 4.344489 2.3200977 -14.306299 5.184692
+ -11.55602 -3.8497238 0.6444722 1.2833948 2.6766639
+ 0.5878921 0.7946299 1.7207596 2.5791872 14.998469
+ -1.3385371 15.031221 -0.8006958 1.99287 -9.52007
+ 2.435466 4.003221 -4.33817 -4.898601 -5.304714
+ -18.033886 10.790787 -12.784645 -5.641755 2.9761686
+ -10.566622 1.4839455 6.152458 -5.7195854 2.8603241
+ 6.112133 8.489869 5.5958056 1.2836679 -1.2293907
+ 0.89927405 7.0288725 -2.854029 -0.9782962 5.8255906
+ 14.905906 -5.025907 0.7866458 -4.2444224 -16.354029
+ 10.521315 0.9604709 -3.3257897 7.144871 -13.592733
+ -8.568869 -1.7953678 0.26313916 10.916714 -6.9374123
+ 1.857403 -6.2746415 2.8154466 -7.2338667 -2.293357
+ -0.05452765 5.4287076 5.0849075 -6.690375 -1.6183422
+ 3.654291 0.94352573 -9.200294 -5.4749465 -3.5235846
+ 1.3420814 4.240421 -2.772944 -2.8451524 16.311104
+ 4.2969875 -1.762936 -12.5758915 8.595198 -0.8835239
+ -1.5708797 1.568961 1.1413603 3.5032008 -0.45251232
+ -6.786333 16.89443 5.3366146 -8.789056 0.6355629
+ 3.2579517 -3.328322 7.5969577 0.66025066 -6.550468
+ -9.148656 2.020372 -0.4615173 1.1965656 -3.8764873
+ 11.6562195 -6.0750933 12.182899 3.2218833 0.81969476
+ 5.570001 -3.8459578 -7.205299 7.9262037 -7.6611166
+ -5.249467 -2.2671914 7.2658715 -13.298164 4.821147
+ -2.7263982 11.691089 -3.8918593 -2.838112 -1.0336838
+ -3.8034165 2.8536487 -5.60398 -1.1972581 1.3455094
+ -3.4903061 2.2408795 5.5010734 -3.970756 11.99696
+ -7.8858757 0.43160373 -5.5059714 4.3426995 16.322706
+ 11.635366 0.72157705 -9.245714 -3.91465 -4.449838
+ -1.5716927 7.713747 -2.2430465 -6.198303 -13.481864
+ 2.8156567 -5.7812386 5.1456156 2.7289324 -14.505571
+ 13.270688 3.448231 -7.0659585 4.5886116 -4.466099
+ -0.296428 -11.463529 -2.6076477 14.110243 -6.9725137
+ -1.9962958 2.7119343 19.391657 0.01961198 14.607133
+ -1.6695905 -4.391516 1.3131028 -6.670972 -5.888604
+ 12.0612335 5.9285784 3.3715196 1.492534 10.723728
+ -0.95514804 -12.085431 ]
+ # get the score between enroll and test
+ Eembeddings Score: 0.4292638301849365
```
### 4.Pretrained Models
diff --git a/demos/speaker_verification/README_cn.md b/demos/speaker_verification/README_cn.md
index fe8949b3c..db382f298 100644
--- a/demos/speaker_verification/README_cn.md
+++ b/demos/speaker_verification/README_cn.md
@@ -29,6 +29,11 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
paddlespeech vector --task spk --input vec.job
echo -e "demo2 85236145389.wav \n demo3 85236145389.wav" | paddlespeech vector --task spk
+
+ paddlespeech vector --task score --input "./85236145389.wav ./123456789.wav"
+
+ echo -e "demo4 85236145389.wav 85236145389.wav \n demo5 85236145389.wav 123456789.wav" > vec.job
+ paddlespeech vector --task score --input vec.job
```
使用方法:
@@ -37,6 +42,7 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
```
参数:
- `input`(必须输入):用于识别的音频文件。
+ - `task` (必须输入): 用于指定 `vector` 处理的具体任务,默认是 `spk`。
- `model`:声纹任务的模型,默认值:`ecapatdnn_voxceleb12`。
- `sample_rate`:音频采样率,默认值:`16000`。
- `config`:声纹任务的参数文件,若不设置则使用预训练模型中的默认配置,默认值:`None`。
@@ -45,45 +51,45 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
输出:
```bash
- demo [ -5.749211 9.505463 -8.200284 -5.2075014 5.3940268
- -3.04878 1.611095 10.127234 -10.534177 -15.821609
- 1.2032688 -0.35080156 1.2629458 -12.643498 -2.5758228
- -11.343508 2.3385992 -8.719341 14.213509 15.404744
- -0.39327756 6.338786 2.688887 8.7104025 17.469526
- -8.77959 7.0576906 4.648855 -1.3089896 -23.294737
- 8.013747 13.891729 -9.926753 5.655307 -5.9422326
- -22.842539 0.6293588 -18.46266 -10.811862 9.8192625
- 3.0070958 3.8072643 -2.3861165 3.0821571 -14.739942
- 1.7594414 -0.6485091 4.485623 2.0207152 7.264915
- -6.40137 23.63524 2.9711294 -22.708025 9.93719
- 20.354511 -10.324688 -0.700492 -8.783211 -5.27593
- 15.999649 3.3004563 12.747926 15.429879 4.7849145
- 5.6699696 -2.3826702 10.605882 3.9112158 3.1500628
- 15.859915 -2.1832209 -23.908653 -6.4799504 -4.5365124
- -9.224193 14.568347 -10.568833 4.982321 -4.342062
- 0.0914714 12.645902 -5.74285 -3.2141201 -2.7173362
- -6.680575 0.4757669 -5.035051 -6.7964664 16.865469
- -11.54324 7.681869 0.44475392 9.708182 -8.932846
- 0.4123232 -4.361452 1.3948607 9.511665 0.11667654
- 2.9079323 6.049952 9.275183 -18.078873 6.2983274
- -0.7500531 -2.725033 -7.6027865 3.3404543 2.990815
- 4.010979 11.000591 -2.8873312 7.1352735 -16.79663
- 18.495346 -14.293832 7.89578 2.2714825 22.976387
- -4.875734 -3.0836344 -2.9999814 13.751918 6.448228
- -11.924197 2.171869 2.0423572 -6.173772 10.778437
- 25.77281 -4.9495463 14.57806 0.3044315 2.6132357
- -7.591999 -2.076944 9.025118 1.7834753 -3.1799617
- -4.9401326 23.465864 5.1685796 -9.018578 9.037825
- -4.4150195 6.859591 -12.274467 -0.88911164 5.186309
- -3.9988663 -13.638606 -9.925445 -0.06329413 -3.6709652
- -12.397416 -12.719869 -1.395601 2.1150916 5.7381287
- -4.4691963 -3.82819 -0.84233856 -1.1604277 -13.490127
- 8.731719 -20.778936 -11.495662 5.8033476 -4.752041
- 10.833007 -6.717991 4.504732 13.4244375 1.1306485
- 7.3435574 1.400918 14.704036 -9.501399 7.2315617
- -6.417456 1.3333273 11.872697 -0.30664724 8.8845
- 6.5569253 4.7948146 0.03662816 -8.704245 6.224871
- -3.2701402 -11.508579 ]
+ demo [ 1.4217498 5.626253 -5.342073 1.1773866 3.308055
+ 1.756596 5.167894 10.80636 -3.8226728 -5.6141334
+ 2.623845 -0.8072968 1.9635103 -7.3128724 0.01103897
+ -9.723131 0.6619743 -6.976803 10.213478 7.494748
+ 2.9105635 3.8949256 3.7999806 7.1061673 16.905321
+ -7.1493764 8.733103 3.4230042 -4.831653 -11.403367
+ 11.232214 7.1274667 -4.2828417 2.452362 -5.130748
+ -18.177666 -2.6116815 -11.000337 -6.7314315 1.6564683
+ 0.7618269 1.1253023 -2.083836 4.725744 -8.782597
+ -3.539873 3.814236 5.1420674 2.162061 4.096431
+ -6.4162116 12.747448 1.9429878 -15.152943 6.417416
+ 16.097002 -9.716668 -1.9920526 -3.3649497 -1.871939
+ 11.567354 3.69788 11.258265 7.442363 9.183411
+ 4.5281515 -1.2417862 4.3959084 6.6727695 5.8898783
+ 7.627124 -0.66919386 -11.889693 -9.208865 -7.4274073
+ -3.7776625 6.917234 -9.848748 -2.0944717 -5.135116
+ 0.49563864 9.317534 -5.9141874 -1.8098574 -0.11738578
+ -7.169265 -1.0578263 -5.7216787 -5.1173844 16.137651
+ -4.473626 7.6624317 -0.55381083 9.631587 -6.4704556
+ -8.548508 4.3716145 -0.79702514 4.478997 -2.9758704
+ 3.272176 2.8382776 5.134597 -9.190781 -0.5657382
+ -4.8745747 2.3165567 -5.984303 -2.1798875 0.35541576
+ -0.31784213 9.493548 2.1144536 4.358092 -12.089823
+ 8.451689 -7.925461 4.6242585 4.4289427 18.692003
+ -2.6204622 -5.149185 -0.35821092 8.488551 4.981496
+ -9.32683 -2.2544234 6.6417594 1.2119585 10.977129
+ 16.555033 3.3238444 9.551863 -1.6676947 -0.79539716
+ -8.605674 -0.47356385 2.6741948 -5.359179 -2.6673796
+ 0.66607 15.443222 4.740594 -3.4725387 11.592567
+ -2.054497 1.7361217 -8.265324 -9.30447 5.4068313
+ -1.5180256 -7.746615 -6.089606 0.07112726 -0.34904733
+ -8.649895 -9.998958 -2.564841 -0.53999114 2.601808
+ -0.31927416 -1.8815292 -2.07215 -3.4105783 -8.2998085
+ 1.483641 -15.365992 -8.288208 3.8847756 -3.4876456
+ 7.3629923 0.4657332 3.132599 12.438889 -1.8337058
+ 4.532936 2.7264361 10.145339 -6.521951 2.897153
+ -3.3925855 5.079156 7.759716 4.677565 5.8457737
+ 2.402413 7.7071047 3.9711342 -6.390043 6.1268735
+ -3.7760346 -11.118123 ]
```
- Python API
@@ -98,53 +104,109 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
ckpt_path=None,
audio_file='./85236145389.wav',
- force_yes=False,
device=paddle.get_device())
print('Audio embedding Result: \n{}'.format(audio_emb))
+
+ test_emb = vector_executor(
+ model='ecapatdnn_voxceleb12',
+ sample_rate=16000,
+ config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
+ ckpt_path=None,
+ audio_file='./123456789.wav',
+ device=paddle.get_device())
+ print('Test embedding Result: \n{}'.format(test_emb))
+
+ # score range [0, 1]
+ score = vector_executor.get_embeddings_score(audio_emb, test_emb)
+ print(f"Eembeddings Score: {score}")
```
输出:
```bash
# Vector Result:
- [ -5.749211 9.505463 -8.200284 -5.2075014 5.3940268
- -3.04878 1.611095 10.127234 -10.534177 -15.821609
- 1.2032688 -0.35080156 1.2629458 -12.643498 -2.5758228
- -11.343508 2.3385992 -8.719341 14.213509 15.404744
- -0.39327756 6.338786 2.688887 8.7104025 17.469526
- -8.77959 7.0576906 4.648855 -1.3089896 -23.294737
- 8.013747 13.891729 -9.926753 5.655307 -5.9422326
- -22.842539 0.6293588 -18.46266 -10.811862 9.8192625
- 3.0070958 3.8072643 -2.3861165 3.0821571 -14.739942
- 1.7594414 -0.6485091 4.485623 2.0207152 7.264915
- -6.40137 23.63524 2.9711294 -22.708025 9.93719
- 20.354511 -10.324688 -0.700492 -8.783211 -5.27593
- 15.999649 3.3004563 12.747926 15.429879 4.7849145
- 5.6699696 -2.3826702 10.605882 3.9112158 3.1500628
- 15.859915 -2.1832209 -23.908653 -6.4799504 -4.5365124
- -9.224193 14.568347 -10.568833 4.982321 -4.342062
- 0.0914714 12.645902 -5.74285 -3.2141201 -2.7173362
- -6.680575 0.4757669 -5.035051 -6.7964664 16.865469
- -11.54324 7.681869 0.44475392 9.708182 -8.932846
- 0.4123232 -4.361452 1.3948607 9.511665 0.11667654
- 2.9079323 6.049952 9.275183 -18.078873 6.2983274
- -0.7500531 -2.725033 -7.6027865 3.3404543 2.990815
- 4.010979 11.000591 -2.8873312 7.1352735 -16.79663
- 18.495346 -14.293832 7.89578 2.2714825 22.976387
- -4.875734 -3.0836344 -2.9999814 13.751918 6.448228
- -11.924197 2.171869 2.0423572 -6.173772 10.778437
- 25.77281 -4.9495463 14.57806 0.3044315 2.6132357
- -7.591999 -2.076944 9.025118 1.7834753 -3.1799617
- -4.9401326 23.465864 5.1685796 -9.018578 9.037825
- -4.4150195 6.859591 -12.274467 -0.88911164 5.186309
- -3.9988663 -13.638606 -9.925445 -0.06329413 -3.6709652
- -12.397416 -12.719869 -1.395601 2.1150916 5.7381287
- -4.4691963 -3.82819 -0.84233856 -1.1604277 -13.490127
- 8.731719 -20.778936 -11.495662 5.8033476 -4.752041
- 10.833007 -6.717991 4.504732 13.4244375 1.1306485
- 7.3435574 1.400918 14.704036 -9.501399 7.2315617
- -6.417456 1.3333273 11.872697 -0.30664724 8.8845
- 6.5569253 4.7948146 0.03662816 -8.704245 6.224871
- -3.2701402 -11.508579 ]
+ Audio embedding Result:
+ [ 1.4217498 5.626253 -5.342073 1.1773866 3.308055
+ 1.756596 5.167894 10.80636 -3.8226728 -5.6141334
+ 2.623845 -0.8072968 1.9635103 -7.3128724 0.01103897
+ -9.723131 0.6619743 -6.976803 10.213478 7.494748
+ 2.9105635 3.8949256 3.7999806 7.1061673 16.905321
+ -7.1493764 8.733103 3.4230042 -4.831653 -11.403367
+ 11.232214 7.1274667 -4.2828417 2.452362 -5.130748
+ -18.177666 -2.6116815 -11.000337 -6.7314315 1.6564683
+ 0.7618269 1.1253023 -2.083836 4.725744 -8.782597
+ -3.539873 3.814236 5.1420674 2.162061 4.096431
+ -6.4162116 12.747448 1.9429878 -15.152943 6.417416
+ 16.097002 -9.716668 -1.9920526 -3.3649497 -1.871939
+ 11.567354 3.69788 11.258265 7.442363 9.183411
+ 4.5281515 -1.2417862 4.3959084 6.6727695 5.8898783
+ 7.627124 -0.66919386 -11.889693 -9.208865 -7.4274073
+ -3.7776625 6.917234 -9.848748 -2.0944717 -5.135116
+ 0.49563864 9.317534 -5.9141874 -1.8098574 -0.11738578
+ -7.169265 -1.0578263 -5.7216787 -5.1173844 16.137651
+ -4.473626 7.6624317 -0.55381083 9.631587 -6.4704556
+ -8.548508 4.3716145 -0.79702514 4.478997 -2.9758704
+ 3.272176 2.8382776 5.134597 -9.190781 -0.5657382
+ -4.8745747 2.3165567 -5.984303 -2.1798875 0.35541576
+ -0.31784213 9.493548 2.1144536 4.358092 -12.089823
+ 8.451689 -7.925461 4.6242585 4.4289427 18.692003
+ -2.6204622 -5.149185 -0.35821092 8.488551 4.981496
+ -9.32683 -2.2544234 6.6417594 1.2119585 10.977129
+ 16.555033 3.3238444 9.551863 -1.6676947 -0.79539716
+ -8.605674 -0.47356385 2.6741948 -5.359179 -2.6673796
+ 0.66607 15.443222 4.740594 -3.4725387 11.592567
+ -2.054497 1.7361217 -8.265324 -9.30447 5.4068313
+ -1.5180256 -7.746615 -6.089606 0.07112726 -0.34904733
+ -8.649895 -9.998958 -2.564841 -0.53999114 2.601808
+ -0.31927416 -1.8815292 -2.07215 -3.4105783 -8.2998085
+ 1.483641 -15.365992 -8.288208 3.8847756 -3.4876456
+ 7.3629923 0.4657332 3.132599 12.438889 -1.8337058
+ 4.532936 2.7264361 10.145339 -6.521951 2.897153
+ -3.3925855 5.079156 7.759716 4.677565 5.8457737
+ 2.402413 7.7071047 3.9711342 -6.390043 6.1268735
+ -3.7760346 -11.118123 ]
+ # get the test embedding
+ Test embedding Result:
+ [ -1.902964 2.0690894 -8.034194 3.5472693 0.18089125
+ 6.9085927 1.4097427 -1.9487704 -10.021278 -0.20755845
+ -8.04332 4.344489 2.3200977 -14.306299 5.184692
+ -11.55602 -3.8497238 0.6444722 1.2833948 2.6766639
+ 0.5878921 0.7946299 1.7207596 2.5791872 14.998469
+ -1.3385371 15.031221 -0.8006958 1.99287 -9.52007
+ 2.435466 4.003221 -4.33817 -4.898601 -5.304714
+ -18.033886 10.790787 -12.784645 -5.641755 2.9761686
+ -10.566622 1.4839455 6.152458 -5.7195854 2.8603241
+ 6.112133 8.489869 5.5958056 1.2836679 -1.2293907
+ 0.89927405 7.0288725 -2.854029 -0.9782962 5.8255906
+ 14.905906 -5.025907 0.7866458 -4.2444224 -16.354029
+ 10.521315 0.9604709 -3.3257897 7.144871 -13.592733
+ -8.568869 -1.7953678 0.26313916 10.916714 -6.9374123
+ 1.857403 -6.2746415 2.8154466 -7.2338667 -2.293357
+ -0.05452765 5.4287076 5.0849075 -6.690375 -1.6183422
+ 3.654291 0.94352573 -9.200294 -5.4749465 -3.5235846
+ 1.3420814 4.240421 -2.772944 -2.8451524 16.311104
+ 4.2969875 -1.762936 -12.5758915 8.595198 -0.8835239
+ -1.5708797 1.568961 1.1413603 3.5032008 -0.45251232
+ -6.786333 16.89443 5.3366146 -8.789056 0.6355629
+ 3.2579517 -3.328322 7.5969577 0.66025066 -6.550468
+ -9.148656 2.020372 -0.4615173 1.1965656 -3.8764873
+ 11.6562195 -6.0750933 12.182899 3.2218833 0.81969476
+ 5.570001 -3.8459578 -7.205299 7.9262037 -7.6611166
+ -5.249467 -2.2671914 7.2658715 -13.298164 4.821147
+ -2.7263982 11.691089 -3.8918593 -2.838112 -1.0336838
+ -3.8034165 2.8536487 -5.60398 -1.1972581 1.3455094
+ -3.4903061 2.2408795 5.5010734 -3.970756 11.99696
+ -7.8858757 0.43160373 -5.5059714 4.3426995 16.322706
+ 11.635366 0.72157705 -9.245714 -3.91465 -4.449838
+ -1.5716927 7.713747 -2.2430465 -6.198303 -13.481864
+ 2.8156567 -5.7812386 5.1456156 2.7289324 -14.505571
+ 13.270688 3.448231 -7.0659585 4.5886116 -4.466099
+ -0.296428 -11.463529 -2.6076477 14.110243 -6.9725137
+ -1.9962958 2.7119343 19.391657 0.01961198 14.607133
+ -1.6695905 -4.391516 1.3131028 -6.670972 -5.888604
+ 12.0612335 5.9285784 3.3715196 1.492534 10.723728
+ -0.95514804 -12.085431 ]
+ # get the score between enroll and test
+ Eembeddings Score: 0.4292638301849365
```
### 4.预训练模型
diff --git a/demos/speaker_verification/run.sh b/demos/speaker_verification/run.sh
index 856886d33..6140f7f38 100644
--- a/demos/speaker_verification/run.sh
+++ b/demos/speaker_verification/run.sh
@@ -1,6 +1,9 @@
#!/bin/bash
wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
+wget -c https://paddlespeech.bj.bcebos.com/vector/audio/123456789.wav
-# asr
-paddlespeech vector --task spk --input ./85236145389.wav
\ No newline at end of file
+# vector
+paddlespeech vector --task spk --input ./85236145389.wav
+
+paddlespeech vector --task score --input "./85236145389.wav ./123456789.wav"
diff --git a/docs/source/released_model.md b/docs/source/released_model.md
index 9a423e03e..1cbe39895 100644
--- a/docs/source/released_model.md
+++ b/docs/source/released_model.md
@@ -6,7 +6,7 @@
### Speech Recognition Model
Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER | Hours of speech | Example Link
:-------------:| :------------:| :-----: | -----: | :-----: |:-----:| :-----: | :-----: | :-----:
-[Ds2 Online Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz) | Aishell Dataset | Char-based | 345 MB | 2 Conv + 5 LSTM layers with only forward direction | 0.080 |-| 151 h | [D2 Online Aishell ASR0](../../examples/aishell/asr0)
+[Ds2 Online Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz) | Aishell Dataset | Char-based | 345 MB | 2 Conv + 5 LSTM layers with only forward direction | 0.078 |-| 151 h | [D2 Online Aishell ASR0](../../examples/aishell/asr0)
[Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz)| Aishell Dataset | Char-based | 306 MB | 2 Conv + 3 bidirectional GRU layers| 0.064 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0)
[Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.1.2.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0565 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1)
[Conformer Offline Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_0.1.2.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0483 |-| 151 h | [Conformer Offline Aishell ASR1](../../examples/aishell/asr1)
@@ -37,8 +37,8 @@ Model Type | Dataset| Example Link | Pretrained Models|Static Models|Size (stati
Tacotron2|LJSpeech|[tacotron2-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts0)|[tacotron2_ljspeech_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.2.0.zip)|||
Tacotron2|CSMSC|[tacotron2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts0)|[tacotron2_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip)|[tacotron2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_static_0.2.0.zip)|103MB|
TransformerTTS| LJSpeech| [transformer-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts1)|[transformer_tts_ljspeech_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/transformer_tts/transformer_tts_ljspeech_ckpt_0.4.zip)|||
-SpeedySpeech| CSMSC | [speedyspeech-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts2) |[speedyspeech_nosil_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_ckpt_0.5.zip)|[speedyspeech_nosil_baker_static_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_static_0.5.zip)|12MB|
-FastSpeech2| CSMSC |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3)|[fastspeech2_nosil_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip)|[fastspeech2_nosil_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_static_0.4.zip)|157MB|
+SpeedySpeech| CSMSC | [speedyspeech-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts2) |[speedyspeech_nosil_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_ckpt_0.5.zip)|[speedyspeech_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_static_0.2.0.zip)|12MB|
+FastSpeech2| CSMSC |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3)|[fastspeech2_nosil_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip)|[fastspeech2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_static_0.2.0.zip)|157MB|
FastSpeech2-Conformer| CSMSC |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3)|[fastspeech2_conformer_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip)|||
FastSpeech2| AISHELL-3 |[fastspeech2-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/tts3)|[fastspeech2_nosil_aishell3_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_ckpt_0.4.zip)|||
FastSpeech2| LJSpeech |[fastspeech2-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts3)|[fastspeech2_nosil_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip)|||
@@ -80,7 +80,7 @@ PANN | ESC-50 |[pann-esc50](../../examples/esc50/cls0)|[esc50_cnn6.tar.gz](https
Model Type | Dataset| Example Link | Pretrained Models | Static Models
:-------------:| :------------:| :-----: | :-----: | :-----:
-PANN | VoxCeleb| [voxceleb_ecapatdnn](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/voxceleb/sv0) | [ecapatdnn.tar.gz](https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz) | -
+PANN | VoxCeleb| [voxceleb_ecapatdnn](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/voxceleb/sv0) | [ecapatdnn.tar.gz](https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz) | -
## Punctuation Restoration Models
Model Type | Dataset| Example Link | Pretrained Models
diff --git a/examples/aishell/asr0/README.md b/examples/aishell/asr0/README.md
index bb45d8df0..4459b1382 100644
--- a/examples/aishell/asr0/README.md
+++ b/examples/aishell/asr0/README.md
@@ -151,21 +151,14 @@ avg.sh best exp/deepspeech2/checkpoints 1
CUDA_VISIBLE_DEVICES= ./local/test.sh conf/deepspeech2.yaml exp/deepspeech2/checkpoints/avg_1
```
## Pretrained Model
-You can get the pretrained transformer or conformer using the scripts below:
-```bash
-Deepspeech2 offline:
-wget https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/ds2.model.tar.gz
-
-Deepspeech2 online:
-wget https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/aishell_ds2_online_cer8.00_release.tar.gz
+You can get the pretrained models from [this](../../../docs/source/released_model.md).
-```
using the `tar` scripts to unpack the model and then you can use the script to test the model.
For example:
```
-wget https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/ds2.model.tar.gz
-tar xzvf ds2.model.tar.gz
+wget https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz
+tar xzvf asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz
source path.sh
# If you have process the data and get the manifest file, you can skip the following 2 steps
bash local/data.sh --stage -1 --stop_stage -1
@@ -173,12 +166,7 @@ bash local/data.sh --stage 2 --stop_stage 2
CUDA_VISIBLE_DEVICES= ./local/test.sh conf/deepspeech2.yaml exp/deepspeech2/checkpoints/avg_1
```
-The performance of the released models are shown below:
-
-| Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER | Hours of speech |
-| :----------------------------: | :-------------: | :---------: | -----: | :------------------------------------------------- | :---- | :--- | :-------------- |
-| Ds2 Online Aishell ASR0 Model | Aishell Dataset | Char-based | 345 MB | 2 Conv + 5 LSTM layers with only forward direction | 0.080 | - | 151 h |
-| Ds2 Offline Aishell ASR0 Model | Aishell Dataset | Char-based | 306 MB | 2 Conv + 3 bidirectional GRU layers | 0.064 | - | 151 h |
+The performance of the released models are shown in [this](./RESULTS.md)
## Stage 4: Static graph model Export
This stage is to transform dygraph to static graph.
```bash
@@ -214,8 +202,8 @@ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
```
you can train the model by yourself, or you can download the pretrained model by the script below:
```bash
-wget https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/ds2.model.tar.gz
-tar xzvf ds2.model.tar.gz
+wget https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz
+tar xzvf asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz
```
You can download the audio demo:
```bash
diff --git a/examples/aishell/asr0/RESULTS.md b/examples/aishell/asr0/RESULTS.md
index 5841a8522..8af3d66d1 100644
--- a/examples/aishell/asr0/RESULTS.md
+++ b/examples/aishell/asr0/RESULTS.md
@@ -4,15 +4,16 @@
| Model | Number of Params | Release | Config | Test set | Valid Loss | CER |
| --- | --- | --- | --- | --- | --- | --- |
-| DeepSpeech2 | 45.18M | 2.2.0 | conf/deepspeech2_online.yaml + spec aug | test | 7.994938373565674 | 0.080 |
+| DeepSpeech2 | 45.18M | r0.2.0 | conf/deepspeech2_online.yaml + spec aug | test | 7.708217620849609| 0.078 |
+| DeepSpeech2 | 45.18M | v2.2.0 | conf/deepspeech2_online.yaml + spec aug | test | 7.994938373565674 | 0.080 |
## Deepspeech2 Non-Streaming
| Model | Number of Params | Release | Config | Test set | Valid Loss | CER |
| --- | --- | --- | --- | --- | --- | --- |
-| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug | test | 5.738585948944092 | 0.064000 |
-| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 |
-| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 |
-| DeepSpeech2 | 58.4M | 2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 |
+| DeepSpeech2 | 58.4M | v2.2.0 | conf/deepspeech2.yaml + spec aug | test | 5.738585948944092 | 0.064000 |
+| DeepSpeech2 | 58.4M | v2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 |
+| DeepSpeech2 | 58.4M | v2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 |
+| DeepSpeech2 | 58.4M | v2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 |
| --- | --- | --- | --- | --- | --- | --- |
-| DeepSpeech2 | 58.4M | 1.8.5 | - | test | - | 0.080447 |
+| DeepSpeech2 | 58.4M | v1.8.5 | - | test | - | 0.080447 |
diff --git a/examples/aishell/asr1/README.md b/examples/aishell/asr1/README.md
index 5277a31eb..25b28ede8 100644
--- a/examples/aishell/asr1/README.md
+++ b/examples/aishell/asr1/README.md
@@ -143,25 +143,14 @@ avg.sh best exp/conformer/checkpoints 20
CUDA_VISIBLE_DEVICES= ./local/test.sh conf/conformer.yaml exp/conformer/checkpoints/avg_20
```
## Pretrained Model
-You can get the pretrained transformer or conformer using the scripts below:
+You can get the pretrained transformer or conformer from [this](../../../docs/source/released_model.md)
-```bash
-# Conformer:
-wget https://deepspeech.bj.bcebos.com/release2.1/aishell/s1/aishell.release.tar.gz
-
-# Chunk Conformer:
-wget https://deepspeech.bj.bcebos.com/release2.1/aishell/s1/aishell.chunk.release.tar.gz
-
-# Transformer:
-wget https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz
-
-```
using the `tar` scripts to unpack the model and then you can use the script to test the model.
For example:
```
-wget https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz
-tar xzvf transformer.model.tar.gz
+wget https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz
+tar xzvf asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz
source path.sh
# If you have process the data and get the manifest file, you can skip the following 2 steps
bash local/data.sh --stage -1 --stop_stage -1
@@ -206,7 +195,7 @@ In some situations, you want to use the trained model to do the inference for th
```
you can train the model by yourself using ```bash run.sh --stage 0 --stop_stage 3```, or you can download the pretrained model through the script below:
```bash
-wget https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz
+wget https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz
tar xzvf transformer.model.tar.gz
```
You can download the audio demo:
diff --git a/examples/aishell3/vc0/README.md b/examples/aishell3/vc0/README.md
index 664ec1ac3..925663ab1 100644
--- a/examples/aishell3/vc0/README.md
+++ b/examples/aishell3/vc0/README.md
@@ -118,7 +118,7 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/voice_cloning.sh ${conf_path} ${train_outpu
```
## Pretrained Model
-[tacotron2_aishell3_ckpt_vc0_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_aishell3_ckpt_vc0_0.2.0.zip)
+- [tacotron2_aishell3_ckpt_vc0_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_aishell3_ckpt_vc0_0.2.0.zip)
Model | Step | eval/loss | eval/l1_loss | eval/mse_loss | eval/bce_loss| eval/attn_loss
diff --git a/examples/aishell3/vc1/README.md b/examples/aishell3/vc1/README.md
index 04b83a5ff..8ab0f9c8c 100644
--- a/examples/aishell3/vc1/README.md
+++ b/examples/aishell3/vc1/README.md
@@ -119,7 +119,7 @@ ref_audio
CUDA_VISIBLE_DEVICES=${gpus} ./local/voice_cloning.sh ${conf_path} ${train_output_path} ${ckpt_name} ${ge2e_params_path} ${ref_audio_dir}
```
## Pretrained Model
-[fastspeech2_nosil_aishell3_vc1_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_vc1_ckpt_0.5.zip)
+- [fastspeech2_nosil_aishell3_vc1_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_vc1_ckpt_0.5.zip)
Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/pitch_loss| eval/energy_loss
:-------------:| :------------:| :-----: | :-----: | :--------: |:--------:|:---------:
diff --git a/examples/aishell3/voc1/README.md b/examples/aishell3/voc1/README.md
index dad464092..eb30e7c40 100644
--- a/examples/aishell3/voc1/README.md
+++ b/examples/aishell3/voc1/README.md
@@ -137,7 +137,8 @@ optional arguments:
5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
## Pretrained Models
-Pretrained models can be downloaded here [pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip).
+Pretrained models can be downloaded here:
+- [pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip)
Model | Step | eval/generator_loss | eval/log_stft_magnitude_loss:| eval/spectral_convergence_loss
:-------------:| :------------:| :-----: | :-----: | :--------:
diff --git a/examples/aishell3/voc5/README.md b/examples/aishell3/voc5/README.md
index ebe2530be..c957c4a3a 100644
--- a/examples/aishell3/voc5/README.md
+++ b/examples/aishell3/voc5/README.md
@@ -136,7 +136,8 @@ optional arguments:
4. `--output-dir` is the directory to save the synthesized audio files.
5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
## Pretrained Models
-The pretrained model can be downloaded here [hifigan_aishell3_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip).
+The pretrained model can be downloaded here:
+- [hifigan_aishell3_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip)
Model | Step | eval/generator_loss | eval/mel_loss| eval/feature_matching_loss
diff --git a/examples/ami/sd0/conf/ecapa_tdnn.yaml b/examples/ami/sd0/conf/ecapa_tdnn.yaml
new file mode 100755
index 000000000..319e44976
--- /dev/null
+++ b/examples/ami/sd0/conf/ecapa_tdnn.yaml
@@ -0,0 +1,62 @@
+###########################################################
+# AMI DATA PREPARE SETTING #
+###########################################################
+split_type: 'full_corpus_asr'
+skip_TNO: True
+# Options for mic_type: 'Mix-Lapel', 'Mix-Headset', 'Array1', 'Array1-01', 'BeamformIt'
+mic_type: 'Mix-Headset'
+vad_type: 'oracle'
+max_subseg_dur: 3.0
+overlap: 1.5
+# Some more exp folders (for cleaner structure).
+embedding_dir: emb #!ref /emb
+meta_data_dir: metadata #!ref /metadata
+ref_rttm_dir: ref_rttms #!ref /ref_rttms
+sys_rttm_dir: sys_rttms #!ref /sys_rttms
+der_dir: DER #!ref /DER
+
+
+###########################################################
+# FEATURE EXTRACTION SETTING #
+###########################################################
+# currently, we only support fbank
+sr: 16000 # sample rate
+n_mels: 80
+window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400
+hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160
+#left_frames: 0
+#right_frames: 0
+#deltas: False
+
+
+###########################################################
+# MODEL SETTING #
+###########################################################
+# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml
+# if we want use another model, please choose another configuration yaml file
+seed: 1234
+emb_dim: 192
+batch_size: 16
+model:
+ input_size: 80
+ channels: [1024, 1024, 1024, 1024, 3072]
+ kernel_sizes: [5, 3, 3, 3, 1]
+ dilations: [1, 2, 3, 4, 1]
+ attention_channels: 128
+ lin_neurons: 192
+# Will automatically download ECAPA-TDNN model (best).
+
+###########################################################
+# SPECTRAL CLUSTERING SETTING #
+###########################################################
+backend: 'SC' # options: 'kmeans' # Note: kmeans goes only with cos affinity
+affinity: 'cos' # options: cos, nn
+max_num_spkrs: 10
+oracle_n_spkrs: True
+
+
+###########################################################
+# DER EVALUATION SETTING #
+###########################################################
+ignore_overlap: True
+forgiveness_collar: 0.25
diff --git a/examples/ami/sd0/local/compute_embdding.py b/examples/ami/sd0/local/compute_embdding.py
new file mode 100644
index 000000000..dc824d7ca
--- /dev/null
+++ b/examples/ami/sd0/local/compute_embdding.py
@@ -0,0 +1,231 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import json
+import os
+import pickle
+import sys
+
+import numpy as np
+import paddle
+from paddle.io import BatchSampler
+from paddle.io import DataLoader
+from tqdm.contrib import tqdm
+from yacs.config import CfgNode
+
+from paddlespeech.s2t.utils.log import Log
+from paddlespeech.vector.cluster.diarization import EmbeddingMeta
+from paddlespeech.vector.io.batch import batch_feature_normalize
+from paddlespeech.vector.io.dataset_from_json import JSONDataset
+from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
+from paddlespeech.vector.modules.sid_model import SpeakerIdetification
+from paddlespeech.vector.training.seeding import seed_everything
+
+# Logger setup
+logger = Log(__name__).getlog()
+
+
+def prepare_subset_json(full_meta_data, rec_id, out_meta_file):
+ """Prepares metadata for a given recording ID.
+
+ Arguments
+ ---------
+ full_meta_data : json
+ Full meta (json) containing all the recordings
+ rec_id : str
+ The recording ID for which meta (json) has to be prepared
+ out_meta_file : str
+ Path of the output meta (json) file.
+ """
+
+ subset = {}
+ for key in full_meta_data:
+ k = str(key)
+ if k.startswith(rec_id):
+ subset[key] = full_meta_data[key]
+
+ with open(out_meta_file, mode="w") as json_f:
+ json.dump(subset, json_f, indent=2)
+
+
+def create_dataloader(json_file, batch_size):
+ """Creates the datasets and their data processing pipelines.
+ This is used for multi-mic processing.
+ """
+
+ # create datasets
+ dataset = JSONDataset(
+ json_file=json_file,
+ feat_type='melspectrogram',
+ n_mels=config.n_mels,
+ window_size=config.window_size,
+ hop_length=config.hop_size)
+
+ # create dataloader
+ batch_sampler = BatchSampler(dataset, batch_size=batch_size, shuffle=True)
+ dataloader = DataLoader(dataset,
+ batch_sampler=batch_sampler,
+ collate_fn=lambda x: batch_feature_normalize(
+ x, mean_norm=True, std_norm=False),
+ return_list=True)
+
+ return dataloader
+
+
+def main(args, config):
+ # set the training device, cpu or gpu
+ paddle.set_device(args.device)
+ # set the random seed
+ seed_everything(config.seed)
+
+ # stage1: build the dnn backbone model network
+ ecapa_tdnn = EcapaTdnn(**config.model)
+
+ # stage2: build the speaker verification eval instance with backbone model
+ model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=1)
+
+ # stage3: load the pre-trained model
+ # we get the last model from the epoch and save_interval
+ args.load_checkpoint = os.path.abspath(
+ os.path.expanduser(args.load_checkpoint))
+
+ # load model checkpoint to sid model
+ state_dict = paddle.load(
+ os.path.join(args.load_checkpoint, 'model.pdparams'))
+ model.set_state_dict(state_dict)
+ logger.info(f'Checkpoint loaded from {args.load_checkpoint}')
+
+ # set the model to eval mode
+ model.eval()
+
+ # load meta data
+ meta_file = os.path.join(
+ args.data_dir,
+ config.meta_data_dir,
+ "ami_" + args.dataset + "." + config.mic_type + ".subsegs.json", )
+ with open(meta_file, "r") as f:
+ full_meta = json.load(f)
+
+ # get all the recording IDs in this dataset.
+ all_keys = full_meta.keys()
+ A = [word.rstrip().split("_")[0] for word in all_keys]
+ all_rec_ids = list(set(A[1:]))
+ all_rec_ids.sort()
+ split = "AMI_" + args.dataset
+ i = 1
+
+ msg = "Extra embdding for " + args.dataset + " set"
+ logger.info(msg)
+
+ if len(all_rec_ids) <= 0:
+ msg = "No recording IDs found! Please check if meta_data json file is properly generated."
+ logger.error(msg)
+ sys.exit()
+
+ # extra different recordings embdding in a dataset.
+ for rec_id in tqdm(all_rec_ids):
+ # This tag will be displayed in the log.
+ tag = ("[" + str(args.dataset) + ": " + str(i) + "/" +
+ str(len(all_rec_ids)) + "]")
+ i = i + 1
+
+ # log message.
+ msg = "Embdding %s : %s " % (tag, rec_id)
+ logger.debug(msg)
+
+ # embedding directory.
+ if not os.path.exists(
+ os.path.join(args.data_dir, config.embedding_dir, split)):
+ os.makedirs(
+ os.path.join(args.data_dir, config.embedding_dir, split))
+
+ # file to store embeddings.
+ emb_file_name = rec_id + "." + config.mic_type + ".emb_stat.pkl"
+ diary_stat_emb_file = os.path.join(args.data_dir, config.embedding_dir,
+ split, emb_file_name)
+
+ # prepare a metadata (json) for one recording. This is basically a subset of full_meta.
+ # lets keep this meta-info in embedding directory itself.
+ json_file_name = rec_id + "." + config.mic_type + ".json"
+ meta_per_rec_file = os.path.join(args.data_dir, config.embedding_dir,
+ split, json_file_name)
+
+ # write subset (meta for one recording) json metadata.
+ prepare_subset_json(full_meta, rec_id, meta_per_rec_file)
+
+ # prepare data loader.
+ diary_set_loader = create_dataloader(meta_per_rec_file,
+ config.batch_size)
+
+ # extract embeddings (skip if already done).
+ if not os.path.isfile(diary_stat_emb_file):
+ logger.debug("Extracting deep embeddings")
+ embeddings = np.empty(shape=[0, config.emb_dim], dtype=np.float64)
+ segset = []
+
+ for batch_idx, batch in enumerate(tqdm(diary_set_loader)):
+ # extrac the audio embedding
+ ids, feats, lengths = batch['ids'], batch['feats'], batch[
+ 'lengths']
+ seg = [x for x in ids]
+ segset = segset + seg
+ emb = model.backbone(feats, lengths).squeeze(
+ -1).numpy() # (N, emb_size, 1) -> (N, emb_size)
+ embeddings = np.concatenate((embeddings, emb), axis=0)
+
+ segset = np.array(segset, dtype="|O")
+ stat_obj = EmbeddingMeta(
+ segset=segset,
+ stats=embeddings, )
+ logger.debug("Saving Embeddings...")
+ with open(diary_stat_emb_file, "wb") as output:
+ pickle.dump(stat_obj, output)
+
+ else:
+ logger.debug("Skipping embedding extraction (as already present).")
+
+
+# Begin experiment!
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(__doc__)
+ parser.add_argument(
+ '--device',
+ default="gpu",
+ help="Select which device to perform diarization, defaults to gpu.")
+ parser.add_argument(
+ "--config", default=None, type=str, help="configuration file")
+ parser.add_argument(
+ "--data-dir",
+ default="../save/",
+ type=str,
+ help="processsed data directory")
+ parser.add_argument(
+ "--dataset",
+ choices=['dev', 'eval'],
+ default="dev",
+ type=str,
+ help="Select which dataset to extra embdding, defaults to dev")
+ parser.add_argument(
+ "--load-checkpoint",
+ type=str,
+ default='',
+ help="Directory to load model checkpoint to compute embeddings.")
+ args = parser.parse_args()
+ config = CfgNode(new_allowed=True)
+ if args.config:
+ config.merge_from_file(args.config)
+
+ config.freeze()
+
+ main(args, config)
diff --git a/examples/ami/sd0/local/data.sh b/examples/ami/sd0/local/data.sh
deleted file mode 100755
index 478ec432d..000000000
--- a/examples/ami/sd0/local/data.sh
+++ /dev/null
@@ -1,49 +0,0 @@
-#!/bin/bash
-
-stage=1
-
-TARGET_DIR=${MAIN_ROOT}/dataset/ami
-data_folder=${TARGET_DIR}/amicorpus #e.g., /path/to/amicorpus/
-manual_annot_folder=${TARGET_DIR}/ami_public_manual_1.6.2 #e.g., /path/to/ami_public_manual_1.6.2/
-
-save_folder=${MAIN_ROOT}/examples/ami/sd0/data
-ref_rttm_dir=${save_folder}/ref_rttms
-meta_data_dir=${save_folder}/metadata
-
-set=L
-
-. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
-set -u
-set -o pipefail
-
-mkdir -p ${save_folder}
-
-if [ ${stage} -le 0 ]; then
- # Download AMI corpus, You need around 10GB of free space to get whole data
- # The signals are too large to package in this way,
- # so you need to use the chooser to indicate which ones you wish to download
- echo "Please follow https://groups.inf.ed.ac.uk/ami/download/ to download the data."
- echo "Annotations: AMI manual annotations v1.6.2 "
- echo "Signals: "
- echo "1) Select one or more AMI meetings: the IDs please follow ./ami_split.py"
- echo "2) Select media streams: Just select Headset mix"
- exit 0;
-fi
-
-if [ ${stage} -le 1 ]; then
- echo "AMI Data preparation"
-
- python local/ami_prepare.py --data_folder ${data_folder} \
- --manual_annot_folder ${manual_annot_folder} \
- --save_folder ${save_folder} --ref_rttm_dir ${ref_rttm_dir} \
- --meta_data_dir ${meta_data_dir}
-
- if [ $? -ne 0 ]; then
- echo "Prepare AMI failed. Please check log message."
- exit 1
- fi
-
-fi
-
-echo "AMI data preparation done."
-exit 0
diff --git a/examples/ami/sd0/local/experiment.py b/examples/ami/sd0/local/experiment.py
new file mode 100755
index 000000000..298228376
--- /dev/null
+++ b/examples/ami/sd0/local/experiment.py
@@ -0,0 +1,428 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import glob
+import json
+import os
+import pickle
+import shutil
+import sys
+
+import numpy as np
+from tqdm.contrib import tqdm
+from yacs.config import CfgNode
+
+from paddlespeech.s2t.utils.log import Log
+from paddlespeech.vector.cluster import diarization as diar
+from utils.DER import DER
+
+# Logger setup
+logger = Log(__name__).getlog()
+
+
+def diarize_dataset(
+ full_meta,
+ split_type,
+ n_lambdas,
+ pval,
+ save_dir,
+ config,
+ n_neighbors=10, ):
+ """This function diarizes all the recordings in a given dataset. It performs
+ computation of embedding and clusters them using spectral clustering (or other backends).
+ The output speaker boundary file is stored in the RTTM format.
+ """
+
+ # prepare `spkr_info` only once when Oracle num of speakers is selected.
+ # spkr_info is essential to obtain number of speakers from groundtruth.
+ if config.oracle_n_spkrs is True:
+ full_ref_rttm_file = os.path.join(save_dir, config.ref_rttm_dir,
+ "fullref_ami_" + split_type + ".rttm")
+ rttm = diar.read_rttm(full_ref_rttm_file)
+
+ spkr_info = list( # noqa F841
+ filter(lambda x: x.startswith("SPKR-INFO"), rttm))
+
+ # get all the recording IDs in this dataset.
+ all_keys = full_meta.keys()
+ A = [word.rstrip().split("_")[0] for word in all_keys]
+ all_rec_ids = list(set(A[1:]))
+ all_rec_ids.sort()
+ split = "AMI_" + split_type
+ i = 1
+
+ # adding tag for directory path.
+ type_of_num_spkr = "oracle" if config.oracle_n_spkrs else "est"
+ tag = (type_of_num_spkr + "_" + str(config.affinity) + "_" + config.backend)
+
+ # make out rttm dir
+ out_rttm_dir = os.path.join(save_dir, config.sys_rttm_dir, config.mic_type,
+ split, tag)
+ if not os.path.exists(out_rttm_dir):
+ os.makedirs(out_rttm_dir)
+
+ # diarizing different recordings in a dataset.
+ for rec_id in tqdm(all_rec_ids):
+ # this tag will be displayed in the log.
+ tag = ("[" + str(split_type) + ": " + str(i) + "/" +
+ str(len(all_rec_ids)) + "]")
+ i = i + 1
+
+ # log message.
+ msg = "Diarizing %s : %s " % (tag, rec_id)
+ logger.debug(msg)
+
+ # load embeddings.
+ emb_file_name = rec_id + "." + config.mic_type + ".emb_stat.pkl"
+ diary_stat_emb_file = os.path.join(save_dir, config.embedding_dir,
+ split, emb_file_name)
+ if not os.path.isfile(diary_stat_emb_file):
+ msg = "Embdding file %s not found! Please check if embdding file is properly generated." % (
+ diary_stat_emb_file)
+ logger.error(msg)
+ sys.exit()
+ with open(diary_stat_emb_file, "rb") as in_file:
+ diary_obj = pickle.load(in_file)
+
+ out_rttm_file = out_rttm_dir + "/" + rec_id + ".rttm"
+
+ # processing starts from here.
+ if config.oracle_n_spkrs is True:
+ # oracle num of speakers.
+ num_spkrs = diar.get_oracle_num_spkrs(rec_id, spkr_info)
+ else:
+ if config.affinity == "nn":
+ # num of speakers tunned on dev set (only for nn affinity).
+ num_spkrs = n_lambdas
+ else:
+ # num of speakers will be estimated using max eigen gap for cos based affinity.
+ # so adding None here. Will use this None later-on.
+ num_spkrs = None
+
+ if config.backend == "kmeans":
+ diar.do_kmeans_clustering(
+ diary_obj,
+ out_rttm_file,
+ rec_id,
+ num_spkrs,
+ pval, )
+
+ if config.backend == "SC":
+ # go for Spectral Clustering (SC).
+ diar.do_spec_clustering(
+ diary_obj,
+ out_rttm_file,
+ rec_id,
+ num_spkrs,
+ pval,
+ config.affinity,
+ n_neighbors, )
+
+ # can used for AHC later. Likewise one can add different backends here.
+ if config.backend == "AHC":
+ # call AHC
+ threshold = pval # pval for AHC is nothing but threshold.
+ diar.do_AHC(diary_obj, out_rttm_file, rec_id, num_spkrs, threshold)
+
+ # once all RTTM outputs are generated, concatenate individual RTTM files to obtain single RTTM file.
+ # this is not needed but just staying with the standards.
+ concate_rttm_file = out_rttm_dir + "/sys_output.rttm"
+ logger.debug("Concatenating individual RTTM files...")
+ with open(concate_rttm_file, "w") as cat_file:
+ for f in glob.glob(out_rttm_dir + "/*.rttm"):
+ if f == concate_rttm_file:
+ continue
+ with open(f, "r") as indi_rttm_file:
+ shutil.copyfileobj(indi_rttm_file, cat_file)
+
+ msg = "The system generated RTTM file for %s set : %s" % (
+ split_type, concate_rttm_file, )
+ logger.debug(msg)
+
+ return concate_rttm_file
+
+
+def dev_pval_tuner(full_meta, save_dir, config):
+ """Tuning p_value for affinity matrix.
+ The p_value used so that only p% of the values in each row is retained.
+ """
+
+ DER_list = []
+ prange = np.arange(0.002, 0.015, 0.001)
+
+ n_lambdas = None # using it as flag later.
+ for p_v in prange:
+ # Process whole dataset for value of p_v.
+ concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v,
+ save_dir, config)
+
+ ref_rttm_file = os.path.join(save_dir, config.ref_rttm_dir,
+ "fullref_ami_dev.rttm")
+ sys_rttm_file = concate_rttm_file
+ [MS, FA, SER, DER_] = DER(
+ ref_rttm_file,
+ sys_rttm_file,
+ config.ignore_overlap,
+ config.forgiveness_collar, )
+
+ DER_list.append(DER_)
+
+ if config.oracle_n_spkrs is True and config.backend == "kmeans":
+ # no need of p_val search. Note p_val is needed for SC for both oracle and est num of speakers.
+ # p_val is needed in oracle_n_spkr=False when using kmeans backend.
+ break
+
+ # Take p_val that gave minmum DER on Dev dataset.
+ tuned_p_val = prange[DER_list.index(min(DER_list))]
+
+ return tuned_p_val
+
+
+def dev_ahc_threshold_tuner(full_meta, save_dir, config):
+ """Tuning threshold for affinity matrix. This function is called when AHC is used as backend.
+ """
+
+ DER_list = []
+ prange = np.arange(0.0, 1.0, 0.1)
+
+ n_lambdas = None # using it as flag later.
+
+ # Note: p_val is threshold in case of AHC.
+ for p_v in prange:
+ # Process whole dataset for value of p_v.
+ concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v,
+ save_dir, config)
+
+ ref_rttm = os.path.join(save_dir, config.ref_rttm_dir,
+ "fullref_ami_dev.rttm")
+ sys_rttm = concate_rttm_file
+ [MS, FA, SER, DER_] = DER(
+ ref_rttm,
+ sys_rttm,
+ config.ignore_overlap,
+ config.forgiveness_collar, )
+
+ DER_list.append(DER_)
+
+ if config.oracle_n_spkrs is True:
+ break # no need of threshold search.
+
+ # Take p_val that gave minmum DER on Dev dataset.
+ tuned_p_val = prange[DER_list.index(min(DER_list))]
+
+ return tuned_p_val
+
+
+def dev_nn_tuner(full_meta, split_type, save_dir, config):
+ """Tuning n_neighbors on dev set. Assuming oracle num of speakers.
+ This is used when nn based affinity is selected.
+ """
+
+ DER_list = []
+ pval = None
+
+ # Now assumming oracle num of speakers.
+ n_lambdas = 4
+
+ for nn in range(5, 15):
+
+ # Process whole dataset for value of n_lambdas.
+ concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v,
+ save_dir, config, nn)
+
+ ref_rttm = os.path.join(save_dir, config.ref_rttm_dir,
+ "fullref_ami_dev.rttm")
+ sys_rttm = concate_rttm_file
+ [MS, FA, SER, DER_] = DER(
+ ref_rttm,
+ sys_rttm,
+ config.ignore_overlap,
+ config.forgiveness_collar, )
+
+ DER_list.append([nn, DER_])
+
+ if config.oracle_n_spkrs is True and config.backend == "kmeans":
+ break
+
+ DER_list.sort(key=lambda x: x[1])
+ tunned_nn = DER_list[0]
+
+ return tunned_nn[0]
+
+
+def dev_tuner(full_meta, split_type, save_dir, config):
+ """Tuning n_components on dev set. Used for nn based affinity matrix.
+ Note: This is a very basic tunning for nn based affinity.
+ This is work in progress till we find a better way.
+ """
+
+ DER_list = []
+ pval = None
+ for n_lambdas in range(1, config.max_num_spkrs + 1):
+
+ # Process whole dataset for value of n_lambdas.
+ concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v,
+ save_dir, config)
+
+ ref_rttm = os.path.join(save_dir, config.ref_rttm_dir,
+ "fullref_ami_dev.rttm")
+ sys_rttm = concate_rttm_file
+ [MS, FA, SER, DER_] = DER(
+ ref_rttm,
+ sys_rttm,
+ config.ignore_overlap,
+ config.forgiveness_collar, )
+
+ DER_list.append(DER_)
+
+ # Take n_lambdas with minmum DER.
+ tuned_n_lambdas = DER_list.index(min(DER_list)) + 1
+
+ return tuned_n_lambdas
+
+
+def main(args, config):
+ # AMI Dev Set: Tune hyperparams on dev set.
+ # Read the embdding file for dev set generated during embdding compute
+ dev_meta_file = os.path.join(
+ args.data_dir,
+ config.meta_data_dir,
+ "ami_dev." + config.mic_type + ".subsegs.json", )
+ with open(dev_meta_file, "r") as f:
+ meta_dev = json.load(f)
+
+ full_meta = meta_dev
+
+ # Processing starts from here
+ # Following few lines selects option for different backend and affinity matrices. Finds best values for hyperameters using dev set.
+ ref_rttm_file = os.path.join(args.data_dir, config.ref_rttm_dir,
+ "fullref_ami_dev.rttm")
+ best_nn = None
+ if config.affinity == "nn":
+ logger.info("Tuning for nn (Multiple iterations over AMI Dev set)")
+ best_nn = dev_nn_tuner(full_meta, args.data_dir, config)
+
+ n_lambdas = None
+ best_pval = None
+
+ if config.affinity == "cos" and (config.backend == "SC" or
+ config.backend == "kmeans"):
+ # oracle num_spkrs or not, doesn't matter for kmeans and SC backends
+ # cos: Tune for the best pval for SC /kmeans (for unknown num of spkrs)
+ logger.info(
+ "Tuning for p-value for SC (Multiple iterations over AMI Dev set)")
+ best_pval = dev_pval_tuner(full_meta, args.data_dir, config)
+
+ elif config.backend == "AHC":
+ logger.info("Tuning for threshold-value for AHC")
+ best_threshold = dev_ahc_threshold_tuner(full_meta, args.data_dir,
+ config)
+ best_pval = best_threshold
+ else:
+ # NN for unknown num of speakers (can be used in future)
+ if config.oracle_n_spkrs is False:
+ # nn: Tune num of number of components (to be updated later)
+ logger.info(
+ "Tuning for number of eigen components for NN (Multiple iterations over AMI Dev set)"
+ )
+ # dev_tuner used for tuning num of components in NN. Can be used in future.
+ n_lambdas = dev_tuner(full_meta, args.data_dir, config)
+
+ # load 'dev' and 'eval' metadata files.
+ full_meta_dev = full_meta # current full_meta is for 'dev'
+ eval_meta_file = os.path.join(
+ args.data_dir,
+ config.meta_data_dir,
+ "ami_eval." + config.mic_type + ".subsegs.json", )
+ with open(eval_meta_file, "r") as f:
+ full_meta_eval = json.load(f)
+
+ # tag to be appended to final output DER files. Writing DER for individual files.
+ type_of_num_spkr = "oracle" if config.oracle_n_spkrs else "est"
+ tag = (
+ type_of_num_spkr + "_" + str(config.affinity) + "." + config.mic_type)
+
+ # perform final diarization on 'dev' and 'eval' with best hyperparams.
+ final_DERs = {}
+ out_der_dir = os.path.join(args.data_dir, config.der_dir)
+ if not os.path.exists(out_der_dir):
+ os.makedirs(out_der_dir)
+
+ for split_type in ["dev", "eval"]:
+ if split_type == "dev":
+ full_meta = full_meta_dev
+ else:
+ full_meta = full_meta_eval
+
+ # performing diarization.
+ msg = "Diarizing using best hyperparams: " + split_type + " set"
+ logger.info(msg)
+ out_boundaries = diarize_dataset(
+ full_meta,
+ split_type,
+ n_lambdas=n_lambdas,
+ pval=best_pval,
+ n_neighbors=best_nn,
+ save_dir=args.data_dir,
+ config=config)
+
+ # computing DER.
+ msg = "Computing DERs for " + split_type + " set"
+ logger.info(msg)
+ ref_rttm = os.path.join(args.data_dir, config.ref_rttm_dir,
+ "fullref_ami_" + split_type + ".rttm")
+ sys_rttm = out_boundaries
+ [MS, FA, SER, DER_vals] = DER(
+ ref_rttm,
+ sys_rttm,
+ config.ignore_overlap,
+ config.forgiveness_collar,
+ individual_file_scores=True, )
+
+ # writing DER values to a file. Append tag.
+ der_file_name = split_type + "_DER_" + tag
+ out_der_file = os.path.join(out_der_dir, der_file_name)
+ msg = "Writing DER file to: " + out_der_file
+ logger.info(msg)
+ diar.write_ders_file(ref_rttm, DER_vals, out_der_file)
+
+ msg = ("AMI " + split_type + " set DER = %s %%\n" %
+ (str(round(DER_vals[-1], 2))))
+ logger.info(msg)
+ final_DERs[split_type] = round(DER_vals[-1], 2)
+
+ # final print DERs
+ msg = (
+ "Final Diarization Error Rate (%%) on AMI corpus: Dev = %s %% | Eval = %s %%\n"
+ % (str(final_DERs["dev"]), str(final_DERs["eval"])))
+ logger.info(msg)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(__doc__)
+ parser.add_argument(
+ "--config", default=None, type=str, help="configuration file")
+ parser.add_argument(
+ "--data-dir",
+ default="../data/",
+ type=str,
+ help="processsed data directory")
+ args = parser.parse_args()
+ config = CfgNode(new_allowed=True)
+ if args.config:
+ config.merge_from_file(args.config)
+
+ config.freeze()
+
+ main(args, config)
diff --git a/examples/ami/sd0/local/process.sh b/examples/ami/sd0/local/process.sh
new file mode 100755
index 000000000..1dfd11b86
--- /dev/null
+++ b/examples/ami/sd0/local/process.sh
@@ -0,0 +1,49 @@
+#!/bin/bash
+
+stage=0
+set=L
+
+. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
+set -o pipefail
+
+data_folder=$1
+manual_annot_folder=$2
+save_folder=$3
+pretrained_model_dir=$4
+conf_path=$5
+device=$6
+
+ref_rttm_dir=${save_folder}/ref_rttms
+meta_data_dir=${save_folder}/metadata
+
+if [ ${stage} -le 0 ]; then
+ echo "AMI Data preparation"
+ python local/ami_prepare.py --data_folder ${data_folder} \
+ --manual_annot_folder ${manual_annot_folder} \
+ --save_folder ${save_folder} --ref_rttm_dir ${ref_rttm_dir} \
+ --meta_data_dir ${meta_data_dir}
+
+ if [ $? -ne 0 ]; then
+ echo "Prepare AMI failed. Please check log message."
+ exit 1
+ fi
+ echo "AMI data preparation done."
+fi
+
+if [ ${stage} -le 1 ]; then
+ # extra embddings for dev and eval dataset
+ for name in dev eval; do
+ python local/compute_embdding.py --config ${conf_path} \
+ --data-dir ${save_folder} \
+ --device ${device} \
+ --dataset ${name} \
+ --load-checkpoint ${pretrained_model_dir}
+ done
+fi
+
+if [ ${stage} -le 2 ]; then
+ # tune hyperparams on dev set
+ # perform final diarization on 'dev' and 'eval' with best hyperparams
+ python local/experiment.py --config ${conf_path} \
+ --data-dir ${save_folder}
+fi
diff --git a/examples/ami/sd0/run.sh b/examples/ami/sd0/run.sh
index 91d4b706a..9035f5955 100644
--- a/examples/ami/sd0/run.sh
+++ b/examples/ami/sd0/run.sh
@@ -1,14 +1,46 @@
#!/bin/bash
-. path.sh || exit 1;
+. ./path.sh || exit 1;
set -e
-stage=1
+stage=0
+#TARGET_DIR=${MAIN_ROOT}/dataset/ami
+TARGET_DIR=/home/dataset/AMI
+data_folder=${TARGET_DIR}/amicorpus #e.g., /path/to/amicorpus/
+manual_annot_folder=${TARGET_DIR}/ami_public_manual_1.6.2 #e.g., /path/to/ami_public_manual_1.6.2/
+
+save_folder=./save
+pretraind_model_dir=${save_folder}/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1/model
+conf_path=conf/ecapa_tdnn.yaml
+device=gpu
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
-if [ ${stage} -le 1 ]; then
- # prepare data
- bash ./local/data.sh || exit -1
-fi
\ No newline at end of file
+if [ $stage -le 0 ]; then
+ # Prepare data
+ # Download AMI corpus, You need around 10GB of free space to get whole data
+ # The signals are too large to package in this way,
+ # so you need to use the chooser to indicate which ones you wish to download
+ echo "Please follow https://groups.inf.ed.ac.uk/ami/download/ to download the data."
+ echo "Annotations: AMI manual annotations v1.6.2 "
+ echo "Signals: "
+ echo "1) Select one or more AMI meetings: the IDs please follow ./ami_split.py"
+ echo "2) Select media streams: Just select Headset mix"
+fi
+
+if [ $stage -le 1 ]; then
+ # Download the pretrained model
+ wget https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz
+ mkdir -p ${save_folder} && tar -xvf sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz -C ${save_folder}
+ rm -rf sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz
+ echo "download the pretrained ECAPA-TDNN Model to path: "${pretraind_model_dir}
+fi
+
+if [ $stage -le 2 ]; then
+ # Tune hyperparams on dev set and perform final diarization on dev and eval with best hyperparams.
+ echo ${data_folder} ${manual_annot_folder} ${save_folder} ${pretraind_model_dir} ${conf_path}
+ bash ./local/process.sh ${data_folder} ${manual_annot_folder} \
+ ${save_folder} ${pretraind_model_dir} ${conf_path} ${device} || exit 1
+fi
+
diff --git a/examples/csmsc/tts0/README.md b/examples/csmsc/tts0/README.md
index 0129329ae..01376bd61 100644
--- a/examples/csmsc/tts0/README.md
+++ b/examples/csmsc/tts0/README.md
@@ -212,7 +212,8 @@ optional arguments:
Pretrained Tacotron2 model with no silence in the edge of audios:
- [tacotron2_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip)
-The static model can be downloaded here [tacotron2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_static_0.2.0.zip).
+The static model can be downloaded here:
+- [tacotron2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_static_0.2.0.zip)
Model | Step | eval/loss | eval/l1_loss | eval/mse_loss | eval/bce_loss| eval/attn_loss
diff --git a/examples/csmsc/tts2/README.md b/examples/csmsc/tts2/README.md
index 5f31f7b36..4fbe34cbf 100644
--- a/examples/csmsc/tts2/README.md
+++ b/examples/csmsc/tts2/README.md
@@ -221,9 +221,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path}
```
## Pretrained Model
-Pretrained SpeedySpeech model with no silence in the edge of audios[speedyspeech_nosil_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_ckpt_0.5.zip).
+Pretrained SpeedySpeech model with no silence in the edge of audios:
+- [speedyspeech_nosil_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_ckpt_0.5.zip)
-The static model can be downloaded here [speedyspeech_nosil_baker_static_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_static_0.5.zip).
+The static model can be downloaded here:
+- [speedyspeech_nosil_baker_static_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_static_0.5.zip)
+- [speedyspeech_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_static_0.2.0.zip)
Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/ssim_loss
:-------------:| :------------:| :-----: | :-----: | :--------:|:--------:
diff --git a/examples/csmsc/tts3/README.md b/examples/csmsc/tts3/README.md
index ae8f7af60..bc672f66f 100644
--- a/examples/csmsc/tts3/README.md
+++ b/examples/csmsc/tts3/README.md
@@ -232,6 +232,9 @@ The static model can be downloaded here:
- [fastspeech2_nosil_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_static_0.4.zip)
- [fastspeech2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_static_0.2.0.zip)
+The ONNX model can be downloaded here:
+- [fastspeech2_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_onnx_0.2.0.zip)
+
Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/pitch_loss| eval/energy_loss
:-------------:| :------------:| :-----: | :-----: | :--------: |:--------:|:---------:
default| 2(gpu) x 76000|1.0991|0.59132|0.035815|0.31915|0.15287|
diff --git a/examples/csmsc/tts3/local/ort_predict.sh b/examples/csmsc/tts3/local/ort_predict.sh
new file mode 100755
index 000000000..3154f6e5a
--- /dev/null
+++ b/examples/csmsc/tts3/local/ort_predict.sh
@@ -0,0 +1,31 @@
+train_output_path=$1
+
+stage=0
+stop_stage=0
+
+# only support default_fastspeech2 + hifigan/mb_melgan now!
+
+# synthesize from metadata
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ python3 ${BIN_DIR}/../ort_predict.py \
+ --inference_dir=${train_output_path}/inference_onnx \
+ --am=fastspeech2_csmsc \
+ --voc=hifigan_csmsc \
+ --test_metadata=dump/test/norm/metadata.jsonl \
+ --output_dir=${train_output_path}/onnx_infer_out \
+ --device=cpu \
+ --cpu_threads=2
+fi
+
+# e2e, synthesize from text
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ python3 ${BIN_DIR}/../ort_predict_e2e.py \
+ --inference_dir=${train_output_path}/inference_onnx \
+ --am=fastspeech2_csmsc \
+ --voc=hifigan_csmsc \
+ --output_dir=${train_output_path}/onnx_infer_out_e2e \
+ --text=${BIN_DIR}/../csmsc_test.txt \
+ --phones_dict=dump/phone_id_map.txt \
+ --device=cpu \
+ --cpu_threads=2
+fi
diff --git a/examples/csmsc/tts3/local/paddle2onnx.sh b/examples/csmsc/tts3/local/paddle2onnx.sh
new file mode 100755
index 000000000..505f3b663
--- /dev/null
+++ b/examples/csmsc/tts3/local/paddle2onnx.sh
@@ -0,0 +1,22 @@
+train_output_path=$1
+model_dir=$2
+output_dir=$3
+model=$4
+
+enable_dev_version=True
+
+model_name=${model%_*}
+echo model_name: ${model_name}
+
+if [ ${model_name} = 'mb_melgan' ] ;then
+ enable_dev_version=False
+fi
+
+mkdir -p ${train_output_path}/${output_dir}
+
+paddle2onnx \
+ --model_dir ${train_output_path}/${model_dir} \
+ --model_filename ${model}.pdmodel \
+ --params_filename ${model}.pdiparams \
+ --save_file ${train_output_path}/${output_dir}/${model}.onnx \
+ --enable_dev_version ${enable_dev_version}
\ No newline at end of file
diff --git a/examples/csmsc/tts3/run.sh b/examples/csmsc/tts3/run.sh
index e1a149b65..b617d5352 100755
--- a/examples/csmsc/tts3/run.sh
+++ b/examples/csmsc/tts3/run.sh
@@ -41,3 +41,25 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} || exit -1
fi
+# paddle2onnx, please make sure the static models are in ${train_output_path}/inference first
+# we have only tested the following models so far
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ # install paddle2onnx
+ version=$(echo `pip list |grep "paddle2onnx"` |awk -F" " '{print $2}')
+ if [[ -z "$version" || ${version} != '0.9.4' ]]; then
+ pip install paddle2onnx==0.9.4
+ fi
+ ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx fastspeech2_csmsc
+ ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx hifigan_csmsc
+ ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx mb_melgan_csmsc
+fi
+
+# inference with onnxruntime, use fastspeech2 + hifigan by default
+if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+ # install onnxruntime
+ version=$(echo `pip list |grep "onnxruntime"` |awk -F" " '{print $2}')
+ if [[ -z "$version" || ${version} != '1.10.0' ]]; then
+ pip install onnxruntime==1.10.0
+ fi
+ ./local/ort_predict.sh ${train_output_path}
+fi
diff --git a/examples/csmsc/voc1/README.md b/examples/csmsc/voc1/README.md
index 5527e8088..2d6de168a 100644
--- a/examples/csmsc/voc1/README.md
+++ b/examples/csmsc/voc1/README.md
@@ -127,9 +127,11 @@ optional arguments:
5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
## Pretrained Models
-The pretrained model can be downloaded here [pwg_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip).
+The pretrained model can be downloaded here:
+- [pwg_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip)
-The static model can be downloaded here [pwg_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip).
+The static model can be downloaded here:
+- [pwg_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip)
Model | Step | eval/generator_loss | eval/log_stft_magnitude_loss| eval/spectral_convergence_loss
:-------------:| :------------:| :-----: | :-----: | :--------:
diff --git a/examples/csmsc/voc3/README.md b/examples/csmsc/voc3/README.md
index 22104a8f2..12adaf7f4 100644
--- a/examples/csmsc/voc3/README.md
+++ b/examples/csmsc/voc3/README.md
@@ -152,11 +152,17 @@ TODO:
The hyperparameter of `finetune.yaml` is not good enough, a smaller `learning_rate` should be used (more `milestones` should be set).
## Pretrained Models
-The pretrained model can be downloaded here [mb_melgan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip).
+The pretrained model can be downloaded here:
+- [mb_melgan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip)
-The finetuned model can be downloaded here [mb_melgan_baker_finetune_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_baker_finetune_ckpt_0.5.zip).
+The finetuned model can be downloaded here:
+- [mb_melgan_baker_finetune_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_baker_finetune_ckpt_0.5.zip)
-The static model can be downloaded here [mb_melgan_csmsc_static_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_static_0.1.1.zip)
+The static model can be downloaded here:
+- [mb_melgan_csmsc_static_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_static_0.1.1.zip)
+
+The ONNX model can be downloaded here:
+- [mb_melgan_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip)
Model | Step | eval/generator_loss | eval/log_stft_magnitude_loss|eval/spectral_convergence_loss |eval/sub_log_stft_magnitude_loss|eval/sub_spectral_convergence_loss
:-------------:| :------------:| :-----: | :-----: | :--------:| :--------:| :--------:
diff --git a/examples/csmsc/voc4/README.md b/examples/csmsc/voc4/README.md
index b5c687391..b7add3e57 100644
--- a/examples/csmsc/voc4/README.md
+++ b/examples/csmsc/voc4/README.md
@@ -112,7 +112,8 @@ optional arguments:
5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
## Pretrained Models
-The pretrained model can be downloaded here [style_melgan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip).
+The pretrained model can be downloaded here:
+- [style_melgan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip)
The static model of Style MelGAN is not available now.
diff --git a/examples/csmsc/voc5/README.md b/examples/csmsc/voc5/README.md
index 21afe6eef..33e676165 100644
--- a/examples/csmsc/voc5/README.md
+++ b/examples/csmsc/voc5/README.md
@@ -112,9 +112,14 @@ optional arguments:
5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
## Pretrained Models
-The pretrained model can be downloaded here [hifigan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip).
+The pretrained model can be downloaded here:
+- [hifigan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip)
-The static model can be downloaded here [hifigan_csmsc_static_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_static_0.1.1.zip).
+The static model can be downloaded here:
+- [hifigan_csmsc_static_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_static_0.1.1.zip)
+
+The ONNX model can be downloaded here:
+- [hifigan_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_onnx_0.2.0.zip)
Model | Step | eval/generator_loss | eval/mel_loss| eval/feature_matching_loss
:-------------:| :------------:| :-----: | :-----: | :--------:
diff --git a/examples/csmsc/voc6/README.md b/examples/csmsc/voc6/README.md
index 7763b3551..26d4523d9 100644
--- a/examples/csmsc/voc6/README.md
+++ b/examples/csmsc/voc6/README.md
@@ -109,9 +109,11 @@ optional arguments:
5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
## Pretrained Models
-The pretrained model can be downloaded here [wavernn_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_ckpt_0.2.0.zip).
+The pretrained model can be downloaded here:
+- [wavernn_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_ckpt_0.2.0.zip)
-The static model can be downloaded here [wavernn_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_static_0.2.0.zip).
+The static model can be downloaded here:
+- [wavernn_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_static_0.2.0.zip)
Model | Step | eval/loss
:-------------:|:------------:| :------------:
diff --git a/examples/iwslt2012/punc0/README.md b/examples/iwslt2012/punc0/README.md
index 74d599a21..6caa9710b 100644
--- a/examples/iwslt2012/punc0/README.md
+++ b/examples/iwslt2012/punc0/README.md
@@ -21,7 +21,7 @@
The pretrained model can be downloaded here [ernie_linear_p3_iwslt2012_zh_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_iwslt2012_zh_ckpt_0.1.1.zip).
### Test Result
-- Ernie Linear
+- Ernie
| |COMMA | PERIOD | QUESTION | OVERALL|
|:-----:|:-----:|:-----:|:-----:|:-----:|
|Precision |0.510955 |0.526462 |0.820755 |0.619391|
diff --git a/examples/iwslt2012/punc0/RESULTS.md b/examples/iwslt2012/punc0/RESULTS.md
new file mode 100644
index 000000000..2e22713d8
--- /dev/null
+++ b/examples/iwslt2012/punc0/RESULTS.md
@@ -0,0 +1,9 @@
+# iwslt2012
+
+## Ernie
+
+| |COMMA | PERIOD | QUESTION | OVERALL|
+|:-----:|:-----:|:-----:|:-----:|:-----:|
+|Precision |0.510955 |0.526462 |0.820755 |0.619391|
+|Recall |0.517433 |0.564179 |0.861386 |0.647666|
+|F1 |0.514173 |0.544669 |0.840580 |0.633141|
diff --git a/examples/librispeech/asr1/README.md b/examples/librispeech/asr1/README.md
index eb1a44001..ae252a58b 100644
--- a/examples/librispeech/asr1/README.md
+++ b/examples/librispeech/asr1/README.md
@@ -151,44 +151,22 @@ avg.sh best exp/conformer/checkpoints 20
CUDA_VISIBLE_DEVICES= ./local/test.sh conf/conformer.yaml exp/conformer/checkpoints/avg_20
```
## Pretrained Model
-You can get the pretrained transformer or conformer using the scripts below:
-```bash
-# Conformer:
-wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/conformer.model.tar.gz
-# Transformer:
-wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/transformer.model.tar.gz
-```
+You can get the pretrained transformer or conformer from [this](../../../docs/source/released_model.md).
+
using the `tar` scripts to unpack the model and then you can use the script to test the model.
For example:
```bash
-wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/conformer.model.tar.gz
-tar xzvf transformer.model.tar.gz
+wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_conformer_librispeech_ckpt_0.1.1.model.tar.gz
+tar xzvf asr1_conformer_librispeech_ckpt_0.1.1.model.tar.gz
source path.sh
# If you have process the data and get the manifest file, you can skip the following 2 steps
bash local/data.sh --stage -1 --stop_stage -1
bash local/data.sh --stage 2 --stop_stage 2
CUDA_VISIBLE_DEVICES= ./local/test.sh conf/conformer.yaml exp/conformer/checkpoints/avg_20
```
-The performance of the released models are shown below:
-## Conformer
-train: Epoch 70, 4 V100-32G, best avg: 20
-
-| Model | Params | Config | Augmentation | Test set | Decode method | Loss | WER |
-| --------- | ------- | ------------------- | ------------ | ---------- | ---------------------- | ----------------- | -------- |
-| conformer | 47.63 M | conf/conformer.yaml | spec_aug | test-clean | attention | 6.433612394332886 | 0.039771 |
-| conformer | 47.63 M | conf/conformer.yaml | spec_aug | test-clean | ctc_greedy_search | 6.433612394332886 | 0.040342 |
-| conformer | 47.63 M | conf/conformer.yaml | spec_aug | test-clean | ctc_prefix_beam_search | 6.433612394332886 | 0.040342 |
-| conformer | 47.63 M | conf/conformer.yaml | spec_aug | test-clean | attention_rescoring | 6.433612394332886 | 0.033761 |
-## Transformer
-train: Epoch 120, 4 V100-32G, 27 Day, best avg: 10
+The performance of the released models are shown in [here](./RESULTS.md).
-| Model | Params | Config | Augmentation | Test set | Decode method | Loss | WER |
-| ----------- | ------- | --------------------- | ------------ | ---------- | ---------------------- | ----------------- | -------- |
-| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention | 6.382194232940674 | 0.049661 |
-| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_greedy_search | 6.382194232940674 | 0.049566 |
-| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_prefix_beam_search | 6.382194232940674 | 0.049585 |
-| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention_rescoring | 6.382194232940674 | 0.038135 |
## Stage 4: CTC Alignment
If you want to get the alignment between the audio and the text, you can use the ctc alignment. The code of this stage is shown below:
```bash
@@ -227,8 +205,8 @@ In some situations, you want to use the trained model to do the inference for th
```
you can train the model by yourself using ```bash run.sh --stage 0 --stop_stage 3```, or you can download the pretrained model through the script below:
```bash
-wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/conformer.model.tar.gz
-tar xzvf conformer.model.tar.gz
+wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_conformer_librispeech_ckpt_0.1.1.model.tar.gz
+tar xzvf asr1_conformer_librispeech_ckpt_0.1.1.model.tar.gz
```
You can download the audio demo:
```bash
diff --git a/examples/librispeech/asr2/README.md b/examples/librispeech/asr2/README.md
index 7d6fe11df..5bc7185a9 100644
--- a/examples/librispeech/asr2/README.md
+++ b/examples/librispeech/asr2/README.md
@@ -1,4 +1,4 @@
-# Transformer/Conformer ASR with Librispeech Asr2
+# Transformer/Conformer ASR with Librispeech ASR2
This example contains code used to train a Transformer or [Conformer](http://arxiv.org/abs/2008.03802) model with [Librispeech dataset](http://www.openslr.org/resources/12) and use some functions in kaldi.
@@ -213,17 +213,14 @@ avg.sh latest exp/transformer/checkpoints 10
./local/recog.sh --ckpt_prefix exp/transformer/checkpoints/avg_10
```
## Pretrained Model
-You can get the pretrained transformer using the scripts below:
-```bash
-# Transformer:
-wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr2/transformer.model.tar.gz
-```
+You can get the pretrained models from [this](../../../docs/source/released_model.md).
+
using the `tar` scripts to unpack the model and then you can use the script to test the model.
For example:
```bash
-wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr2/transformer.model.tar.gz
-tar xzvf transformer.model.tar.gz
+wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr2/asr2_transformer_librispeech_ckpt_0.1.1.model.tar.gz
+tar xzvf asr2_transformer_librispeech_ckpt_0.1.1.model.tar.gz
source path.sh
# If you have process the data and get the manifest file, you can skip the following 2 steps
bash local/data.sh --stage -1 --stop_stage -1
@@ -231,26 +228,7 @@ bash local/data.sh --stage 2 --stop_stage 2
CUDA_VISIBLE_DEVICES= ./local/test.sh conf/transformer.yaml exp/ctc/checkpoints/avg_10
```
-The performance of the released models are shown below:
-### Transformer
-| Model | Params | GPUS | Averaged Model | Config | Augmentation | Loss |
-| :---------: | :----: | :--------------------: | :--------------: | :-------------------: | :----------: | :-------------: |
-| transformer | 32.52M | 8 Tesla V100-SXM2-32GB | 10-best val_loss | conf/transformer.yaml | spec_aug | 6.3197922706604 |
-
-#### Attention Rescore
-| Test Set | Decode Method | #Snt | #Wrd | Corr | Sub | Del | Ins | Err | S.Err |
-| ---------- | --------------------- | ---- | ----- | ---- | ---- | ---- | ---- | ---- | ----- |
-| test-clean | attention | 2620 | 52576 | 96.4 | 2.5 | 1.1 | 0.4 | 4.0 | 34.7 |
-| test-clean | ctc_greedy_search | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 48.0 |
-| test-clean | ctc_prefix_beamsearch | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 47.6 |
-| test-clean | attention_rescore | 2620 | 52576 | 96.8 | 2.9 | 0.3 | 0.4 | 3.7 | 38.0 |
-
-#### JoinCTC
-| Test Set | Decode Method | #Snt | #Wrd | Corr | Sub | Del | Ins | Err | S.Err |
-| ---------- | ----------------- | ---- | ----- | ---- | ---- | ---- | ---- | ---- | ----- |
-| test-clean | join_ctc_only_att | 2620 | 52576 | 96.1 | 2.5 | 1.4 | 0.4 | 4.4 | 34.7 |
-| test-clean | join_ctc_w/o_lm | 2620 | 52576 | 97.2 | 2.6 | 0.3 | 0.4 | 3.2 | 34.9 |
-| test-clean | join_ctc_w_lm | 2620 | 52576 | 97.9 | 1.8 | 0.2 | 0.3 | 2.4 | 27.8 |
+The performance of the released models are shown [here](./RESULTS.md).
Compare with [ESPNET](https://github.com/espnet/espnet/blob/master/egs/librispeech/asr1/RESULTS.md#pytorch-large-transformer-with-specaug-4-gpus--transformer-lm-4-gpus) we using 8gpu, but the model size (aheads4-adim256) small than it.
## Stage 5: CTC Alignment
diff --git a/examples/ljspeech/tts1/README.md b/examples/ljspeech/tts1/README.md
index 4f7680e84..7f32522ac 100644
--- a/examples/ljspeech/tts1/README.md
+++ b/examples/ljspeech/tts1/README.md
@@ -171,7 +171,8 @@ optional arguments:
6. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
## Pretrained Model
-Pretrained Model can be downloaded here. [transformer_tts_ljspeech_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/transformer_tts/transformer_tts_ljspeech_ckpt_0.4.zip)
+Pretrained Model can be downloaded here:
+- [transformer_tts_ljspeech_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/transformer_tts/transformer_tts_ljspeech_ckpt_0.4.zip)
TransformerTTS checkpoint contains files listed below.
```text
diff --git a/examples/ljspeech/tts3/README.md b/examples/ljspeech/tts3/README.md
index f5e919c0f..e028fa05d 100644
--- a/examples/ljspeech/tts3/README.md
+++ b/examples/ljspeech/tts3/README.md
@@ -214,7 +214,8 @@ optional arguments:
9. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
## Pretrained Model
-Pretrained FastSpeech2 model with no silence in the edge of audios. [fastspeech2_nosil_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip)
+Pretrained FastSpeech2 model with no silence in the edge of audios:
+- [fastspeech2_nosil_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip)
Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/pitch_loss| eval/energy_loss
:-------------:| :------------:| :-----: | :-----: | :--------: |:--------:|:---------:
diff --git a/examples/ljspeech/voc0/README.md b/examples/ljspeech/voc0/README.md
index 13a50efb5..41b08d57f 100644
--- a/examples/ljspeech/voc0/README.md
+++ b/examples/ljspeech/voc0/README.md
@@ -50,4 +50,5 @@ Synthesize waveform.
6. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
## Pretrained Model
-Pretrained Model with residual channel equals 128 can be downloaded here. [waveflow_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/waveflow/waveflow_ljspeech_ckpt_0.3.zip).
+Pretrained Model with residual channel equals 128 can be downloaded here:
+- [waveflow_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/waveflow/waveflow_ljspeech_ckpt_0.3.zip)
diff --git a/examples/ljspeech/voc1/README.md b/examples/ljspeech/voc1/README.md
index 6fcb2a520..4513b2a05 100644
--- a/examples/ljspeech/voc1/README.md
+++ b/examples/ljspeech/voc1/README.md
@@ -127,7 +127,8 @@ optional arguments:
5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
## Pretrained Model
-Pretrained models can be downloaded here. [pwg_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip)
+Pretrained models can be downloaded here:
+- [pwg_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip)
Parallel WaveGAN checkpoint contains files listed below.
diff --git a/examples/ljspeech/voc5/README.md b/examples/ljspeech/voc5/README.md
index 9fbb9f746..9b31e2650 100644
--- a/examples/ljspeech/voc5/README.md
+++ b/examples/ljspeech/voc5/README.md
@@ -127,7 +127,8 @@ optional arguments:
5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
## Pretrained Model
-The pretrained model can be downloaded here [hifigan_ljspeech_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_ckpt_0.2.0.zip).
+The pretrained model can be downloaded here:
+- [hifigan_ljspeech_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_ckpt_0.2.0.zip)
Model | Step | eval/generator_loss | eval/mel_loss| eval/feature_matching_loss
@@ -143,6 +144,5 @@ hifigan_ljspeech_ckpt_0.2.0
└── snapshot_iter_2500000.pdz # generator parameters of hifigan
```
-
## Acknowledgement
We adapted some code from https://github.com/kan-bayashi/ParallelWaveGAN.
diff --git a/examples/vctk/tts3/README.md b/examples/vctk/tts3/README.md
index 157949d1f..f373ca6a3 100644
--- a/examples/vctk/tts3/README.md
+++ b/examples/vctk/tts3/README.md
@@ -217,7 +217,8 @@ optional arguments:
9. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
## Pretrained Model
-Pretrained FastSpeech2 model with no silence in the edge of audios. [fastspeech2_nosil_vctk_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_vctk_ckpt_0.5.zip)
+Pretrained FastSpeech2 model with no silence in the edge of audios:
+- [fastspeech2_nosil_vctk_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_vctk_ckpt_0.5.zip)
FastSpeech2 checkpoint contains files listed below.
```text
diff --git a/examples/vctk/voc1/README.md b/examples/vctk/voc1/README.md
index 4714f28dc..1c3016f88 100644
--- a/examples/vctk/voc1/README.md
+++ b/examples/vctk/voc1/README.md
@@ -132,7 +132,8 @@ optional arguments:
5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
## Pretrained Model
-Pretrained models can be downloaded here [pwg_vctk_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.1.1.zip).
+Pretrained models can be downloaded here:
+- [pwg_vctk_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.1.1.zip)
Parallel WaveGAN checkpoint contains files listed below.
diff --git a/examples/vctk/voc5/README.md b/examples/vctk/voc5/README.md
index b4be341c0..4eb25c02d 100644
--- a/examples/vctk/voc5/README.md
+++ b/examples/vctk/voc5/README.md
@@ -133,7 +133,8 @@ optional arguments:
5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
## Pretrained Model
-The pretrained model can be downloaded here [hifigan_vctk_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip).
+The pretrained model can be downloaded here:
+- [hifigan_vctk_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip)
Model | Step | eval/generator_loss | eval/mel_loss| eval/feature_matching_loss
diff --git a/examples/voxceleb/sv0/RESULT.md b/examples/voxceleb/sv0/RESULT.md
index c37bcecef..3a3f67d09 100644
--- a/examples/voxceleb/sv0/RESULT.md
+++ b/examples/voxceleb/sv0/RESULT.md
@@ -4,4 +4,4 @@
| Model | Number of Params | Release | Config | dim | Test set | Cosine | Cosine + S-Norm |
| --- | --- | --- | --- | --- | --- | --- | ---- |
-| ECAPA-TDNN | 85M | 0.1.1 | conf/ecapa_tdnn.yaml |192 | test | 1.15 | 1.06 |
+| ECAPA-TDNN | 85M | 0.2.0 | conf/ecapa_tdnn.yaml |192 | test | 1.02 | 0.95 |
diff --git a/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml b/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
index e58dca82d..4715c5a3c 100644
--- a/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
+++ b/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
@@ -1,14 +1,16 @@
###########################################
# Data #
###########################################
-# we should explicitly specify the wav path of vox2 audio data converted from m4a
-vox2_base_path:
augment: True
-batch_size: 16
+batch_size: 32
num_workers: 2
-num_speakers: 7205 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
+num_speakers: 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
shuffle: True
+skip_prep: False
+split_ratio: 0.9
+chunk_duration: 3.0 # seconds
random_chunk: True
+verification_file: data/vox1/veri_test2.txt
###########################################################
# FEATURE EXTRACTION SETTING #
@@ -26,7 +28,6 @@ hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160
# if we want use another model, please choose another configuration yaml file
model:
input_size: 80
- # "channels": [512, 512, 512, 512, 1536],
channels: [1024, 1024, 1024, 1024, 3072]
kernel_sizes: [5, 3, 3, 3, 1]
dilations: [1, 2, 3, 4, 1]
@@ -38,8 +39,8 @@ model:
###########################################
seed: 1986 # according from speechbrain configuration
epochs: 10
-save_interval: 1
-log_interval: 1
+save_interval: 10
+log_interval: 10
learning_rate: 1e-8
diff --git a/examples/voxceleb/sv0/conf/ecapa_tdnn_small.yaml b/examples/voxceleb/sv0/conf/ecapa_tdnn_small.yaml
new file mode 100644
index 000000000..5ad5ea285
--- /dev/null
+++ b/examples/voxceleb/sv0/conf/ecapa_tdnn_small.yaml
@@ -0,0 +1,53 @@
+###########################################
+# Data #
+###########################################
+augment: True
+batch_size: 16
+num_workers: 2
+num_speakers: 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
+shuffle: True
+skip_prep: False
+split_ratio: 0.9
+chunk_duration: 3.0 # seconds
+random_chunk: True
+verification_file: data/vox1/veri_test2.txt
+
+###########################################################
+# FEATURE EXTRACTION SETTING #
+###########################################################
+# currently, we only support fbank
+sr: 16000 # sample rate
+n_mels: 80
+window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400
+hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160
+
+###########################################################
+# MODEL SETTING #
+###########################################################
+# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml
+# if we want use another model, please choose another configuration yaml file
+model:
+ input_size: 80
+ channels: [512, 512, 512, 512, 1536]
+ kernel_sizes: [5, 3, 3, 3, 1]
+ dilations: [1, 2, 3, 4, 1]
+ attention_channels: 128
+ lin_neurons: 192
+
+###########################################
+# Training #
+###########################################
+seed: 1986 # according from speechbrain configuration
+epochs: 100
+save_interval: 10
+log_interval: 10
+learning_rate: 1e-8
+
+
+###########################################
+# Testing #
+###########################################
+global_embedding_norm: True
+embedding_mean_norm: True
+embedding_std_norm: False
+
diff --git a/examples/voxceleb/sv0/local/data.sh b/examples/voxceleb/sv0/local/data.sh
index a3ff1c486..d6010ec66 100755
--- a/examples/voxceleb/sv0/local/data.sh
+++ b/examples/voxceleb/sv0/local/data.sh
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-stage=1
+stage=0
stop_stage=100
. ${MAIN_ROOT}/utils/parse_options.sh || exit -1;
@@ -30,29 +30,114 @@ dir=$1
conf_path=$2
mkdir -p ${dir}
-if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
- # data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
- # we should use the local/convert.sh convert m4a to wav
- python3 local/data_prepare.py \
- --data-dir ${dir} \
- --config ${conf_path}
-fi
-
+# Generally the `MAIN_ROOT` refers to the root of PaddleSpeech,
+# which is defined in the path.sh
+# And we will download the voxceleb data and rirs noise to ${MAIN_ROOT}/dataset
TARGET_DIR=${MAIN_ROOT}/dataset
mkdir -p ${TARGET_DIR}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- # download data, generate manifests
- python3 ${TARGET_DIR}/voxceleb/voxceleb1.py \
- --manifest_prefix="data/vox1/manifest" \
+ # download data, generate manifests
+ # we will generate the manifest.{dev,test} file from ${TARGET_DIR}/voxceleb/vox1/{dev,test} directory
+ # and generate the meta info and download the trial file
+ # manifest.dev: 148642
+ # manifest.test: 4847
+ echo "Start to download vox1 dataset and generate the manifest files "
+ python3 ${TARGET_DIR}/voxceleb/voxceleb1.py \
+ --manifest_prefix="${dir}/vox1/manifest" \
--target_dir="${TARGET_DIR}/voxceleb/vox1/"
- if [ $? -ne 0 ]; then
- echo "Prepare voxceleb failed. Terminated."
- exit 1
- fi
+ if [ $? -ne 0 ]; then
+ echo "Prepare voxceleb1 failed. Terminated."
+ exit 1
+ fi
+
+fi
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ # download voxceleb2 data
+ # we will download the data and unzip the package
+ # and we will store the m4a file in ${TARGET_DIR}/voxceleb/vox2/{dev,test}
+ echo "start to download vox2 dataset"
+ python3 ${TARGET_DIR}/voxceleb/voxceleb2.py \
+ --download \
+ --target_dir="${TARGET_DIR}/voxceleb/vox2/"
+
+ if [ $? -ne 0 ]; then
+ echo "Download voxceleb2 dataset failed. Terminated."
+ exit 1
+ fi
+
+fi
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ # convert the m4a to wav
+ # and we will not delete the original m4a file
+ echo "start to convert the m4a to wav"
+ bash local/convert.sh ${TARGET_DIR}/voxceleb/vox2/test/ || exit 1;
+
+ if [ $? -ne 0 ]; then
+ echo "Convert voxceleb2 dataset from m4a to wav failed. Terminated."
+ exit 1
+ fi
+ echo "m4a convert to wav operation finished"
+fi
+
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ # generate the vox2 manifest file from wav file
+ # we will generate the ${dir}/vox2/manifest.vox2
+ # because we use all the vox2 dataset to train, so collect all the vox2 data in one file
+ echo "start generate the vox2 manifest files"
+ python3 ${TARGET_DIR}/voxceleb/voxceleb2.py \
+ --generate \
+ --manifest_prefix="${dir}/vox2/manifest" \
+ --target_dir="${TARGET_DIR}/voxceleb/vox2/"
- # for dataset in train dev test; do
- # mv data/manifest.${dataset} data/manifest.${dataset}.raw
- # done
-fi
\ No newline at end of file
+ if [ $? -ne 0 ]; then
+ echo "Prepare voxceleb2 dataset failed. Terminated."
+ exit 1
+ fi
+fi
+
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ # generate the vox csv file
+ # Currently, our training system use csv file for dataset
+ echo "convert the json format to csv format to be compatible with training process"
+ python3 local/make_vox_csv_dataset_from_json.py\
+ --train "${dir}/vox1/manifest.dev" "${dir}/vox2/manifest.vox2"\
+ --test "${dir}/vox1/manifest.test" \
+ --target_dir "${dir}/vox/" \
+ --config ${conf_path}
+
+ if [ $? -ne 0 ]; then
+ echo "Prepare voxceleb failed. Terminated."
+ exit 1
+ fi
+fi
+
+if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+ # generate the open rir noise manifest file
+ echo "generate the open rir noise manifest file"
+ python3 ${TARGET_DIR}/rir_noise/rir_noise.py\
+ --manifest_prefix="${dir}/rir_noise/manifest" \
+ --target_dir="${TARGET_DIR}/rir_noise/"
+
+ if [ $? -ne 0 ]; then
+ echo "Prepare rir_noise failed. Terminated."
+ exit 1
+ fi
+fi
+
+if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
+ # generate the open rir noise manifest file
+ echo "generate the open rir noise csv file"
+ python3 local/make_rirs_noise_csv_dataset_from_json.py \
+ --noise_dir="${TARGET_DIR}/rir_noise/" \
+ --data_dir="${dir}/rir_noise/" \
+ --config ${conf_path}
+
+ if [ $? -ne 0 ]; then
+ echo "Prepare rir_noise failed. Terminated."
+ exit 1
+ fi
+fi
diff --git a/examples/voxceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py b/examples/voxceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py
new file mode 100644
index 000000000..b25a9d49a
--- /dev/null
+++ b/examples/voxceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py
@@ -0,0 +1,167 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Convert the PaddleSpeech jsonline format data to csv format data in voxceleb experiment.
+Currently, Speaker Identificaton Training process use csv format.
+"""
+import argparse
+import csv
+import os
+from typing import List
+
+import tqdm
+from yacs.config import CfgNode
+
+from paddleaudio import load as load_audio
+from paddlespeech.s2t.utils.log import Log
+from paddlespeech.vector.utils.vector_utils import get_chunks
+
+logger = Log(__name__).getlog()
+
+
+def get_chunks_list(wav_file: str,
+ split_chunks: bool,
+ base_path: str,
+ chunk_duration: float=3.0) -> List[List[str]]:
+ """Get the single audio file info
+
+ Args:
+ wav_file (list): the wav audio file and get this audio segment info list
+ split_chunks (bool): audio split flag
+ base_path (str): the audio base path
+ chunk_duration (float): the chunk duration.
+ if set the split_chunks, we split the audio into multi-chunks segment.
+ """
+ waveform, sr = load_audio(wav_file)
+ audio_id = wav_file.split("/rir_noise/")[-1].split(".")[0]
+ audio_duration = waveform.shape[0] / sr
+
+ ret = []
+ if split_chunks and audio_duration > chunk_duration: # Split into pieces of self.chunk_duration seconds.
+ uniq_chunks_list = get_chunks(chunk_duration, audio_id, audio_duration)
+
+ for idx, chunk in enumerate(uniq_chunks_list):
+ s, e = chunk.split("_")[-2:] # Timestamps of start and end
+ start_sample = int(float(s) * sr)
+ end_sample = int(float(e) * sr)
+
+ # currently, all vector csv data format use one representation
+ # id, duration, wav, start, stop, label
+ # in rirs noise, all the label name is 'noise'
+ # the label is string type and we will convert it to integer type in training
+ ret.append([
+ chunk, audio_duration, wav_file, start_sample, end_sample,
+ "noise"
+ ])
+ else: # Keep whole audio.
+ ret.append(
+ [audio_id, audio_duration, wav_file, 0, waveform.shape[0], "noise"])
+ return ret
+
+
+def generate_csv(wav_files,
+ output_file: str,
+ base_path: str,
+ split_chunks: bool=True):
+ """Prepare the csv file according the wav files
+
+ Args:
+ wav_files (list): all the audio list to prepare the csv file
+ output_file (str): the output csv file
+ config (CfgNode): yaml configuration content
+ split_chunks (bool): audio split flag
+ """
+ logger.info(f'Generating csv: {output_file}')
+ header = ["utt_id", "duration", "wav", "start", "stop", "label"]
+ csv_lines = []
+ for item in tqdm.tqdm(wav_files):
+ csv_lines.extend(
+ get_chunks_list(
+ item, base_path=base_path, split_chunks=split_chunks))
+
+ if not os.path.exists(os.path.dirname(output_file)):
+ os.makedirs(os.path.dirname(output_file))
+
+ with open(output_file, mode="w") as csv_f:
+ csv_writer = csv.writer(
+ csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL)
+ csv_writer.writerow(header)
+ for line in csv_lines:
+ csv_writer.writerow(line)
+
+
+def prepare_data(args, config):
+ """Convert the jsonline format to csv format
+
+ Args:
+ args (argparse.Namespace): scripts args
+ config (CfgNode): yaml configuration content
+ """
+ # if external config set the skip_prep flat, we will do nothing
+ if config.skip_prep:
+ return
+
+ base_path = args.noise_dir
+ wav_path = os.path.join(base_path, "RIRS_NOISES")
+ logger.info(f"base path: {base_path}")
+ logger.info(f"wav path: {wav_path}")
+ rir_list = os.path.join(wav_path, "real_rirs_isotropic_noises", "rir_list")
+ rir_files = []
+ with open(rir_list, 'r') as f:
+ for line in f.readlines():
+ rir_file = line.strip().split(' ')[-1]
+ rir_files.append(os.path.join(base_path, rir_file))
+
+ noise_list = os.path.join(wav_path, "pointsource_noises", "noise_list")
+ noise_files = []
+ with open(noise_list, 'r') as f:
+ for line in f.readlines():
+ noise_file = line.strip().split(' ')[-1]
+ noise_files.append(os.path.join(base_path, noise_file))
+
+ csv_path = os.path.join(args.data_dir, 'csv')
+ logger.info(f"csv path: {csv_path}")
+ generate_csv(
+ rir_files, os.path.join(csv_path, 'rir.csv'), base_path=base_path)
+ generate_csv(
+ noise_files, os.path.join(csv_path, 'noise.csv'), base_path=base_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument(
+ "--noise_dir",
+ default=None,
+ required=True,
+ help="The noise dataset dataset directory.")
+ parser.add_argument(
+ "--data_dir",
+ default=None,
+ required=True,
+ help="The target directory stores the csv files")
+ parser.add_argument(
+ "--config",
+ default=None,
+ required=True,
+ type=str,
+ help="configuration file")
+ args = parser.parse_args()
+
+ # parse the yaml config file
+ config = CfgNode(new_allowed=True)
+ if args.config:
+ config.merge_from_file(args.config)
+
+ # prepare the csv file from jsonlines files
+ prepare_data(args, config)
diff --git a/examples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py b/examples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py
new file mode 100644
index 000000000..4e64c3067
--- /dev/null
+++ b/examples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py
@@ -0,0 +1,251 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Convert the PaddleSpeech jsonline format data to csv format data in voxceleb experiment.
+Currently, Speaker Identificaton Training process use csv format.
+"""
+import argparse
+import csv
+import json
+import os
+import random
+
+import tqdm
+from yacs.config import CfgNode
+
+from paddleaudio import load as load_audio
+from paddlespeech.s2t.utils.log import Log
+from paddlespeech.vector.utils.vector_utils import get_chunks
+
+logger = Log(__name__).getlog()
+
+
+def prepare_csv(wav_files, output_file, config, split_chunks=True):
+ """Prepare the csv file according the wav files
+
+ Args:
+ wav_files (list): all the audio list to prepare the csv file
+ output_file (str): the output csv file
+ config (CfgNode): yaml configuration content
+ split_chunks (bool, optional): audio split flag. Defaults to True.
+ """
+ if not os.path.exists(os.path.dirname(output_file)):
+ os.makedirs(os.path.dirname(output_file))
+ csv_lines = []
+ header = ["utt_id", "duration", "wav", "start", "stop", "label"]
+ # voxceleb meta info for each training utterance segment
+ # we extract a segment from a utterance to train
+ # and the segment' period is between start and stop time point in the original wav file
+ # each field in the meta info means as follows:
+ # utt_id: the utterance segment name, which is uniq in training dataset
+ # duration: the total utterance time
+ # wav: utterance file path, which should be absoulute path
+ # start: start point in the original wav file sample point range
+ # stop: stop point in the original wav file sample point range
+ # label: the utterance segment's label name,
+ # which is speaker name in speaker verification domain
+ for item in tqdm.tqdm(wav_files, total=len(wav_files)):
+ item = json.loads(item.strip())
+ audio_id = item['utt'].replace(".wav",
+ "") # we remove the wav suffix name
+ audio_duration = item['feat_shape'][0]
+ wav_file = item['feat']
+ label = audio_id.split('-')[
+ 0] # speaker name in speaker verification domain
+ waveform, sr = load_audio(wav_file)
+ if split_chunks:
+ uniq_chunks_list = get_chunks(config.chunk_duration, audio_id,
+ audio_duration)
+ for chunk in uniq_chunks_list:
+ s, e = chunk.split("_")[-2:] # Timestamps of start and end
+ start_sample = int(float(s) * sr)
+ end_sample = int(float(e) * sr)
+ # id, duration, wav, start, stop, label
+ # in vector, the label in speaker id
+ csv_lines.append([
+ chunk, audio_duration, wav_file, start_sample, end_sample,
+ label
+ ])
+ else:
+ csv_lines.append([
+ audio_id, audio_duration, wav_file, 0, waveform.shape[0], label
+ ])
+
+ with open(output_file, mode="w") as csv_f:
+ csv_writer = csv.writer(
+ csv_f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
+ csv_writer.writerow(header)
+ for line in csv_lines:
+ csv_writer.writerow(line)
+
+
+def get_enroll_test_list(dataset_list, verification_file):
+ """Get the enroll and test utterance list from all the voxceleb1 test utterance dataset.
+ Generally, we get the enroll and test utterances from the verfification file.
+ The verification file format as follows:
+ target/nontarget enroll-utt test-utt,
+ we set 0 as nontarget and 1 as target, eg:
+ 0 a.wav b.wav
+ 1 a.wav a.wav
+
+ Args:
+ dataset_list (list): all the dataset to get the test utterances
+ verification_file (str): voxceleb1 trial file
+ """
+ logger.info(f"verification file: {verification_file}")
+ enroll_audios = set()
+ test_audios = set()
+ with open(verification_file, 'r') as f:
+ for line in f:
+ _, enroll_file, test_file = line.strip().split(' ')
+ enroll_audios.add('-'.join(enroll_file.split('/')))
+ test_audios.add('-'.join(test_file.split('/')))
+
+ enroll_files = []
+ test_files = []
+ for dataset in dataset_list:
+ with open(dataset, 'r') as f:
+ for line in f:
+ # audio_id may be in enroll and test at the same time
+ # eg: 1 a.wav a.wav
+ # the audio a.wav is enroll and test file at the same time
+ audio_id = json.loads(line.strip())['utt']
+ if audio_id in enroll_audios:
+ enroll_files.append(line)
+ if audio_id in test_audios:
+ test_files.append(line)
+
+ enroll_files = sorted(enroll_files)
+ test_files = sorted(test_files)
+
+ return enroll_files, test_files
+
+
+def get_train_dev_list(dataset_list, target_dir, split_ratio):
+ """Get the train and dev utterance list from all the training utterance dataset.
+ Generally, we use the split_ratio as the train dataset ratio,
+ and the remaining utterance (ratio is 1 - split_ratio) is the dev dataset
+
+ Args:
+ dataset_list (list): all the dataset to get the all utterances
+ target_dir (str): the target train and dev directory,
+ we will create the csv directory to store the {train,dev}.csv file
+ split_ratio (float): train dataset ratio in all utterance list
+ """
+ logger.info("start to get train and dev utt list")
+ if not os.path.exists(os.path.join(target_dir, "meta")):
+ os.makedirs(os.path.join(target_dir, "meta"))
+
+ audio_files = []
+ speakers = set()
+ for dataset in dataset_list:
+ with open(dataset, 'r') as f:
+ for line in f:
+ # the label is speaker name
+ label_name = json.loads(line.strip())['utt2spk']
+ speakers.add(label_name)
+ audio_files.append(line.strip())
+ speakers = sorted(speakers)
+ logger.info(f"we get {len(speakers)} speakers from all the train dataset")
+
+ with open(os.path.join(target_dir, "meta", "label2id.txt"), 'w') as f:
+ for label_id, label_name in enumerate(speakers):
+ f.write(f'{label_name} {label_id}\n')
+ logger.info(
+ f'we store the speakers to {os.path.join(target_dir, "meta", "label2id.txt")}'
+ )
+
+ # the split_ratio is for train dataset
+ # the remaining is for dev dataset
+ split_idx = int(split_ratio * len(audio_files))
+ audio_files = sorted(audio_files)
+ random.shuffle(audio_files)
+ train_files, dev_files = audio_files[:split_idx], audio_files[split_idx:]
+ logger.info(
+ f"we get train utterances: {len(train_files)}, dev utterance: {len(dev_files)}"
+ )
+ return train_files, dev_files
+
+
+def prepare_data(args, config):
+ """Convert the jsonline format to csv format
+
+ Args:
+ args (argparse.Namespace): scripts args
+ config (CfgNode): yaml configuration content
+ """
+ # stage0: set the random seed
+ random.seed(config.seed)
+
+ # if external config set the skip_prep flat, we will do nothing
+ if config.skip_prep:
+ return
+
+ # stage 1: prepare the enroll and test csv file
+ # And we generate the speaker to label file label2id.txt
+ logger.info("start to prepare the data csv file")
+ enroll_files, test_files = get_enroll_test_list(
+ [args.test], verification_file=config.verification_file)
+ prepare_csv(
+ enroll_files,
+ os.path.join(args.target_dir, "csv", "enroll.csv"),
+ config,
+ split_chunks=False)
+ prepare_csv(
+ test_files,
+ os.path.join(args.target_dir, "csv", "test.csv"),
+ config,
+ split_chunks=False)
+
+ # stage 2: prepare the train and dev csv file
+ # we get the train dataset ratio as config.split_ratio
+ # and the remaining is dev dataset
+ logger.info("start to prepare the data csv file")
+ train_files, dev_files = get_train_dev_list(
+ args.train, target_dir=args.target_dir, split_ratio=config.split_ratio)
+ prepare_csv(train_files,
+ os.path.join(args.target_dir, "csv", "train.csv"), config)
+ prepare_csv(dev_files,
+ os.path.join(args.target_dir, "csv", "dev.csv"), config)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument(
+ "--train",
+ required=True,
+ nargs='+',
+ help="The jsonline files list for train.")
+ parser.add_argument(
+ "--test", required=True, help="The jsonline file for test")
+ parser.add_argument(
+ "--target_dir",
+ default=None,
+ required=True,
+ help="The target directory stores the csv files and meta file.")
+ parser.add_argument(
+ "--config",
+ default=None,
+ required=True,
+ type=str,
+ help="configuration file")
+ args = parser.parse_args()
+
+ # parse the yaml config file
+ config = CfgNode(new_allowed=True)
+ if args.config:
+ config.merge_from_file(args.config)
+
+ # prepare the csv file from jsonlines files
+ prepare_data(args, config)
diff --git a/examples/voxceleb/sv0/run.sh b/examples/voxceleb/sv0/run.sh
index bbc9e3dbb..e1dccf2ae 100755
--- a/examples/voxceleb/sv0/run.sh
+++ b/examples/voxceleb/sv0/run.sh
@@ -18,24 +18,22 @@ set -e
#######################################################################
# stage 0: data prepare, including voxceleb1 download and generate {train,dev,enroll,test}.csv
-# voxceleb2 data is m4a format, so we need user to convert the m4a to wav yourselves as described in Readme.md with the script local/convert.sh
+# voxceleb2 data is m4a format, so we need convert the m4a to wav yourselves with the script local/convert.sh
# stage 1: train the speaker identification model
# stage 2: test speaker identification
-# stage 3: extract the training embeding to train the LDA and PLDA
+# stage 3: (todo)extract the training embeding to train the LDA and PLDA
######################################################################
-# we can set the variable PPAUDIO_HOME to specifiy the root directory of the downloaded vox1 and vox2 dataset
-# default the dataset will be stored in the ~/.paddleaudio/
# the vox2 dataset is stored in m4a format, we need to convert the audio from m4a to wav yourself
-# and put all of them to ${PPAUDIO_HOME}/datasets/vox2
-# we will find the wav from ${PPAUDIO_HOME}/datasets/vox1/wav and ${PPAUDIO_HOME}/datasets/vox2/wav
-# export PPAUDIO_HOME=
+# and put all of them to ${MAIN_ROOT}/datasets/vox2
+# we will find the wav from ${MAIN_ROOT}/datasets/vox1/{dev,test}/wav and ${MAIN_ROOT}/datasets/vox2/wav
+
stage=0
stop_stage=50
# data directory
# if we set the variable ${dir}, we will store the wav info to this directory
-# otherwise, we will store the wav info to vox1 and vox2 directory respectively
+# otherwise, we will store the wav info to data/vox1 and data/vox2 directory respectively
# vox2 wav path, we must convert the m4a format to wav format
dir=data/ # data info directory
@@ -64,6 +62,6 @@ if [ $stage -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi
# if [ $stage -le 3 ]; then
-# # stage 2: extract the training embeding to train the LDA and PLDA
+# # stage 3: extract the training embeding to train the LDA and PLDA
# # todo: extract the training embedding
# fi
diff --git a/paddleaudio/paddleaudio/datasets/voxceleb.py b/paddleaudio/paddleaudio/datasets/voxceleb.py
index 3f72b5f2e..07f44e0c1 100644
--- a/paddleaudio/paddleaudio/datasets/voxceleb.py
+++ b/paddleaudio/paddleaudio/datasets/voxceleb.py
@@ -261,7 +261,7 @@ class VoxCeleb(Dataset):
output_file: str,
split_chunks: bool=True):
print(f'Generating csv: {output_file}')
- header = ["ID", "duration", "wav", "start", "stop", "spk_id"]
+ header = ["id", "duration", "wav", "start", "stop", "spk_id"]
# Note: this may occurs c++ execption, but the program will execute fine
# so we can ignore the execption
with Pool(cpu_count()) as p:
diff --git a/paddleaudio/paddleaudio/utils/numeric.py b/paddleaudio/paddleaudio/utils/numeric.py
new file mode 100644
index 000000000..126cada50
--- /dev/null
+++ b/paddleaudio/paddleaudio/utils/numeric.py
@@ -0,0 +1,30 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+
+
+def pcm16to32(audio: np.ndarray) -> np.ndarray:
+ """pcm int16 to float32
+
+ Args:
+ audio (np.ndarray): Waveform with dtype of int16.
+
+ Returns:
+ np.ndarray: Waveform with dtype of float32.
+ """
+ if audio.dtype == np.int16:
+ audio = audio.astype("float32")
+ bits = np.iinfo(np.int16).bits
+ audio = audio / (2**(bits - 1))
+ return audio
diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py
index 1fb4be434..b12b9f6fc 100644
--- a/paddlespeech/cli/asr/infer.py
+++ b/paddlespeech/cli/asr/infer.py
@@ -80,9 +80,9 @@ pretrained_models = {
},
"deepspeech2online_aishell-zh-16k": {
'url':
- 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz',
+ 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz',
'md5':
- 'd5e076217cf60486519f72c217d21b9b',
+ '23e16c69730a1cb5d735c98c83c21e16',
'cfg_path':
'model.yaml',
'ckpt_path':
@@ -426,6 +426,11 @@ class ASRExecutor(BaseExecutor):
try:
audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True)
+ audio_duration = audio.shape[0] / audio_sample_rate
+ max_duration = 50.0
+ if audio_duration >= max_duration:
+ logger.error("Please input audio file less then 50 seconds.\n")
+ return
except Exception as e:
logger.exception(e)
logger.error(
diff --git a/paddlespeech/cli/vector/infer.py b/paddlespeech/cli/vector/infer.py
index 175a9723e..68e832ac7 100644
--- a/paddlespeech/cli/vector/infer.py
+++ b/paddlespeech/cli/vector/infer.py
@@ -15,6 +15,7 @@ import argparse
import os
import sys
from collections import OrderedDict
+from typing import Dict
from typing import List
from typing import Optional
from typing import Union
@@ -42,9 +43,9 @@ pretrained_models = {
# "paddlespeech vector --task spk --model ecapatdnn_voxceleb12-16k --sr 16000 --input ./input.wav"
"ecapatdnn_voxceleb12-16k": {
'url':
- 'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz',
+ 'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz',
'md5':
- 'a1c0dba7d4de997187786ff517d5b4ec',
+ 'cc33023c54ab346cd318408f43fcaf95',
'cfg_path':
'conf/model.yaml', # the yaml config path
'ckpt_path':
@@ -79,7 +80,7 @@ class VectorExecutor(BaseExecutor):
"--task",
type=str,
default="spk",
- choices=["spk"],
+ choices=["spk", "score"],
help="task type in vector domain")
self.parser.add_argument(
"--input",
@@ -147,13 +148,40 @@ class VectorExecutor(BaseExecutor):
logger.info(f"task source: {task_source}")
# stage 3: process the audio one by one
+ # we do action according the task type
task_result = OrderedDict()
has_exceptions = False
for id_, input_ in task_source.items():
try:
- res = self(input_, model, sample_rate, config, ckpt_path,
- device)
- task_result[id_] = res
+ # extract the speaker audio embedding
+ if parser_args.task == "spk":
+ logger.info("do vector spk task")
+ res = self(input_, model, sample_rate, config, ckpt_path,
+ device)
+ task_result[id_] = res
+ elif parser_args.task == "score":
+ logger.info("do vector score task")
+ logger.info(f"input content {input_}")
+ if len(input_.split()) != 2:
+ logger.error(
+ f"vector score task input {input_} wav num is not two,"
+ "that is {len(input_.split())}")
+ sys.exit(-1)
+
+ # get the enroll and test embedding
+ enroll_audio, test_audio = input_.split()
+ logger.info(
+ f"score task, enroll audio: {enroll_audio}, test audio: {test_audio}"
+ )
+ enroll_embedding = self(enroll_audio, model, sample_rate,
+ config, ckpt_path, device)
+ test_embedding = self(test_audio, model, sample_rate,
+ config, ckpt_path, device)
+
+ # get the score
+ res = self.get_embeddings_score(enroll_embedding,
+ test_embedding)
+ task_result[id_] = res
except Exception as e:
has_exceptions = True
task_result[id_] = f'{e.__class__.__name__}: {e}'
@@ -172,6 +200,49 @@ class VectorExecutor(BaseExecutor):
else:
return True
+ def _get_job_contents(
+ self, job_input: os.PathLike) -> Dict[str, Union[str, os.PathLike]]:
+ """
+ Read a job input file and return its contents in a dictionary.
+ Refactor from the Executor._get_job_contents
+
+ Args:
+ job_input (os.PathLike): The job input file.
+
+ Returns:
+ Dict[str, str]: Contents of job input.
+ """
+ job_contents = OrderedDict()
+ with open(job_input) as f:
+ for line in f:
+ line = line.strip()
+ if not line:
+ continue
+ k = line.split(' ')[0]
+ v = ' '.join(line.split(' ')[1:])
+ job_contents[k] = v
+ return job_contents
+
+ def get_embeddings_score(self, enroll_embedding, test_embedding):
+ """get the enroll embedding and test embedding score
+
+ Args:
+ enroll_embedding (numpy.array): shape: (emb_size), enroll audio embedding
+ test_embedding (numpy.array): shape: (emb_size), test audio embedding
+
+ Returns:
+ score: the score between enroll embedding and test embedding
+ """
+ if not hasattr(self, "score_func"):
+ self.score_func = paddle.nn.CosineSimilarity(axis=0)
+ logger.info("create the cosine score function ")
+
+ score = self.score_func(
+ paddle.to_tensor(enroll_embedding),
+ paddle.to_tensor(test_embedding))
+
+ return score.item()
+
@stats_wrapper
def __call__(self,
audio_file: os.PathLike,
diff --git a/paddlespeech/server/bin/paddlespeech_server.py b/paddlespeech/server/bin/paddlespeech_server.py
index f6a7f4295..474a8b79f 100644
--- a/paddlespeech/server/bin/paddlespeech_server.py
+++ b/paddlespeech/server/bin/paddlespeech_server.py
@@ -23,8 +23,9 @@ from ..util import cli_server_register
from ..util import stats_wrapper
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.engine_pool import init_engine_pool
-from paddlespeech.server.restful.api import setup_router
+from paddlespeech.server.restful.api import setup_router as setup_http_router
from paddlespeech.server.utils.config import get_config
+from paddlespeech.server.ws.api import setup_router as setup_ws_router
__all__ = ['ServerExecutor', 'ServerStatsExecutor']
@@ -63,7 +64,12 @@ class ServerExecutor(BaseExecutor):
"""
# init api
api_list = list(engine.split("_")[0] for engine in config.engine_list)
- api_router = setup_router(api_list)
+ if config.protocol == "websocket":
+ api_router = setup_ws_router(api_list)
+ elif config.protocol == "http":
+ api_router = setup_http_router(api_list)
+ else:
+ raise Exception("unsupported protocol")
app.include_router(api_router)
if not init_engine_pool(config):
diff --git a/paddlespeech/server/conf/tts_online_application.yaml b/paddlespeech/server/conf/tts_online_application.yaml
new file mode 100644
index 000000000..a80b3ecec
--- /dev/null
+++ b/paddlespeech/server/conf/tts_online_application.yaml
@@ -0,0 +1,46 @@
+# This is the parameter configuration file for PaddleSpeech Serving.
+
+#################################################################################
+# SERVER SETTING #
+#################################################################################
+host: 127.0.0.1
+port: 8092
+
+# The task format in the engin_list is: _
+# task choices = ['asr_online', 'tts_online']
+# protocol = ['websocket', 'http'] (only one can be selected).
+protocol: 'http'
+engine_list: ['tts_online']
+
+
+#################################################################################
+# ENGINE CONFIG #
+#################################################################################
+
+################################### TTS #########################################
+################### speech task: tts; engine_type: online #######################
+tts_online:
+ # am (acoustic model) choices=['fastspeech2_csmsc']
+ am: 'fastspeech2_csmsc'
+ am_config:
+ am_ckpt:
+ am_stat:
+ phones_dict:
+ tones_dict:
+ speaker_dict:
+ spk_id: 0
+
+ # voc (vocoder) choices=['mb_melgan_csmsc']
+ voc: 'mb_melgan_csmsc'
+ voc_config:
+ voc_ckpt:
+ voc_stat:
+
+ # others
+ lang: 'zh'
+ device: # set 'gpu:id' or 'cpu'
+ am_block: 42
+ am_pad: 12
+ voc_block: 14
+ voc_pad: 14
+
diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py
index 389175a0a..1f356a3c6 100644
--- a/paddlespeech/server/engine/asr/online/asr_engine.py
+++ b/paddlespeech/server/engine/asr/online/asr_engine.py
@@ -27,6 +27,7 @@ from paddlespeech.s2t.frontend.speech import SpeechSegment
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.base_engine import BaseEngine
+from paddlespeech.server.utils.audio_process import pcm2float
from paddlespeech.server.utils.paddle_predictor import init_predictor
__all__ = ['ASREngine']
@@ -36,7 +37,7 @@ pretrained_models = {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz',
'md5':
- 'd5e076217cf60486519f72c217d21b9b',
+ '23e16c69730a1cb5d735c98c83c21e16',
'cfg_path':
'model.yaml',
'ckpt_path':
@@ -222,21 +223,6 @@ class ASRServerExecutor(ASRExecutor):
else:
raise Exception("invalid model name")
- def _pcm16to32(self, audio):
- """pcm int16 to float32
-
- Args:
- audio(numpy.array): numpy.int16
-
- Returns:
- audio(numpy.array): numpy.float32
- """
- if audio.dtype == np.int16:
- audio = audio.astype("float32")
- bits = np.iinfo(np.int16).bits
- audio = audio / (2**(bits - 1))
- return audio
-
def extract_feat(self, samples, sample_rate):
"""extract feat
@@ -249,7 +235,7 @@ class ASRServerExecutor(ASRExecutor):
x_chunk_lens (numpy.array): shape[B]
"""
# pcm16 -> pcm 32
- samples = self._pcm16to32(samples)
+ samples = pcm2float(samples)
# read audio
speech_segment = SpeechSegment.from_pcm(
diff --git a/paddlespeech/server/engine/engine_factory.py b/paddlespeech/server/engine/engine_factory.py
index 2a39fb79b..e147a29a6 100644
--- a/paddlespeech/server/engine/engine_factory.py
+++ b/paddlespeech/server/engine/engine_factory.py
@@ -34,6 +34,9 @@ class EngineFactory(object):
elif engine_name == 'tts' and engine_type == 'python':
from paddlespeech.server.engine.tts.python.tts_engine import TTSEngine
return TTSEngine()
+ elif engine_name == 'tts' and engine_type == 'online':
+ from paddlespeech.server.engine.tts.online.tts_engine import TTSEngine
+ return TTSEngine()
elif engine_name == 'cls' and engine_type == 'inference':
from paddlespeech.server.engine.cls.paddleinference.cls_engine import CLSEngine
return CLSEngine()
diff --git a/paddlespeech/server/engine/tts/online/__init__.py b/paddlespeech/server/engine/tts/online/__init__.py
new file mode 100644
index 000000000..97043fd7b
--- /dev/null
+++ b/paddlespeech/server/engine/tts/online/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/paddlespeech/server/engine/tts/online/tts_engine.py b/paddlespeech/server/engine/tts/online/tts_engine.py
new file mode 100644
index 000000000..25a8bc76f
--- /dev/null
+++ b/paddlespeech/server/engine/tts/online/tts_engine.py
@@ -0,0 +1,220 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import base64
+import time
+
+import numpy as np
+import paddle
+
+from paddlespeech.cli.log import logger
+from paddlespeech.cli.tts.infer import TTSExecutor
+from paddlespeech.server.engine.base_engine import BaseEngine
+from paddlespeech.server.utils.audio_process import float2pcm
+from paddlespeech.server.utils.util import get_chunks
+
+__all__ = ['TTSEngine']
+
+
+class TTSServerExecutor(TTSExecutor):
+ def __init__(self):
+ super().__init__()
+ pass
+
+ @paddle.no_grad()
+ def infer(
+ self,
+ text: str,
+ lang: str='zh',
+ am: str='fastspeech2_csmsc',
+ spk_id: int=0,
+ am_block: int=42,
+ am_pad: int=12,
+ voc_block: int=14,
+ voc_pad: int=14, ):
+ """
+ Model inference and result stored in self.output.
+ """
+ am_name = am[:am.rindex('_')]
+ am_dataset = am[am.rindex('_') + 1:]
+ get_tone_ids = False
+ merge_sentences = False
+ frontend_st = time.time()
+ if lang == 'zh':
+ input_ids = self.frontend.get_input_ids(
+ text,
+ merge_sentences=merge_sentences,
+ get_tone_ids=get_tone_ids)
+ phone_ids = input_ids["phone_ids"]
+ if get_tone_ids:
+ tone_ids = input_ids["tone_ids"]
+ elif lang == 'en':
+ input_ids = self.frontend.get_input_ids(
+ text, merge_sentences=merge_sentences)
+ phone_ids = input_ids["phone_ids"]
+ else:
+ print("lang should in {'zh', 'en'}!")
+ self.frontend_time = time.time() - frontend_st
+
+ for i in range(len(phone_ids)):
+ am_st = time.time()
+ part_phone_ids = phone_ids[i]
+ # am
+ if am_name == 'speedyspeech':
+ part_tone_ids = tone_ids[i]
+ mel = self.am_inference(part_phone_ids, part_tone_ids)
+ # fastspeech2
+ else:
+ # multi speaker
+ if am_dataset in {"aishell3", "vctk"}:
+ mel = self.am_inference(
+ part_phone_ids, spk_id=paddle.to_tensor(spk_id))
+ else:
+ mel = self.am_inference(part_phone_ids)
+ am_et = time.time()
+
+ # voc streaming
+ voc_upsample = self.voc_config.n_shift
+ mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc")
+ chunk_num = len(mel_chunks)
+ voc_st = time.time()
+ for i, mel_chunk in enumerate(mel_chunks):
+ sub_wav = self.voc_inference(mel_chunk)
+ front_pad = min(i * voc_block, voc_pad)
+
+ if i == 0:
+ sub_wav = sub_wav[:voc_block * voc_upsample]
+ elif i == chunk_num - 1:
+ sub_wav = sub_wav[front_pad * voc_upsample:]
+ else:
+ sub_wav = sub_wav[front_pad * voc_upsample:(
+ front_pad + voc_block) * voc_upsample]
+
+ yield sub_wav
+
+
+class TTSEngine(BaseEngine):
+ """TTS server engine
+
+ Args:
+ metaclass: Defaults to Singleton.
+ """
+
+ def __init__(self, name=None):
+ """Initialize TTS server engine
+ """
+ super(TTSEngine, self).__init__()
+
+ def init(self, config: dict) -> bool:
+ self.executor = TTSServerExecutor()
+ self.config = config
+ assert "fastspeech2_csmsc" in config.am and (
+ config.voc == "hifigan_csmsc-zh" or config.voc == "mb_melgan_csmsc"
+ ), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.'
+ try:
+ if self.config.device:
+ self.device = self.config.device
+ else:
+ self.device = paddle.get_device()
+ paddle.set_device(self.device)
+ except Exception as e:
+ logger.error(
+ "Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
+ )
+ logger.error("Initialize TTS server engine Failed on device: %s." %
+ (self.device))
+ return False
+
+ try:
+ self.executor._init_from_path(
+ am=self.config.am,
+ am_config=self.config.am_config,
+ am_ckpt=self.config.am_ckpt,
+ am_stat=self.config.am_stat,
+ phones_dict=self.config.phones_dict,
+ tones_dict=self.config.tones_dict,
+ speaker_dict=self.config.speaker_dict,
+ voc=self.config.voc,
+ voc_config=self.config.voc_config,
+ voc_ckpt=self.config.voc_ckpt,
+ voc_stat=self.config.voc_stat,
+ lang=self.config.lang)
+ except Exception as e:
+ logger.error("Failed to get model related files.")
+ logger.error("Initialize TTS server engine Failed on device: %s." %
+ (self.device))
+ return False
+
+ self.am_block = self.config.am_block
+ self.am_pad = self.config.am_pad
+ self.voc_block = self.config.voc_block
+ self.voc_pad = self.config.voc_pad
+
+ logger.info("Initialize TTS server engine successfully on device: %s." %
+ (self.device))
+ return True
+
+ def preprocess(self, text_bese64: str=None, text_bytes: bytes=None):
+ # Convert byte to text
+ if text_bese64:
+ text_bytes = base64.b64decode(text_bese64) # base64 to bytes
+ text = text_bytes.decode('utf-8') # bytes to text
+
+ return text
+
+ def run(self,
+ sentence: str,
+ spk_id: int=0,
+ speed: float=1.0,
+ volume: float=1.0,
+ sample_rate: int=0,
+ save_path: str=None):
+ """ run include inference and postprocess.
+
+ Args:
+ sentence (str): text to be synthesized
+ spk_id (int, optional): speaker id for multi-speaker speech synthesis. Defaults to 0.
+ speed (float, optional): speed. Defaults to 1.0.
+ volume (float, optional): volume. Defaults to 1.0.
+ sample_rate (int, optional): target sample rate for synthesized audio,
+ 0 means the same as the model sampling rate. Defaults to 0.
+ save_path (str, optional): The save path of the synthesized audio.
+ None means do not save audio. Defaults to None.
+
+ Returns:
+ wav_base64: The base64 format of the synthesized audio.
+ """
+
+ lang = self.config.lang
+ wav_list = []
+
+ for wav in self.executor.infer(
+ text=sentence,
+ lang=lang,
+ am=self.config.am,
+ spk_id=spk_id,
+ am_block=self.am_block,
+ am_pad=self.am_pad,
+ voc_block=self.voc_block,
+ voc_pad=self.voc_pad):
+ # wav type: float32, convert to pcm (base64)
+ wav = float2pcm(wav) # float32 to int16
+ wav_bytes = wav.tobytes() # to bytes
+ wav_base64 = base64.b64encode(wav_bytes).decode('utf8') # to base64
+ wav_list.append(wav)
+
+ yield wav_base64
+
+ wav_all = np.concatenate(wav_list, axis=0)
+ logger.info("The durations of audio is: {} s".format(
+ len(wav_all) / self.executor.am_config.fs))
diff --git a/paddlespeech/server/restful/tts_api.py b/paddlespeech/server/restful/tts_api.py
index 4e9bbe23e..d1268428a 100644
--- a/paddlespeech/server/restful/tts_api.py
+++ b/paddlespeech/server/restful/tts_api.py
@@ -15,6 +15,7 @@ import traceback
from typing import Union
from fastapi import APIRouter
+from fastapi.responses import StreamingResponse
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.engine_pool import get_engine_pool
@@ -125,3 +126,14 @@ def tts(request_body: TTSRequest):
traceback.print_exc()
return response
+
+
+@router.post("/paddlespeech/streaming/tts")
+async def stream_tts(request_body: TTSRequest):
+ text = request_body.text
+
+ engine_pool = get_engine_pool()
+ tts_engine = engine_pool['tts']
+ logger.info("Get tts engine successfully.")
+
+ return StreamingResponse(tts_engine.run(sentence=text))
diff --git a/paddlespeech/server/tests/tts/test_client.py b/paddlespeech/server/tests/tts/offline/http_client.py
similarity index 90%
rename from paddlespeech/server/tests/tts/test_client.py
rename to paddlespeech/server/tests/tts/offline/http_client.py
index e42c9bcfa..1bdee4c18 100644
--- a/paddlespeech/server/tests/tts/test_client.py
+++ b/paddlespeech/server/tests/tts/offline/http_client.py
@@ -33,7 +33,8 @@ def tts_client(args):
text: A sentence to be synthesized
outfile: Synthetic audio file
"""
- url = 'http://127.0.0.1:8090/paddlespeech/tts'
+ url = "http://" + str(args.server) + ":" + str(
+ args.port) + "/paddlespeech/tts"
request = {
"text": args.text,
"spk_id": args.spk_id,
@@ -72,7 +73,7 @@ if __name__ == "__main__":
parser.add_argument(
'--text',
type=str,
- default="你好,欢迎使用语音合成服务",
+ default="您好,欢迎使用语音合成服务。",
help='A sentence to be synthesized')
parser.add_argument('--spk_id', type=int, default=0, help='Speaker id')
parser.add_argument('--speed', type=float, default=1.0, help='Audio speed')
@@ -88,6 +89,9 @@ if __name__ == "__main__":
type=str,
default="./out.wav",
help='Synthesized audio file')
+ parser.add_argument(
+ "--server", type=str, help="server ip", default="127.0.0.1")
+ parser.add_argument("--port", type=int, help="server port", default=8090)
args = parser.parse_args()
st = time.time()
diff --git a/paddlespeech/server/tests/tts/online/http_client.py b/paddlespeech/server/tests/tts/online/http_client.py
new file mode 100644
index 000000000..cbc1f5c02
--- /dev/null
+++ b/paddlespeech/server/tests/tts/online/http_client.py
@@ -0,0 +1,100 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import base64
+import json
+import os
+import time
+
+import requests
+
+from paddlespeech.server.utils.audio_process import pcm2wav
+
+
+def save_audio(buffer, audio_path) -> bool:
+ if args.save_path.endswith("pcm"):
+ with open(args.save_path, "wb") as f:
+ f.write(buffer)
+ elif args.save_path.endswith("wav"):
+ with open("./tmp.pcm", "wb") as f:
+ f.write(buffer)
+ pcm2wav("./tmp.pcm", audio_path, channels=1, bits=16, sample_rate=24000)
+ os.system("rm ./tmp.pcm")
+ else:
+ print("Only supports saved audio format is pcm or wav")
+ return False
+
+ return True
+
+
+def test(args):
+ params = {
+ "text": args.text,
+ "spk_id": args.spk_id,
+ "speed": args.speed,
+ "volume": args.volume,
+ "sample_rate": args.sample_rate,
+ "save_path": ''
+ }
+
+ buffer = b''
+ flag = 1
+ url = "http://" + str(args.server) + ":" + str(
+ args.port) + "/paddlespeech/streaming/tts"
+ st = time.time()
+ html = requests.post(url, json.dumps(params), stream=True)
+ for chunk in html.iter_content(chunk_size=1024):
+ chunk = base64.b64decode(chunk) # bytes
+ if flag:
+ first_response = time.time() - st
+ print(f"首包响应:{first_response} s")
+ flag = 0
+ buffer += chunk
+
+ final_response = time.time() - st
+ duration = len(buffer) / 2.0 / 24000
+
+ print(f"尾包响应:{final_response} s")
+ print(f"音频时长:{duration} s")
+ print(f"RTF: {final_response / duration}")
+
+ if args.save_path is not None:
+ if save_audio(buffer, args.save_path):
+ print("音频保存至:", args.save_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--text',
+ type=str,
+ default="您好,欢迎使用语音合成服务。",
+ help='A sentence to be synthesized')
+ parser.add_argument('--spk_id', type=int, default=0, help='Speaker id')
+ parser.add_argument('--speed', type=float, default=1.0, help='Audio speed')
+ parser.add_argument(
+ '--volume', type=float, default=1.0, help='Audio volume')
+ parser.add_argument(
+ '--sample_rate',
+ type=int,
+ default=0,
+ help='Sampling rate, the default is the same as the model')
+ parser.add_argument(
+ "--server", type=str, help="server ip", default="127.0.0.1")
+ parser.add_argument("--port", type=int, help="server port", default=8092)
+ parser.add_argument(
+ "--save_path", type=str, help="save audio path", default=None)
+
+ args = parser.parse_args()
+ test(args)
diff --git a/paddlespeech/server/tests/tts/online/http_client_playaudio.py b/paddlespeech/server/tests/tts/online/http_client_playaudio.py
new file mode 100644
index 000000000..1e7e8064e
--- /dev/null
+++ b/paddlespeech/server/tests/tts/online/http_client_playaudio.py
@@ -0,0 +1,112 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import base64
+import json
+import threading
+import time
+
+import pyaudio
+import requests
+
+mutex = threading.Lock()
+buffer = b''
+p = pyaudio.PyAudio()
+stream = p.open(
+ format=p.get_format_from_width(2), channels=1, rate=24000, output=True)
+max_fail = 50
+
+
+def play_audio():
+ global stream
+ global buffer
+ global max_fail
+ while True:
+ if not buffer:
+ max_fail -= 1
+ time.sleep(0.05)
+ if max_fail < 0:
+ break
+ mutex.acquire()
+ stream.write(buffer)
+ buffer = b''
+ mutex.release()
+
+
+def test(args):
+ global mutex
+ global buffer
+ params = {
+ "text": args.text,
+ "spk_id": args.spk_id,
+ "speed": args.speed,
+ "volume": args.volume,
+ "sample_rate": args.sample_rate,
+ "save_path": ''
+ }
+
+ all_bytes = 0.0
+ t = threading.Thread(target=play_audio)
+ flag = 1
+ url = "http://" + str(args.server) + ":" + str(
+ args.port) + "/paddlespeech/streaming/tts"
+ st = time.time()
+ html = requests.post(url, json.dumps(params), stream=True)
+ for chunk in html.iter_content(chunk_size=1024):
+ mutex.acquire()
+ chunk = base64.b64decode(chunk) # bytes
+ buffer += chunk
+ mutex.release()
+ if flag:
+ first_response = time.time() - st
+ print(f"首包响应:{first_response} s")
+ flag = 0
+ t.start()
+ all_bytes += len(chunk)
+
+ final_response = time.time() - st
+ duration = all_bytes / 2 / 24000
+
+ print(f"尾包响应:{final_response} s")
+ print(f"音频时长:{duration} s")
+ print(f"RTF: {final_response / duration}")
+
+ t.join()
+ stream.stop_stream()
+ stream.close()
+ p.terminate()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--text',
+ type=str,
+ default="您好,欢迎使用语音合成服务。",
+ help='A sentence to be synthesized')
+ parser.add_argument('--spk_id', type=int, default=0, help='Speaker id')
+ parser.add_argument('--speed', type=float, default=1.0, help='Audio speed')
+ parser.add_argument(
+ '--volume', type=float, default=1.0, help='Audio volume')
+ parser.add_argument(
+ '--sample_rate',
+ type=int,
+ default=0,
+ help='Sampling rate, the default is the same as the model')
+ parser.add_argument(
+ "--server", type=str, help="server ip", default="127.0.0.1")
+ parser.add_argument("--port", type=int, help="server port", default=8092)
+
+ args = parser.parse_args()
+ test(args)
diff --git a/paddlespeech/server/tests/tts/online/ws_client.py b/paddlespeech/server/tests/tts/online/ws_client.py
new file mode 100644
index 000000000..eef010cf2
--- /dev/null
+++ b/paddlespeech/server/tests/tts/online/ws_client.py
@@ -0,0 +1,126 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import _thread as thread
+import argparse
+import base64
+import json
+import ssl
+import time
+
+import websocket
+
+flag = 1
+st = 0.0
+all_bytes = b''
+
+
+class WsParam(object):
+ # 初始化
+ def __init__(self, text, server="127.0.0.1", port=8090):
+ self.server = server
+ self.port = port
+ self.url = "ws://" + self.server + ":" + str(self.port) + "/ws/tts"
+ self.text = text
+
+ # 生成url
+ def create_url(self):
+ return self.url
+
+
+def on_message(ws, message):
+ global flag
+ global st
+ global all_bytes
+
+ try:
+ message = json.loads(message)
+ audio = message["audio"]
+ audio = base64.b64decode(audio) # bytes
+ status = message["status"]
+ all_bytes += audio
+
+ if status == 0:
+ print("create successfully.")
+ elif status == 1:
+ if flag:
+ print(f"首包响应:{time.time() - st} s")
+ flag = 0
+ elif status == 2:
+ final_response = time.time() - st
+ duration = len(all_bytes) / 2.0 / 24000
+ print(f"尾包响应:{final_response} s")
+ print(f"音频时长:{duration} s")
+ print(f"RTF: {final_response / duration}")
+ with open("./out.pcm", "wb") as f:
+ f.write(all_bytes)
+ print("ws is closed")
+ ws.close()
+ else:
+ print("infer error")
+
+ except Exception as e:
+ print("receive msg,but parse exception:", e)
+
+
+# 收到websocket错误的处理
+def on_error(ws, error):
+ print("### error:", error)
+
+
+# 收到websocket关闭的处理
+def on_close(ws):
+ print("### closed ###")
+
+
+# 收到websocket连接建立的处理
+def on_open(ws):
+ def run(*args):
+ global st
+ text_base64 = str(
+ base64.b64encode((wsParam.text).encode('utf-8')), "UTF8")
+ d = {"text": text_base64}
+ d = json.dumps(d)
+ print("Start sending text data")
+ st = time.time()
+ ws.send(d)
+
+ thread.start_new_thread(run, ())
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--text",
+ type=str,
+ help="A sentence to be synthesized",
+ default="您好,欢迎使用语音合成服务。")
+ parser.add_argument(
+ "--server", type=str, help="server ip", default="127.0.0.1")
+ parser.add_argument("--port", type=int, help="server port", default=8092)
+ args = parser.parse_args()
+
+ print("***************************************")
+ print("Server ip: ", args.server)
+ print("Server port: ", args.port)
+ print("Sentence to be synthesized: ", args.text)
+ print("***************************************")
+
+ wsParam = WsParam(text=args.text, server=args.server, port=args.port)
+
+ websocket.enableTrace(False)
+ wsUrl = wsParam.create_url()
+ ws = websocket.WebSocketApp(
+ wsUrl, on_message=on_message, on_error=on_error, on_close=on_close)
+ ws.on_open = on_open
+ ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
diff --git a/paddlespeech/server/tests/tts/online/ws_client_playaudio.py b/paddlespeech/server/tests/tts/online/ws_client_playaudio.py
new file mode 100644
index 000000000..cdeb362df
--- /dev/null
+++ b/paddlespeech/server/tests/tts/online/ws_client_playaudio.py
@@ -0,0 +1,160 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import _thread as thread
+import argparse
+import base64
+import json
+import ssl
+import threading
+import time
+
+import pyaudio
+import websocket
+
+mutex = threading.Lock()
+buffer = b''
+p = pyaudio.PyAudio()
+stream = p.open(
+ format=p.get_format_from_width(2), channels=1, rate=24000, output=True)
+flag = 1
+st = 0.0
+all_bytes = 0.0
+
+
+class WsParam(object):
+ # 初始化
+ def __init__(self, text, server="127.0.0.1", port=8090):
+ self.server = server
+ self.port = port
+ self.url = "ws://" + self.server + ":" + str(self.port) + "/ws/tts"
+ self.text = text
+
+ # 生成url
+ def create_url(self):
+ return self.url
+
+
+def play_audio():
+ global stream
+ global buffer
+ while True:
+ time.sleep(0.05)
+ if not buffer: # buffer 为空
+ break
+ mutex.acquire()
+ stream.write(buffer)
+ buffer = b''
+ mutex.release()
+
+
+t = threading.Thread(target=play_audio)
+
+
+def on_message(ws, message):
+ global flag
+ global t
+ global buffer
+ global st
+ global all_bytes
+
+ try:
+ message = json.loads(message)
+ audio = message["audio"]
+ audio = base64.b64decode(audio) # bytes
+ status = message["status"]
+ all_bytes += len(audio)
+
+ if status == 0:
+ print("create successfully.")
+ elif status == 1:
+ mutex.acquire()
+ buffer += audio
+ mutex.release()
+ if flag:
+ print(f"首包响应:{time.time() - st} s")
+ flag = 0
+ print("Start playing audio")
+ t.start()
+ elif status == 2:
+ final_response = time.time() - st
+ duration = all_bytes / 2 / 24000
+ print(f"尾包响应:{final_response} s")
+ print(f"音频时长:{duration} s")
+ print(f"RTF: {final_response / duration}")
+ print("ws is closed")
+ ws.close()
+ else:
+ print("infer error")
+
+ except Exception as e:
+ print("receive msg,but parse exception:", e)
+
+
+# 收到websocket错误的处理
+def on_error(ws, error):
+ print("### error:", error)
+
+
+# 收到websocket关闭的处理
+def on_close(ws):
+ print("### closed ###")
+
+
+# 收到websocket连接建立的处理
+def on_open(ws):
+ def run(*args):
+ global st
+ text_base64 = str(
+ base64.b64encode((wsParam.text).encode('utf-8')), "UTF8")
+ d = {"text": text_base64}
+ d = json.dumps(d)
+ print("Start sending text data")
+ st = time.time()
+ ws.send(d)
+
+ thread.start_new_thread(run, ())
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--text",
+ type=str,
+ help="A sentence to be synthesized",
+ default="您好,欢迎使用语音合成服务。")
+ parser.add_argument(
+ "--server", type=str, help="server ip", default="127.0.0.1")
+ parser.add_argument("--port", type=int, help="server port", default=8092)
+ args = parser.parse_args()
+
+ print("***************************************")
+ print("Server ip: ", args.server)
+ print("Server port: ", args.port)
+ print("Sentence to be synthesized: ", args.text)
+ print("***************************************")
+
+ wsParam = WsParam(text=args.text, server=args.server, port=args.port)
+
+ websocket.enableTrace(False)
+ wsUrl = wsParam.create_url()
+ ws = websocket.WebSocketApp(
+ wsUrl, on_message=on_message, on_error=on_error, on_close=on_close)
+ ws.on_open = on_open
+ ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
+
+ t.join()
+ print("End of playing audio")
+ stream.stop_stream()
+ stream.close()
+ p.terminate()
diff --git a/paddlespeech/server/utils/audio_process.py b/paddlespeech/server/utils/audio_process.py
index 3cbb495a6..e85b9a27e 100644
--- a/paddlespeech/server/utils/audio_process.py
+++ b/paddlespeech/server/utils/audio_process.py
@@ -103,3 +103,40 @@ def change_speed(sample_raw, speed_rate, sample_rate):
sample_rate_in=sample_rate).squeeze(-1).astype(np.float32).copy()
return sample_speed
+
+
+def float2pcm(sig, dtype='int16'):
+ """Convert floating point signal with a range from -1 to 1 to PCM.
+
+ Args:
+ sig (array): Input array, must have floating point type.
+ dtype (str, optional): Desired (integer) data type. Defaults to 'int16'.
+
+ Returns:
+ numpy.ndarray: Integer data, scaled and clipped to the range of the given
+ """
+ sig = np.asarray(sig)
+ if sig.dtype.kind != 'f':
+ raise TypeError("'sig' must be a float array")
+ dtype = np.dtype(dtype)
+ if dtype.kind not in 'iu':
+ raise TypeError("'dtype' must be an integer type")
+
+ i = np.iinfo(dtype)
+ abs_max = 2**(i.bits - 1)
+ offset = i.min + abs_max
+ return (sig * abs_max + offset).clip(i.min, i.max).astype(dtype)
+
+
+def pcm2float(data):
+ """pcm int16 to float32
+ Args:
+ audio(numpy.array): numpy.int16
+ Returns:
+ audio(numpy.array): numpy.float32
+ """
+ if data.dtype == np.int16:
+ data = data.astype("float32")
+ bits = np.iinfo(np.int16).bits
+ data = data / (2**(bits - 1))
+ return data
diff --git a/paddlespeech/server/utils/util.py b/paddlespeech/server/utils/util.py
index e9104fa2d..0fe70849d 100644
--- a/paddlespeech/server/utils/util.py
+++ b/paddlespeech/server/utils/util.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the
import base64
+import math
def wav2base64(wav_file: str):
@@ -31,3 +32,42 @@ def self_check():
""" self check resource
"""
return True
+
+
+def denorm(data, mean, std):
+ """stream am model need to denorm
+ """
+ return data * std + mean
+
+
+def get_chunks(data, block_size, pad_size, step):
+ """Divide data into multiple chunks
+
+ Args:
+ data (tensor): data
+ block_size (int): [description]
+ pad_size (int): [description]
+ step (str): set "am" or "voc", generate chunk for step am or vocoder(voc)
+
+ Returns:
+ list: chunks list
+ """
+ if step == "am":
+ data_len = data.shape[1]
+ elif step == "voc":
+ data_len = data.shape[0]
+ else:
+ print("Please set correct type to get chunks, am or voc")
+
+ chunks = []
+ n = math.ceil(data_len / block_size)
+ for i in range(n):
+ start = max(0, i * block_size - pad_size)
+ end = min((i + 1) * block_size + pad_size, data_len)
+ if step == "am":
+ chunks.append(data[:, start:end, :])
+ elif step == "voc":
+ chunks.append(data[start:end, :])
+ else:
+ print("Please set correct type to get chunks, am or voc")
+ return chunks
diff --git a/paddlespeech/server/ws/api.py b/paddlespeech/server/ws/api.py
index 10664d114..313fd16f5 100644
--- a/paddlespeech/server/ws/api.py
+++ b/paddlespeech/server/ws/api.py
@@ -16,6 +16,7 @@ from typing import List
from fastapi import APIRouter
from paddlespeech.server.ws.asr_socket import router as asr_router
+from paddlespeech.server.ws.tts_socket import router as tts_router
_router = APIRouter()
@@ -31,7 +32,7 @@ def setup_router(api_list: List):
if api_name == 'asr':
_router.include_router(asr_router)
elif api_name == 'tts':
- pass
+ _router.include_router(tts_router)
else:
pass
diff --git a/paddlespeech/server/ws/tts_socket.py b/paddlespeech/server/ws/tts_socket.py
new file mode 100644
index 000000000..11458b3cf
--- /dev/null
+++ b/paddlespeech/server/ws/tts_socket.py
@@ -0,0 +1,62 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+
+from fastapi import APIRouter
+from fastapi import WebSocket
+from fastapi import WebSocketDisconnect
+from starlette.websockets import WebSocketState as WebSocketState
+
+from paddlespeech.cli.log import logger
+from paddlespeech.server.engine.engine_pool import get_engine_pool
+
+router = APIRouter()
+
+
+@router.websocket('/ws/tts')
+async def websocket_endpoint(websocket: WebSocket):
+ await websocket.accept()
+
+ try:
+ # careful here, changed the source code from starlette.websockets
+ assert websocket.application_state == WebSocketState.CONNECTED
+ message = await websocket.receive()
+ websocket._raise_on_disconnect(message)
+
+ # get engine
+ engine_pool = get_engine_pool()
+ tts_engine = engine_pool['tts']
+
+ # 获取 message 并转文本
+ message = json.loads(message["text"])
+ text_bese64 = message["text"]
+ sentence = tts_engine.preprocess(text_bese64=text_bese64)
+
+ # run
+ wav_generator = tts_engine.run(sentence)
+
+ while True:
+ try:
+ tts_results = next(wav_generator)
+ resp = {"status": 1, "audio": tts_results}
+ await websocket.send_json(resp)
+ logger.info("streaming audio...")
+ except StopIteration as e:
+ resp = {"status": 2, "audio": ''}
+ await websocket.send_json(resp)
+ logger.info("Complete the transmission of audio streams")
+ break
+
+ except WebSocketDisconnect:
+ pass
diff --git a/paddlespeech/t2s/exps/fastspeech2/preprocess.py b/paddlespeech/t2s/exps/fastspeech2/preprocess.py
index 5bda75451..db1842b2e 100644
--- a/paddlespeech/t2s/exps/fastspeech2/preprocess.py
+++ b/paddlespeech/t2s/exps/fastspeech2/preprocess.py
@@ -86,6 +86,9 @@ def process_sentence(config: Dict[str, Any],
logmel = mel_extractor.get_log_mel_fbank(wav)
# change duration according to mel_length
compare_duration_and_mel_length(sentences, utt_id, logmel)
+ # utt_id may be popped in compare_duration_and_mel_length
+ if utt_id not in sentences:
+ return None
phones = sentences[utt_id][0]
durations = sentences[utt_id][1]
num_frames = logmel.shape[0]
diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py
index 1188ddfb1..62602a01f 100644
--- a/paddlespeech/t2s/exps/inference.py
+++ b/paddlespeech/t2s/exps/inference.py
@@ -104,7 +104,7 @@ def get_voc_output(args, voc_predictor, input):
def parse_args():
parser = argparse.ArgumentParser(
- description="Paddle Infernce with speedyspeech & parallel wavegan.")
+ description="Paddle Infernce with acoustic model & vocoder.")
# acoustic model
parser.add_argument(
'--am',
diff --git a/paddlespeech/t2s/exps/ort_predict.py b/paddlespeech/t2s/exps/ort_predict.py
new file mode 100644
index 000000000..e8d4d61c3
--- /dev/null
+++ b/paddlespeech/t2s/exps/ort_predict.py
@@ -0,0 +1,156 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+from pathlib import Path
+
+import jsonlines
+import numpy as np
+import onnxruntime as ort
+import soundfile as sf
+from timer import timer
+
+from paddlespeech.t2s.exps.syn_utils import get_test_dataset
+from paddlespeech.t2s.utils import str2bool
+
+
+def get_sess(args, filed='am'):
+ full_name = ''
+ if filed == 'am':
+ full_name = args.am
+ elif filed == 'voc':
+ full_name = args.voc
+ model_dir = str(Path(args.inference_dir) / (full_name + ".onnx"))
+ sess_options = ort.SessionOptions()
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
+ sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
+
+ if args.device == "gpu":
+ # fastspeech2/mb_melgan can't use trt now!
+ if args.use_trt:
+ providers = ['TensorrtExecutionProvider']
+ else:
+ providers = ['CUDAExecutionProvider']
+ elif args.device == "cpu":
+ providers = ['CPUExecutionProvider']
+ sess_options.intra_op_num_threads = args.cpu_threads
+ sess = ort.InferenceSession(
+ model_dir, providers=providers, sess_options=sess_options)
+ return sess
+
+
+def ort_predict(args):
+ # construct dataset for evaluation
+ with jsonlines.open(args.test_metadata, 'r') as reader:
+ test_metadata = list(reader)
+ am_name = args.am[:args.am.rindex('_')]
+ am_dataset = args.am[args.am.rindex('_') + 1:]
+ test_dataset = get_test_dataset(args, test_metadata, am_name, am_dataset)
+
+ output_dir = Path(args.output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ fs = 24000 if am_dataset != 'ljspeech' else 22050
+
+ # am
+ am_sess = get_sess(args, filed='am')
+
+ # vocoder
+ voc_sess = get_sess(args, filed='voc')
+
+ # am warmup
+ for T in [27, 38, 54]:
+ data = np.random.randint(1, 266, size=(T, ))
+ am_sess.run(None, {"text": data})
+
+ # voc warmup
+ for T in [227, 308, 544]:
+ data = np.random.rand(T, 80).astype("float32")
+ voc_sess.run(None, {"logmel": data})
+ print("warm up done!")
+
+ N = 0
+ T = 0
+ for example in test_dataset:
+ utt_id = example['utt_id']
+ phone_ids = example["text"]
+ with timer() as t:
+ mel = am_sess.run(output_names=None, input_feed={'text': phone_ids})
+ mel = mel[0]
+ wav = voc_sess.run(output_names=None, input_feed={'logmel': mel})
+
+ N += len(wav[0])
+ T += t.elapse
+ speed = len(wav[0]) / t.elapse
+ rtf = fs / speed
+ sf.write(
+ str(output_dir / (utt_id + ".wav")),
+ np.array(wav)[0],
+ samplerate=fs)
+ print(
+ f"{utt_id}, mel: {mel.shape}, wave: {len(wav[0])}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
+ )
+ print(f"generation speed: {N / T}Hz, RTF: {fs / (N / T) }")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Infernce with onnxruntime.")
+ # acoustic model
+ parser.add_argument(
+ '--am',
+ type=str,
+ default='fastspeech2_csmsc',
+ choices=[
+ 'fastspeech2_csmsc',
+ ],
+ help='Choose acoustic model type of tts task.')
+
+ # voc
+ parser.add_argument(
+ '--voc',
+ type=str,
+ default='hifigan_csmsc',
+ choices=['hifigan_csmsc', 'mb_melgan_csmsc'],
+ help='Choose vocoder type of tts task.')
+ # other
+ parser.add_argument(
+ "--inference_dir", type=str, help="dir to save inference models")
+ parser.add_argument("--test_metadata", type=str, help="test metadata.")
+ parser.add_argument("--output_dir", type=str, help="output dir")
+
+ # inference
+ parser.add_argument(
+ "--use_trt",
+ type=str2bool,
+ default=False,
+ help="Whether to use inference engin TensorRT.", )
+
+ parser.add_argument(
+ "--device",
+ default="gpu",
+ choices=["gpu", "cpu"],
+ help="Device selected for inference.", )
+ parser.add_argument('--cpu_threads', type=int, default=1)
+
+ args, _ = parser.parse_known_args()
+ return args
+
+
+def main():
+ args = parse_args()
+
+ ort_predict(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/paddlespeech/t2s/exps/ort_predict_e2e.py b/paddlespeech/t2s/exps/ort_predict_e2e.py
new file mode 100644
index 000000000..8aa04cbc5
--- /dev/null
+++ b/paddlespeech/t2s/exps/ort_predict_e2e.py
@@ -0,0 +1,183 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+from pathlib import Path
+
+import numpy as np
+import onnxruntime as ort
+import soundfile as sf
+from timer import timer
+
+from paddlespeech.t2s.exps.syn_utils import get_frontend
+from paddlespeech.t2s.exps.syn_utils import get_sentences
+from paddlespeech.t2s.utils import str2bool
+
+
+def get_sess(args, filed='am'):
+ full_name = ''
+ if filed == 'am':
+ full_name = args.am
+ elif filed == 'voc':
+ full_name = args.voc
+ model_dir = str(Path(args.inference_dir) / (full_name + ".onnx"))
+ sess_options = ort.SessionOptions()
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
+ sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
+
+ if args.device == "gpu":
+ # fastspeech2/mb_melgan can't use trt now!
+ if args.use_trt:
+ providers = ['TensorrtExecutionProvider']
+ else:
+ providers = ['CUDAExecutionProvider']
+ elif args.device == "cpu":
+ providers = ['CPUExecutionProvider']
+ sess_options.intra_op_num_threads = args.cpu_threads
+ sess = ort.InferenceSession(
+ model_dir, providers=providers, sess_options=sess_options)
+ return sess
+
+
+def ort_predict(args):
+
+ # frontend
+ frontend = get_frontend(args)
+
+ output_dir = Path(args.output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+ sentences = get_sentences(args)
+
+ am_name = args.am[:args.am.rindex('_')]
+ am_dataset = args.am[args.am.rindex('_') + 1:]
+ fs = 24000 if am_dataset != 'ljspeech' else 22050
+
+ # am
+ am_sess = get_sess(args, filed='am')
+
+ # vocoder
+ voc_sess = get_sess(args, filed='voc')
+
+ # am warmup
+ for T in [27, 38, 54]:
+ data = np.random.randint(1, 266, size=(T, ))
+ am_sess.run(None, {"text": data})
+
+ # voc warmup
+ for T in [227, 308, 544]:
+ data = np.random.rand(T, 80).astype("float32")
+ voc_sess.run(None, {"logmel": data})
+ print("warm up done!")
+
+ # frontend warmup
+ # Loading model cost 0.5+ seconds
+ if args.lang == 'zh':
+ frontend.get_input_ids("你好,欢迎使用飞桨框架进行深度学习研究!", merge_sentences=True)
+ else:
+ print("lang should in be 'zh' here!")
+
+ N = 0
+ T = 0
+ merge_sentences = True
+ for utt_id, sentence in sentences:
+ with timer() as t:
+ if args.lang == 'zh':
+ input_ids = frontend.get_input_ids(
+ sentence, merge_sentences=merge_sentences)
+
+ phone_ids = input_ids["phone_ids"]
+ else:
+ print("lang should in be 'zh' here!")
+ # merge_sentences=True here, so we only use the first item of phone_ids
+ phone_ids = phone_ids[0].numpy()
+ mel = am_sess.run(output_names=None, input_feed={'text': phone_ids})
+ mel = mel[0]
+ wav = voc_sess.run(output_names=None, input_feed={'logmel': mel})
+
+ N += len(wav[0])
+ T += t.elapse
+ speed = len(wav[0]) / t.elapse
+ rtf = fs / speed
+ sf.write(
+ str(output_dir / (utt_id + ".wav")),
+ np.array(wav)[0],
+ samplerate=fs)
+ print(
+ f"{utt_id}, mel: {mel.shape}, wave: {len(wav[0])}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
+ )
+ print(f"generation speed: {N / T}Hz, RTF: {fs / (N / T) }")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Infernce with onnxruntime.")
+ # acoustic model
+ parser.add_argument(
+ '--am',
+ type=str,
+ default='fastspeech2_csmsc',
+ choices=[
+ 'fastspeech2_csmsc',
+ ],
+ help='Choose acoustic model type of tts task.')
+ parser.add_argument(
+ "--phones_dict", type=str, default=None, help="phone vocabulary file.")
+ parser.add_argument(
+ "--tones_dict", type=str, default=None, help="tone vocabulary file.")
+
+ # voc
+ parser.add_argument(
+ '--voc',
+ type=str,
+ default='hifigan_csmsc',
+ choices=['hifigan_csmsc', 'mb_melgan_csmsc'],
+ help='Choose vocoder type of tts task.')
+ # other
+ parser.add_argument(
+ "--inference_dir", type=str, help="dir to save inference models")
+ parser.add_argument(
+ "--text",
+ type=str,
+ help="text to synthesize, a 'utt_id sentence' pair per line")
+ parser.add_argument("--output_dir", type=str, help="output dir")
+ parser.add_argument(
+ '--lang',
+ type=str,
+ default='zh',
+ help='Choose model language. zh or en')
+
+ # inference
+ parser.add_argument(
+ "--use_trt",
+ type=str2bool,
+ default=False,
+ help="Whether to use inference engin TensorRT.", )
+
+ parser.add_argument(
+ "--device",
+ default="gpu",
+ choices=["gpu", "cpu"],
+ help="Device selected for inference.", )
+ parser.add_argument('--cpu_threads', type=int, default=1)
+
+ args, _ = parser.parse_known_args()
+ return args
+
+
+def main():
+ args = parse_args()
+
+ ort_predict(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/paddlespeech/t2s/exps/speedyspeech/preprocess.py b/paddlespeech/t2s/exps/speedyspeech/preprocess.py
index 3f81c4e14..e833d1394 100644
--- a/paddlespeech/t2s/exps/speedyspeech/preprocess.py
+++ b/paddlespeech/t2s/exps/speedyspeech/preprocess.py
@@ -79,6 +79,9 @@ def process_sentence(config: Dict[str, Any],
logmel = mel_extractor.get_log_mel_fbank(wav)
# change duration according to mel_length
compare_duration_and_mel_length(sentences, utt_id, logmel)
+ # utt_id may be popped in compare_duration_and_mel_length
+ if utt_id not in sentences:
+ return None
labels = sentences[utt_id][0]
# extract phone and duration
phones = []
diff --git a/paddlespeech/t2s/exps/synthesize_streaming.py b/paddlespeech/t2s/exps/synthesize_streaming.py
index f38b2d352..7b9906c10 100644
--- a/paddlespeech/t2s/exps/synthesize_streaming.py
+++ b/paddlespeech/t2s/exps/synthesize_streaming.py
@@ -90,6 +90,7 @@ def evaluate(args):
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
merge_sentences = True
+ get_tone_ids = False
N = 0
T = 0
@@ -98,8 +99,6 @@ def evaluate(args):
for utt_id, sentence in sentences:
with timer() as t:
- get_tone_ids = False
-
if args.lang == 'zh':
input_ids = frontend.get_input_ids(
sentence,
diff --git a/paddlespeech/t2s/exps/tacotron2/preprocess.py b/paddlespeech/t2s/exps/tacotron2/preprocess.py
index 7f41089eb..14a0d7eae 100644
--- a/paddlespeech/t2s/exps/tacotron2/preprocess.py
+++ b/paddlespeech/t2s/exps/tacotron2/preprocess.py
@@ -82,6 +82,9 @@ def process_sentence(config: Dict[str, Any],
logmel = mel_extractor.get_log_mel_fbank(wav)
# change duration according to mel_length
compare_duration_and_mel_length(sentences, utt_id, logmel)
+ # utt_id may be popped in compare_duration_and_mel_length
+ if utt_id not in sentences:
+ return None
phones = sentences[utt_id][0]
durations = sentences[utt_id][1]
num_frames = logmel.shape[0]
diff --git a/paddlespeech/t2s/modules/positional_encoding.py b/paddlespeech/t2s/modules/positional_encoding.py
index 7c368c3aa..715c576f5 100644
--- a/paddlespeech/t2s/modules/positional_encoding.py
+++ b/paddlespeech/t2s/modules/positional_encoding.py
@@ -31,8 +31,9 @@ def sinusoid_position_encoding(num_positions: int,
channel = paddle.arange(0, feature_size, 2, dtype=dtype)
index = paddle.arange(start_pos, start_pos + num_positions, 1, dtype=dtype)
- p = (paddle.unsqueeze(index, -1) *
- omega) / (10000.0**(channel / float(feature_size)))
+ denominator = channel / float(feature_size)
+ denominator = paddle.to_tensor([10000.0], dtype='float32')**denominator
+ p = (paddle.unsqueeze(index, -1) * omega) / denominator
encodings = paddle.zeros([num_positions, feature_size], dtype=dtype)
encodings[:, 0::2] = paddle.sin(p)
encodings[:, 1::2] = paddle.cos(p)
diff --git a/paddlespeech/vector/cluster/diarization.py b/paddlespeech/vector/cluster/diarization.py
index ee00cb535..56af61af6 100644
--- a/paddlespeech/vector/cluster/diarization.py
+++ b/paddlespeech/vector/cluster/diarization.py
@@ -747,6 +747,77 @@ def merge_ssegs_same_speaker(lol):
return new_lol
+def write_ders_file(ref_rttm, DER, out_der_file):
+ """Write the final DERs for individual recording.
+
+ Arguments
+ ---------
+ ref_rttm : str
+ Reference RTTM file.
+ DER : array
+ Array containing DER values of each recording.
+ out_der_file : str
+ File to write the DERs.
+ """
+
+ rttm = read_rttm(ref_rttm)
+ spkr_info = list(filter(lambda x: x.startswith("SPKR-INFO"), rttm))
+
+ rec_id_list = []
+ count = 0
+
+ with open(out_der_file, "w") as f:
+ for row in spkr_info:
+ a = row.split(" ")
+ rec_id = a[1]
+ if rec_id not in rec_id_list:
+ r = [rec_id, str(round(DER[count], 2))]
+ rec_id_list.append(rec_id)
+ line_str = " ".join(r)
+ f.write("%s\n" % line_str)
+ count += 1
+ r = ["OVERALL ", str(round(DER[count], 2))]
+ line_str = " ".join(r)
+ f.write("%s\n" % line_str)
+
+
+def get_oracle_num_spkrs(rec_id, spkr_info):
+ """
+ Returns actual number of speakers in a recording from the ground-truth.
+ This can be used when the condition is oracle number of speakers.
+
+ Arguments
+ ---------
+ rec_id : str
+ Recording ID for which the number of speakers have to be obtained.
+ spkr_info : list
+ Header of the RTTM file. Starting with `SPKR-INFO`.
+
+ Example
+ -------
+ >>> from speechbrain.processing import diarization as diar
+ >>> spkr_info = ['SPKR-INFO ES2011a 0 unknown ES2011a.A ',
+ ... 'SPKR-INFO ES2011a 0 unknown ES2011a.B ',
+ ... 'SPKR-INFO ES2011a 0 unknown ES2011a.C ',
+ ... 'SPKR-INFO ES2011a 0 unknown ES2011a.D ',
+ ... 'SPKR-INFO ES2011b 0 unknown ES2011b.A ',
+ ... 'SPKR-INFO ES2011b 0 unknown ES2011b.B ',
+ ... 'SPKR-INFO ES2011b 0 unknown ES2011b.C ']
+ >>> diar.get_oracle_num_spkrs('ES2011a', spkr_info)
+ 4
+ >>> diar.get_oracle_num_spkrs('ES2011b', spkr_info)
+ 3
+ """
+
+ num_spkrs = 0
+ for line in spkr_info:
+ if rec_id in line:
+ # Since rec_id is prefix for each speaker
+ num_spkrs += 1
+
+ return num_spkrs
+
+
def distribute_overlap(lol):
"""
Distributes the overlapped speech equally among the adjacent segments
@@ -827,6 +898,29 @@ def distribute_overlap(lol):
return new_lol
+def read_rttm(rttm_file_path):
+ """
+ Reads and returns RTTM in list format.
+
+ Arguments
+ ---------
+ rttm_file_path : str
+ Path to the RTTM file to be read.
+
+ Returns
+ -------
+ rttm : list
+ List containing rows of RTTM file.
+ """
+
+ rttm = []
+ with open(rttm_file_path, "r") as f:
+ for line in f:
+ entry = line[:-1]
+ rttm.append(entry)
+ return rttm
+
+
def write_rttm(segs_list, out_rttm_file):
"""
Writes the segment list in RTTM format (A standard NIST format).
diff --git a/paddlespeech/vector/exps/ecapa_tdnn/test.py b/paddlespeech/vector/exps/ecapa_tdnn/test.py
index d0de6dc51..70b1521ed 100644
--- a/paddlespeech/vector/exps/ecapa_tdnn/test.py
+++ b/paddlespeech/vector/exps/ecapa_tdnn/test.py
@@ -21,10 +21,11 @@ from paddle.io import DataLoader
from tqdm import tqdm
from yacs.config import CfgNode
-from paddleaudio.datasets import VoxCeleb
from paddleaudio.metric import compute_eer
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.batch import batch_feature_normalize
+from paddlespeech.vector.io.dataset import CSVDataset
+from paddlespeech.vector.io.embedding_norm import InputNormalization
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.seeding import seed_everything
@@ -32,6 +33,91 @@ from paddlespeech.vector.training.seeding import seed_everything
logger = Log(__name__).getlog()
+def compute_dataset_embedding(data_loader, model, mean_var_norm_emb, config,
+ id2embedding):
+ """compute the dataset embeddings
+
+ Args:
+ data_loader (_type_): _description_
+ model (_type_): _description_
+ mean_var_norm_emb (_type_): _description_
+ config (_type_): _description_
+ """
+ logger.info(
+ f'Computing embeddings on {data_loader.dataset.csv_path} dataset')
+ with paddle.no_grad():
+ for batch_idx, batch in enumerate(tqdm(data_loader)):
+
+ # stage 8-1: extrac the audio embedding
+ ids, feats, lengths = batch['ids'], batch['feats'], batch['lengths']
+ embeddings = model.backbone(feats, lengths).squeeze(
+ -1) # (N, emb_size, 1) -> (N, emb_size)
+
+ # Global embedding normalization.
+ # if we use the global embedding norm
+ # eer can reduece about relative 10%
+ if config.global_embedding_norm and mean_var_norm_emb:
+ lengths = paddle.ones([embeddings.shape[0]])
+ embeddings = mean_var_norm_emb(embeddings, lengths)
+
+ # Update embedding dict.
+ id2embedding.update(dict(zip(ids, embeddings)))
+
+
+def compute_verification_scores(id2embedding, train_cohort, config):
+ labels = []
+ enroll_ids = []
+ test_ids = []
+ logger.info(f"read the trial from {config.verification_file}")
+ cos_sim_func = paddle.nn.CosineSimilarity(axis=-1)
+ scores = []
+ with open(config.verification_file, 'r') as f:
+ for line in f.readlines():
+ label, enroll_id, test_id = line.strip().split(' ')
+ enroll_id = enroll_id.split('.')[0].replace('/', '-')
+ test_id = test_id.split('.')[0].replace('/', '-')
+ labels.append(int(label))
+
+ enroll_emb = id2embedding[enroll_id]
+ test_emb = id2embedding[test_id]
+ score = cos_sim_func(enroll_emb, test_emb).item()
+
+ if "score_norm" in config:
+ # Getting norm stats for enroll impostors
+ enroll_rep = paddle.tile(
+ enroll_emb, repeat_times=[train_cohort.shape[0], 1])
+ score_e_c = cos_sim_func(enroll_rep, train_cohort)
+ if "cohort_size" in config:
+ score_e_c, _ = paddle.topk(
+ score_e_c, k=config.cohort_size, axis=0)
+ mean_e_c = paddle.mean(score_e_c, axis=0)
+ std_e_c = paddle.std(score_e_c, axis=0)
+
+ # Getting norm stats for test impostors
+ test_rep = paddle.tile(
+ test_emb, repeat_times=[train_cohort.shape[0], 1])
+ score_t_c = cos_sim_func(test_rep, train_cohort)
+ if "cohort_size" in config:
+ score_t_c, _ = paddle.topk(
+ score_t_c, k=config.cohort_size, axis=0)
+ mean_t_c = paddle.mean(score_t_c, axis=0)
+ std_t_c = paddle.std(score_t_c, axis=0)
+
+ if config.score_norm == "s-norm":
+ score_e = (score - mean_e_c) / std_e_c
+ score_t = (score - mean_t_c) / std_t_c
+
+ score = 0.5 * (score_e + score_t)
+ elif config.score_norm == "z-norm":
+ score = (score - mean_e_c) / std_e_c
+ elif config.score_norm == "t-norm":
+ score = (score - mean_t_c) / std_t_c
+
+ scores.append(score)
+
+ return scores, labels
+
+
def main(args, config):
# stage0: set the training device, cpu or gpu
paddle.set_device(args.device)
@@ -58,9 +144,8 @@ def main(args, config):
# stage4: construct the enroll and test dataloader
- enroll_dataset = VoxCeleb(
- subset='enroll',
- target_dir=args.data_dir,
+ enroll_dataset = CSVDataset(
+ os.path.join(args.data_dir, "vox/csv/enroll.csv"),
feat_type='melspectrogram',
random_chunk=False,
n_mels=config.n_mels,
@@ -68,16 +153,15 @@ def main(args, config):
hop_length=config.hop_size)
enroll_sampler = BatchSampler(
enroll_dataset, batch_size=config.batch_size,
- shuffle=True) # Shuffle to make embedding normalization more robust.
- enrol_loader = DataLoader(enroll_dataset,
+ shuffle=False) # Shuffle to make embedding normalization more robust.
+ enroll_loader = DataLoader(enroll_dataset,
batch_sampler=enroll_sampler,
collate_fn=lambda x: batch_feature_normalize(
- x, mean_norm=True, std_norm=False),
+ x, mean_norm=True, std_norm=False),
num_workers=config.num_workers,
return_list=True,)
- test_dataset = VoxCeleb(
- subset='test',
- target_dir=args.data_dir,
+ test_dataset = CSVDataset(
+ os.path.join(args.data_dir, "vox/csv/test.csv"),
feat_type='melspectrogram',
random_chunk=False,
n_mels=config.n_mels,
@@ -85,7 +169,7 @@ def main(args, config):
hop_length=config.hop_size)
test_sampler = BatchSampler(
- test_dataset, batch_size=config.batch_size, shuffle=True)
+ test_dataset, batch_size=config.batch_size, shuffle=False)
test_loader = DataLoader(test_dataset,
batch_sampler=test_sampler,
collate_fn=lambda x: batch_feature_normalize(
@@ -97,75 +181,65 @@ def main(args, config):
# stage6: global embedding norm to imporve the performance
logger.info(f"global embedding norm: {config.global_embedding_norm}")
- if config.global_embedding_norm:
- global_embedding_mean = None
- global_embedding_std = None
- mean_norm_flag = config.embedding_mean_norm
- std_norm_flag = config.embedding_std_norm
- batch_count = 0
# stage7: Compute embeddings of audios in enrol and test dataset from model.
+
+ if config.global_embedding_norm:
+ mean_var_norm_emb = InputNormalization(
+ norm_type="global",
+ mean_norm=config.embedding_mean_norm,
+ std_norm=config.embedding_std_norm)
+
+ if "score_norm" in config:
+ logger.info(f"we will do score norm: {config.score_norm}")
+ train_dataset = CSVDataset(
+ os.path.join(args.data_dir, "vox/csv/train.csv"),
+ feat_type='melspectrogram',
+ n_train_snts=config.n_train_snts,
+ random_chunk=False,
+ n_mels=config.n_mels,
+ window_size=config.window_size,
+ hop_length=config.hop_size)
+ train_sampler = BatchSampler(
+ train_dataset, batch_size=config.batch_size, shuffle=False)
+ train_loader = DataLoader(train_dataset,
+ batch_sampler=train_sampler,
+ collate_fn=lambda x: batch_feature_normalize(
+ x, mean_norm=True, std_norm=False),
+ num_workers=config.num_workers,
+ return_list=True,)
+
id2embedding = {}
# Run multi times to make embedding normalization more stable.
- for i in range(2):
- for dl in [enrol_loader, test_loader]:
- logger.info(
- f'Loop {[i+1]}: Computing embeddings on {dl.dataset.subset} dataset'
- )
- with paddle.no_grad():
- for batch_idx, batch in enumerate(tqdm(dl)):
-
- # stage 8-1: extrac the audio embedding
- ids, feats, lengths = batch['ids'], batch['feats'], batch[
- 'lengths']
- embeddings = model.backbone(feats, lengths).squeeze(
- -1).numpy() # (N, emb_size, 1) -> (N, emb_size)
-
- # Global embedding normalization.
- # if we use the global embedding norm
- # eer can reduece about relative 10%
- if config.global_embedding_norm:
- batch_count += 1
- current_mean = embeddings.mean(
- axis=0) if mean_norm_flag else 0
- current_std = embeddings.std(
- axis=0) if std_norm_flag else 1
- # Update global mean and std.
- if global_embedding_mean is None and global_embedding_std is None:
- global_embedding_mean, global_embedding_std = current_mean, current_std
- else:
- weight = 1 / batch_count # Weight decay by batches.
- global_embedding_mean = (
- 1 - weight
- ) * global_embedding_mean + weight * current_mean
- global_embedding_std = (
- 1 - weight
- ) * global_embedding_std + weight * current_std
- # Apply global embedding normalization.
- embeddings = (embeddings - global_embedding_mean
- ) / global_embedding_std
-
- # Update embedding dict.
- id2embedding.update(dict(zip(ids, embeddings)))
+ logger.info("First loop for enroll and test dataset")
+ compute_dataset_embedding(enroll_loader, model, mean_var_norm_emb, config,
+ id2embedding)
+ compute_dataset_embedding(test_loader, model, mean_var_norm_emb, config,
+ id2embedding)
+
+ logger.info("Second loop for enroll and test dataset")
+ compute_dataset_embedding(enroll_loader, model, mean_var_norm_emb, config,
+ id2embedding)
+ compute_dataset_embedding(test_loader, model, mean_var_norm_emb, config,
+ id2embedding)
+ mean_var_norm_emb.save(
+ os.path.join(args.load_checkpoint, "mean_var_norm_emb"))
# stage 8: Compute cosine scores.
- labels = []
- enroll_ids = []
- test_ids = []
- logger.info(f"read the trial from {VoxCeleb.veri_test_file}")
- with open(VoxCeleb.veri_test_file, 'r') as f:
- for line in f.readlines():
- label, enroll_id, test_id = line.strip().split(' ')
- labels.append(int(label))
- enroll_ids.append(enroll_id.split('.')[0].replace('/', '-'))
- test_ids.append(test_id.split('.')[0].replace('/', '-'))
-
- cos_sim_func = paddle.nn.CosineSimilarity(axis=1)
- enrol_embeddings, test_embeddings = map(lambda ids: paddle.to_tensor(
- np.asarray([id2embedding[uttid] for uttid in ids], dtype='float32')),
- [enroll_ids, test_ids
- ]) # (N, emb_size)
- scores = cos_sim_func(enrol_embeddings, test_embeddings)
+ train_cohort = None
+ if "score_norm" in config:
+ train_embeddings = {}
+ # cohort embedding not do mean and std norm
+ compute_dataset_embedding(train_loader, model, None, config,
+ train_embeddings)
+ train_cohort = paddle.stack(list(train_embeddings.values()))
+
+ # compute the scores
+ scores, labels = compute_verification_scores(id2embedding, train_cohort,
+ config)
+
+ # compute the EER and threshold
+ scores = paddle.to_tensor(scores)
EER, threshold = compute_eer(np.asarray(labels), scores.numpy())
logger.info(
f'EER of verification test: {EER*100:.4f}%, score threshold: {threshold:.5f}'
diff --git a/paddlespeech/vector/exps/ecapa_tdnn/train.py b/paddlespeech/vector/exps/ecapa_tdnn/train.py
index 257b97abe..b777dae89 100644
--- a/paddlespeech/vector/exps/ecapa_tdnn/train.py
+++ b/paddlespeech/vector/exps/ecapa_tdnn/train.py
@@ -23,13 +23,13 @@ from paddle.io import DistributedBatchSampler
from yacs.config import CfgNode
from paddleaudio.compliance.librosa import melspectrogram
-from paddleaudio.datasets.voxceleb import VoxCeleb
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.augment import build_augment_pipeline
from paddlespeech.vector.io.augment import waveform_augment
from paddlespeech.vector.io.batch import batch_pad_right
from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.io.batch import waveform_collate_fn
+from paddlespeech.vector.io.dataset import CSVDataset
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.loss import AdditiveAngularMargin
from paddlespeech.vector.modules.loss import LogSoftmaxWrapper
@@ -54,8 +54,12 @@ def main(args, config):
# stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline
# note: some cmd must do in rank==0, so wo will refactor the data prepare code
- train_dataset = VoxCeleb('train', target_dir=args.data_dir)
- dev_dataset = VoxCeleb('dev', target_dir=args.data_dir)
+ train_dataset = CSVDataset(
+ csv_path=os.path.join(args.data_dir, "vox/csv/train.csv"),
+ label2id_path=os.path.join(args.data_dir, "vox/meta/label2id.txt"))
+ dev_dataset = CSVDataset(
+ csv_path=os.path.join(args.data_dir, "vox/csv/dev.csv"),
+ label2id_path=os.path.join(args.data_dir, "vox/meta/label2id.txt"))
if config.augment:
augment_pipeline = build_augment_pipeline(target_dir=args.data_dir)
@@ -67,7 +71,7 @@ def main(args, config):
# stage4: build the speaker verification train instance with backbone model
model = SpeakerIdetification(
- backbone=ecapa_tdnn, num_class=VoxCeleb.num_speakers)
+ backbone=ecapa_tdnn, num_class=config.num_speakers)
# stage5: build the optimizer, we now only construct the AdamW optimizer
# 140000 is single gpu steps
@@ -193,15 +197,15 @@ def main(args, config):
paddle.optimizer.lr.LRScheduler):
optimizer._learning_rate.step()
optimizer.clear_grad()
- train_run_cost += time.time() - train_start
# stage 9-8: Calculate average loss per batch
- avg_loss += loss.numpy()[0]
+ avg_loss = loss.item()
# stage 9-9: Calculate metrics, which is one-best accuracy
preds = paddle.argmax(logits, axis=1)
num_corrects += (preds == labels).numpy().sum()
num_samples += feats.shape[0]
+ train_run_cost += time.time() - train_start
timer.count() # step plus one in timer
# stage 9-10: print the log information only on 0-rank per log-freq batchs
@@ -220,8 +224,9 @@ def main(args, config):
train_feat_cost / config.log_interval)
print_msg += ' avg_train_cost: {:.5f} sec,'.format(
train_run_cost / config.log_interval)
- print_msg += ' lr={:.4E} step/sec={:.2f} | ETA {}'.format(
- lr, timer.timing, timer.eta)
+
+ print_msg += ' lr={:.4E} step/sec={:.2f} ips:{:.5f}| ETA {}'.format(
+ lr, timer.timing, timer.ips, timer.eta)
logger.info(print_msg)
avg_loss = 0
diff --git a/paddlespeech/vector/io/augment.py b/paddlespeech/vector/io/augment.py
index 3baace139..0aa89c6a3 100644
--- a/paddlespeech/vector/io/augment.py
+++ b/paddlespeech/vector/io/augment.py
@@ -14,6 +14,7 @@
# this is modified from SpeechBrain
# https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/lobes/augment.py
import math
+import os
from typing import List
import numpy as np
@@ -21,8 +22,8 @@ import paddle
import paddle.nn as nn
import paddle.nn.functional as F
-from paddleaudio.datasets.rirs_noises import OpenRIRNoise
from paddlespeech.s2t.utils.log import Log
+from paddlespeech.vector.io.dataset import CSVDataset
from paddlespeech.vector.io.signal_processing import compute_amplitude
from paddlespeech.vector.io.signal_processing import convolve1d
from paddlespeech.vector.io.signal_processing import dB_to_amplitude
@@ -509,7 +510,7 @@ class AddNoise(nn.Layer):
assert w >= 0, f'Target length {target_length} is less than origin length {x.shape[0]}'
return np.pad(x, [0, w], mode=mode, **kwargs)
- ids = [item['id'] for item in batch]
+ ids = [item['utt_id'] for item in batch]
lengths = np.asarray([item['feat'].shape[0] for item in batch])
waveforms = list(
map(lambda x: pad(x, max(max_length, lengths.max().item())),
@@ -589,7 +590,7 @@ class AddReverb(nn.Layer):
assert w >= 0, f'Target length {target_length} is less than origin length {x.shape[0]}'
return np.pad(x, [0, w], mode=mode, **kwargs)
- ids = [item['id'] for item in batch]
+ ids = [item['utt_id'] for item in batch]
lengths = np.asarray([item['feat'].shape[0] for item in batch])
waveforms = list(
map(lambda x: pad(x, lengths.max().item()),
@@ -839,8 +840,10 @@ def build_augment_pipeline(target_dir=None) -> List[paddle.nn.Layer]:
List[paddle.nn.Layer]: all augment process
"""
logger.info("start to build the augment pipeline")
- noise_dataset = OpenRIRNoise('noise', target_dir=target_dir)
- rir_dataset = OpenRIRNoise('rir', target_dir=target_dir)
+ noise_dataset = CSVDataset(csv_path=os.path.join(target_dir,
+ "rir_noise/csv/noise.csv"))
+ rir_dataset = CSVDataset(csv_path=os.path.join(target_dir,
+ "rir_noise/csv/rir.csv"))
wavedrop = TimeDomainSpecAugment(
sample_rate=16000,
diff --git a/paddlespeech/vector/io/batch.py b/paddlespeech/vector/io/batch.py
index 92ca990cf..5049d1946 100644
--- a/paddlespeech/vector/io/batch.py
+++ b/paddlespeech/vector/io/batch.py
@@ -17,6 +17,17 @@ import paddle
def waveform_collate_fn(batch):
+ """Wrap the waveform into a batch form
+
+ Args:
+ batch (list): the waveform list from the dataloader
+ the item of data include several field
+ feat: the utterance waveform data
+ label: the utterance label encoding data
+
+ Returns:
+ dict: the batch data to dataloader
+ """
waveforms = np.stack([item['feat'] for item in batch])
labels = np.stack([item['label'] for item in batch])
@@ -27,6 +38,18 @@ def feature_normalize(feats: paddle.Tensor,
mean_norm: bool=True,
std_norm: bool=True,
convert_to_numpy: bool=False):
+ """Do one utterance feature normalization
+
+ Args:
+ feats (paddle.Tensor): the original utterance feat, such as fbank, mfcc
+ mean_norm (bool, optional): mean norm flag. Defaults to True.
+ std_norm (bool, optional): std norm flag. Defaults to True.
+ convert_to_numpy (bool, optional): convert the paddle.tensor to numpy
+ and do feature norm with numpy. Defaults to False.
+
+ Returns:
+ paddle.Tensor : the normalized feats
+ """
# Features normalization if needed
# numpy.mean is a little with paddle.mean about 1e-6
if convert_to_numpy:
@@ -60,7 +83,17 @@ def pad_right_2d(x, target_length, axis=-1, mode='constant', **kwargs):
def batch_feature_normalize(batch, mean_norm: bool=True, std_norm: bool=True):
- ids = [item['id'] for item in batch]
+ """Do batch utterance features normalization
+
+ Args:
+ batch (list): the batch feature from dataloader
+ mean_norm (bool, optional): mean normalization flag. Defaults to True.
+ std_norm (bool, optional): std normalization flag. Defaults to True.
+
+ Returns:
+ dict: the normalized batch features
+ """
+ ids = [item['utt_id'] for item in batch]
lengths = np.asarray([item['feat'].shape[1] for item in batch])
feats = list(
map(lambda x: pad_right_2d(x, lengths.max()),
diff --git a/paddlespeech/vector/io/dataset.py b/paddlespeech/vector/io/dataset.py
new file mode 100644
index 000000000..316c8ac34
--- /dev/null
+++ b/paddlespeech/vector/io/dataset.py
@@ -0,0 +1,192 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from dataclasses import fields
+from paddle.io import Dataset
+
+from paddleaudio import load as load_audio
+from paddleaudio.compliance.librosa import melspectrogram
+from paddlespeech.s2t.utils.log import Log
+logger = Log(__name__).getlog()
+
+# the audio meta info in the vector CSVDataset
+# utt_id: the utterance segment name
+# duration: utterance segment time
+# wav: utterance file path
+# start: start point in the original wav file
+# stop: stop point in the original wav file
+# label: the utterance segment's label id
+
+
+@dataclass
+class meta_info:
+ """the audio meta info in the vector CSVDataset
+
+ Args:
+ utt_id (str): the utterance segment name
+ duration (float): utterance segment time
+ wav (str): utterance file path
+ start (int): start point in the original wav file
+ stop (int): stop point in the original wav file
+ lab_id (str): the utterance segment's label id
+ """
+ utt_id: str
+ duration: float
+ wav: str
+ start: int
+ stop: int
+ label: str
+
+
+# csv dataset support feature type
+# raw: return the pcm data sample point
+# melspectrogram: fbank feature
+feat_funcs = {
+ 'raw': None,
+ 'melspectrogram': melspectrogram,
+}
+
+
+class CSVDataset(Dataset):
+ def __init__(self,
+ csv_path,
+ label2id_path=None,
+ config=None,
+ random_chunk=True,
+ feat_type: str="raw",
+ n_train_snts: int=-1,
+ **kwargs):
+ """Implement the CSV Dataset
+
+ Args:
+ csv_path (str): csv dataset file path
+ label2id_path (str): the utterance label to integer id map file path
+ config (CfgNode): yaml config
+ feat_type (str): dataset feature type. if it is raw, it return pcm data.
+ n_train_snts (int): select the n_train_snts sample from the dataset.
+ if n_train_snts = -1, dataset will load all the sample.
+ Default value is -1.
+ kwargs : feature type args
+ """
+ super().__init__()
+ self.csv_path = csv_path
+ self.label2id_path = label2id_path
+ self.config = config
+ self.random_chunk = random_chunk
+ self.feat_type = feat_type
+ self.n_train_snts = n_train_snts
+ self.feat_config = kwargs
+ self.id2label = {}
+ self.label2id = {}
+ self.data = self.load_data_csv()
+ self.load_speaker_to_label()
+
+ def load_data_csv(self):
+ """Load the csv dataset content and store them in the data property
+ the csv dataset's format has six fields,
+ that is audio_id or utt_id, audio duration, segment start point, segment stop point
+ and utterance label.
+ Note in training period, the utterance label must has a map to integer id in label2id_path
+
+ Returns:
+ list: the csv data with meta_info type
+ """
+ data = []
+
+ with open(self.csv_path, 'r') as rf:
+ for line in rf.readlines()[1:]:
+ audio_id, duration, wav, start, stop, spk_id = line.strip(
+ ).split(',')
+ data.append(
+ meta_info(audio_id,
+ float(duration), wav,
+ int(start), int(stop), spk_id))
+ if self.n_train_snts > 0:
+ sample_num = min(self.n_train_snts, len(data))
+ data = data[0:sample_num]
+
+ return data
+
+ def load_speaker_to_label(self):
+ """Load the utterance label map content.
+ In vector domain, we call the utterance label as speaker label.
+ The speaker label is real speaker label in speaker verification domain,
+ and in language identification is language label.
+ """
+ if not self.label2id_path:
+ logger.warning("No speaker id to label file")
+ return
+
+ with open(self.label2id_path, 'r') as f:
+ for line in f.readlines():
+ label_name, label_id = line.strip().split(' ')
+ self.label2id[label_name] = int(label_id)
+ self.id2label[int(label_id)] = label_name
+
+ def convert_to_record(self, idx: int):
+ """convert the dataset sample to training record the CSV Dataset
+
+ Args:
+ idx (int) : the request index in all the dataset
+ """
+ sample = self.data[idx]
+
+ record = {}
+ # To show all fields in a namedtuple: `type(sample)._fields`
+ for field in fields(sample):
+ record[field.name] = getattr(sample, field.name)
+
+ waveform, sr = load_audio(record['wav'])
+
+ # random select a chunk audio samples from the audio
+ if self.config and self.config.random_chunk:
+ num_wav_samples = waveform.shape[0]
+ num_chunk_samples = int(self.config.chunk_duration * sr)
+ start = random.randint(0, num_wav_samples - num_chunk_samples - 1)
+ stop = start + num_chunk_samples
+ else:
+ start = record['start']
+ stop = record['stop']
+
+ # we only return the waveform as feat
+ waveform = waveform[start:stop]
+
+ # all availabel feature type is in feat_funcs
+ assert self.feat_type in feat_funcs.keys(), \
+ f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}"
+ feat_func = feat_funcs[self.feat_type]
+ feat = feat_func(
+ waveform, sr=sr, **self.feat_config) if feat_func else waveform
+
+ record.update({'feat': feat})
+ if self.label2id:
+ record.update({'label': self.label2id[record['label']]})
+
+ return record
+
+ def __getitem__(self, idx):
+ """Return the specific index sample
+
+ Args:
+ idx (int) : the request index in all the dataset
+ """
+ return self.convert_to_record(idx)
+
+ def __len__(self):
+ """Return the dataset length
+
+ Returns:
+ int: the length num of the dataset
+ """
+ return len(self.data)
diff --git a/paddlespeech/vector/io/dataset_from_json.py b/paddlespeech/vector/io/dataset_from_json.py
new file mode 100644
index 000000000..5ffd2c186
--- /dev/null
+++ b/paddlespeech/vector/io/dataset_from_json.py
@@ -0,0 +1,116 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+
+from dataclasses import dataclass
+from dataclasses import fields
+from paddle.io import Dataset
+
+from paddleaudio import load as load_audio
+from paddleaudio.compliance.librosa import melspectrogram
+from paddleaudio.compliance.librosa import mfcc
+
+
+@dataclass
+class meta_info:
+ """the audio meta info in the vector JSONDataset
+ Args:
+ id (str): the segment name
+ duration (float): segment time
+ wav (str): wav file path
+ start (int): start point in the original wav file
+ stop (int): stop point in the original wav file
+ lab_id (str): the record id
+ """
+ id: str
+ duration: float
+ wav: str
+ start: int
+ stop: int
+ record_id: str
+
+
+# json dataset support feature type
+feat_funcs = {
+ 'raw': None,
+ 'melspectrogram': melspectrogram,
+ 'mfcc': mfcc,
+}
+
+
+class JSONDataset(Dataset):
+ """
+ dataset from json file.
+ """
+
+ def __init__(self, json_file: str, feat_type: str='raw', **kwargs):
+ """
+ Ags:
+ json_file (:obj:`str`): Data prep JSON file.
+ labels (:obj:`List[int]`): Labels of audio files.
+ feat_type (:obj:`str`, `optional`, defaults to `raw`):
+ It identifies the feature type that user wants to extrace of an audio file.
+ """
+ if feat_type not in feat_funcs.keys():
+ raise RuntimeError(
+ f"Unknown feat_type: {feat_type}, it must be one in {list(feat_funcs.keys())}"
+ )
+
+ self.json_file = json_file
+ self.feat_type = feat_type
+ self.feat_config = kwargs
+ self._data = self._get_data()
+ super(JSONDataset, self).__init__()
+
+ def _get_data(self):
+ with open(self.json_file, "r") as f:
+ meta_data = json.load(f)
+ data = []
+ for key in meta_data:
+ sub_seg = meta_data[key]["wav"]
+ wav = sub_seg["file"]
+ duration = sub_seg["duration"]
+ start = sub_seg["start"]
+ stop = sub_seg["stop"]
+ rec_id = str(key).rsplit("_", 2)[0]
+ data.append(
+ meta_info(
+ str(key),
+ float(duration), wav, int(start), int(stop), str(rec_id)))
+ return data
+
+ def _convert_to_record(self, idx: int):
+ sample = self._data[idx]
+
+ record = {}
+ # To show all fields in a namedtuple
+ for field in fields(sample):
+ record[field.name] = getattr(sample, field.name)
+
+ waveform, sr = load_audio(record['wav'])
+ waveform = waveform[record['start']:record['stop']]
+
+ feat_func = feat_funcs[self.feat_type]
+ feat = feat_func(
+ waveform, sr=sr, **self.feat_config) if feat_func else waveform
+
+ record.update({'feat': feat})
+
+ return record
+
+ def __getitem__(self, idx):
+ return self._convert_to_record(idx)
+
+ def __len__(self):
+ return len(self._data)
diff --git a/paddlespeech/vector/io/embedding_norm.py b/paddlespeech/vector/io/embedding_norm.py
new file mode 100644
index 000000000..619f37101
--- /dev/null
+++ b/paddlespeech/vector/io/embedding_norm.py
@@ -0,0 +1,214 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Dict
+
+import paddle
+
+
+class InputNormalization:
+ spk_dict_mean: Dict[int, paddle.Tensor]
+ spk_dict_std: Dict[int, paddle.Tensor]
+ spk_dict_count: Dict[int, int]
+
+ def __init__(
+ self,
+ mean_norm=True,
+ std_norm=True,
+ norm_type="global", ):
+ """Do feature or embedding mean and std norm
+
+ Args:
+ mean_norm (bool, optional): mean norm flag. Defaults to True.
+ std_norm (bool, optional): std norm flag. Defaults to True.
+ norm_type (str, optional): norm type. Defaults to "global".
+ """
+ super().__init__()
+ self.training = True
+ self.mean_norm = mean_norm
+ self.std_norm = std_norm
+ self.norm_type = norm_type
+ self.glob_mean = paddle.to_tensor([0], dtype="float32")
+ self.glob_std = paddle.to_tensor([0], dtype="float32")
+ self.spk_dict_mean = {}
+ self.spk_dict_std = {}
+ self.spk_dict_count = {}
+ self.weight = 1.0
+ self.count = 0
+ self.eps = 1e-10
+
+ def __call__(self,
+ x,
+ lengths,
+ spk_ids=paddle.to_tensor([], dtype="float32")):
+ """Returns the tensor with the surrounding context.
+ Args:
+ x (paddle.Tensor): A batch of tensors.
+ lengths (paddle.Tensor): A batch of tensors containing the relative length of each
+ sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid
+ computing stats on zero-padded steps.
+ spk_ids (_type_, optional): tensor containing the ids of each speaker (e.g, [0 10 6]).
+ It is used to perform per-speaker normalization when
+ norm_type='speaker'. Defaults to paddle.to_tensor([], dtype="float32").
+ Returns:
+ paddle.Tensor: The normalized feature or embedding
+ """
+ N_batches = x.shape[0]
+ # print(f"x shape: {x.shape[1]}")
+ current_means = []
+ current_stds = []
+
+ for snt_id in range(N_batches):
+
+ # Avoiding padded time steps
+ # actual size is the actual time data length
+ actual_size = paddle.round(lengths[snt_id] *
+ x.shape[1]).astype("int32")
+ # computing actual time data statistics
+ current_mean, current_std = self._compute_current_stats(
+ x[snt_id, 0:actual_size, ...].unsqueeze(0))
+ current_means.append(current_mean)
+ current_stds.append(current_std)
+
+ if self.norm_type == "global":
+ current_mean = paddle.mean(paddle.stack(current_means), axis=0)
+ current_std = paddle.mean(paddle.stack(current_stds), axis=0)
+
+ if self.norm_type == "global":
+
+ if self.training:
+ if self.count == 0:
+ self.glob_mean = current_mean
+ self.glob_std = current_std
+
+ else:
+ self.weight = 1 / (self.count + 1)
+
+ self.glob_mean = (
+ 1 - self.weight
+ ) * self.glob_mean + self.weight * current_mean
+
+ self.glob_std = (
+ 1 - self.weight
+ ) * self.glob_std + self.weight * current_std
+
+ self.glob_mean.detach()
+ self.glob_std.detach()
+
+ self.count = self.count + 1
+ x = (x - self.glob_mean) / (self.glob_std)
+ return x
+
+ def _compute_current_stats(self, x):
+ """Returns the tensor with the surrounding context.
+
+ Args:
+ x (paddle.Tensor): A batch of tensors.
+
+ Returns:
+ the statistics of the data
+ """
+ # Compute current mean
+ if self.mean_norm:
+ current_mean = paddle.mean(x, axis=0).detach()
+ else:
+ current_mean = paddle.to_tensor([0.0], dtype="float32")
+
+ # Compute current std
+ if self.std_norm:
+ current_std = paddle.std(x, axis=0).detach()
+ else:
+ current_std = paddle.to_tensor([1.0], dtype="float32")
+
+ # Improving numerical stability of std
+ current_std = paddle.maximum(current_std,
+ self.eps * paddle.ones_like(current_std))
+
+ return current_mean, current_std
+
+ def _statistics_dict(self):
+ """Fills the dictionary containing the normalization statistics.
+ """
+ state = {}
+ state["count"] = self.count
+ state["glob_mean"] = self.glob_mean
+ state["glob_std"] = self.glob_std
+ state["spk_dict_mean"] = self.spk_dict_mean
+ state["spk_dict_std"] = self.spk_dict_std
+ state["spk_dict_count"] = self.spk_dict_count
+
+ return state
+
+ def _load_statistics_dict(self, state):
+ """Loads the dictionary containing the statistics.
+
+ Arguments
+ ---------
+ state : dict
+ A dictionary containing the normalization statistics.
+ """
+ self.count = state["count"]
+ if isinstance(state["glob_mean"], int):
+ self.glob_mean = state["glob_mean"]
+ self.glob_std = state["glob_std"]
+ else:
+ self.glob_mean = state["glob_mean"] # .to(self.device_inp)
+ self.glob_std = state["glob_std"] # .to(self.device_inp)
+
+ # Loading the spk_dict_mean in the right device
+ self.spk_dict_mean = {}
+ for spk in state["spk_dict_mean"]:
+ self.spk_dict_mean[spk] = state["spk_dict_mean"][spk]
+
+ # Loading the spk_dict_std in the right device
+ self.spk_dict_std = {}
+ for spk in state["spk_dict_std"]:
+ self.spk_dict_std[spk] = state["spk_dict_std"][spk]
+
+ self.spk_dict_count = state["spk_dict_count"]
+
+ return state
+
+ def to(self, device):
+ """Puts the needed tensors in the right device.
+ """
+ self = super(InputNormalization, self).to(device)
+ self.glob_mean = self.glob_mean.to(device)
+ self.glob_std = self.glob_std.to(device)
+ for spk in self.spk_dict_mean:
+ self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device)
+ self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device)
+ return self
+
+ def save(self, path):
+ """Save statistic dictionary.
+
+ Args:
+ path (str): A path where to save the dictionary.
+ """
+ stats = self._statistics_dict()
+ paddle.save(stats, path)
+
+ def _load(self, path, end_of_epoch=False, device=None):
+ """Load statistic dictionary.
+
+ Arguments
+ ---------
+ path : str
+ The path of the statistic dictionary
+ device : str, None
+ Passed to paddle.load(..., map_location=device)
+ """
+ del end_of_epoch # Unused here.
+ stats = paddle.load(path, map_location=device)
+ self._load_statistics_dict(stats)
diff --git a/paddlespeech/vector/models/ecapa_tdnn.py b/paddlespeech/vector/models/ecapa_tdnn.py
index 0e7287cd3..895ff13f4 100644
--- a/paddlespeech/vector/models/ecapa_tdnn.py
+++ b/paddlespeech/vector/models/ecapa_tdnn.py
@@ -79,6 +79,20 @@ class Conv1d(nn.Layer):
bias_attr=bias, )
def forward(self, x):
+ """Do conv1d forward
+
+ Args:
+ x (paddle.Tensor): [N, C, L] input data,
+ N is the batch,
+ C is the data dimension,
+ L is the time
+
+ Raises:
+ ValueError: only support the same padding type
+
+ Returns:
+ paddle.Tensor: the value of conv1d
+ """
if self.padding == "same":
x = self._manage_padding(x, self.kernel_size, self.dilation,
self.stride)
@@ -88,6 +102,20 @@ class Conv1d(nn.Layer):
return self.conv(x)
def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int):
+ """Padding the input data
+
+ Args:
+ x (paddle.Tensor): [N, C, L] input data
+ N is the batch,
+ C is the data dimension,
+ L is the time
+ kernel_size (int): 1-d convolution kernel size
+ dilation (int): 1-d convolution dilation
+ stride (int): 1-d convolution stride
+
+ Returns:
+ paddle.Tensor: the padded input data
+ """
L_in = x.shape[-1] # Detecting input shape
padding = self._get_padding_elem(L_in, stride, kernel_size,
dilation) # Time padding
@@ -101,6 +129,17 @@ class Conv1d(nn.Layer):
stride: int,
kernel_size: int,
dilation: int):
+ """Calculate the padding value in same mode
+
+ Args:
+ L_in (int): the times of the input data,
+ stride (int): 1-d convolution stride
+ kernel_size (int): 1-d convolution kernel size
+ dilation (int): 1-d convolution stride
+
+ Returns:
+ int: return the padding value in same mode
+ """
if stride > 1:
n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
L_out = stride * (n_steps - 1) + kernel_size * dilation
@@ -245,6 +284,13 @@ class SEBlock(nn.Layer):
class AttentiveStatisticsPooling(nn.Layer):
def __init__(self, channels, attention_channels=128, global_context=True):
+ """Compute the speaker verification statistics
+ The detail info is section 3.1 in https://arxiv.org/pdf/1709.01507.pdf
+ Args:
+ channels (int): input data channel or data dimension
+ attention_channels (int, optional): attention dimension. Defaults to 128.
+ global_context (bool, optional): If use the global context information. Defaults to True.
+ """
super().__init__()
self.eps = 1e-12
diff --git a/paddlespeech/vector/utils/time.py b/paddlespeech/vector/utils/time.py
index 8e85b0e12..9dfbbe1f7 100644
--- a/paddlespeech/vector/utils/time.py
+++ b/paddlespeech/vector/utils/time.py
@@ -23,6 +23,7 @@ class Timer(object):
self.last_start_step = 0
self.current_step = 0
self._is_running = True
+ self.cur_ips = 0
def start(self):
self.last_time = time.time()
@@ -43,12 +44,17 @@ class Timer(object):
self.last_start_step = self.current_step
time_used = time.time() - self.last_time
self.last_time = time.time()
+ self.cur_ips = run_steps / time_used
return time_used / run_steps
@property
def is_running(self) -> bool:
return self._is_running
+ @property
+ def ips(self) -> float:
+ return self.cur_ips
+
@property
def eta(self) -> str:
if not self.is_running:
diff --git a/paddlespeech/vector/utils/vector_utils.py b/paddlespeech/vector/utils/vector_utils.py
new file mode 100644
index 000000000..46de7ffaa
--- /dev/null
+++ b/paddlespeech/vector/utils/vector_utils.py
@@ -0,0 +1,32 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+def get_chunks(seg_dur, audio_id, audio_duration):
+ """Get all chunk segments from a utterance
+
+ Args:
+ seg_dur (float): segment chunk duration, seconds
+ audio_id (str): utterance name,
+ audio_duration (float): utterance duration, seconds
+
+ Returns:
+ List: all the chunk segments
+ """
+ num_chunks = int(audio_duration / seg_dur) # all in seconds
+ chunk_lst = [
+ audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur)
+ for i in range(num_chunks)
+ ]
+ return chunk_lst
diff --git a/speechx/examples/aishell/local/compute-wer.py b/speechx/examples/aishell/local/compute-wer.py
new file mode 100755
index 000000000..a3eefc0dc
--- /dev/null
+++ b/speechx/examples/aishell/local/compute-wer.py
@@ -0,0 +1,500 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+
+import re, sys, unicodedata
+import codecs
+
+remove_tag = True
+spacelist= [' ', '\t', '\r', '\n']
+puncts = ['!', ',', '?',
+ '、', '。', '!', ',', ';', '?',
+ ':', '「', '」', '︰', '『', '』', '《', '》']
+
+def characterize(string) :
+ res = []
+ i = 0
+ while i < len(string):
+ char = string[i]
+ if char in puncts:
+ i += 1
+ continue
+ cat1 = unicodedata.category(char)
+ #https://unicodebook.readthedocs.io/unicode.html#unicode-categories
+ if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned
+ i += 1
+ continue
+ if cat1 == 'Lo': # letter-other
+ res.append(char)
+ i += 1
+ else:
+ # some input looks like: , we want to separate it to two words.
+ sep = ' '
+ if char == '<': sep = '>'
+ j = i+1
+ while j < len(string):
+ c = string[j]
+ if ord(c) >= 128 or (c in spacelist) or (c==sep):
+ break
+ j += 1
+ if j < len(string) and string[j] == '>':
+ j += 1
+ res.append(string[i:j])
+ i = j
+ return res
+
+def stripoff_tags(x):
+ if not x: return ''
+ chars = []
+ i = 0; T=len(x)
+ while i < T:
+ if x[i] == '<':
+ while i < T and x[i] != '>':
+ i += 1
+ i += 1
+ else:
+ chars.append(x[i])
+ i += 1
+ return ''.join(chars)
+
+
+def normalize(sentence, ignore_words, cs, split=None):
+ """ sentence, ignore_words are both in unicode
+ """
+ new_sentence = []
+ for token in sentence:
+ x = token
+ if not cs:
+ x = x.upper()
+ if x in ignore_words:
+ continue
+ if remove_tag:
+ x = stripoff_tags(x)
+ if not x:
+ continue
+ if split and x in split:
+ new_sentence += split[x]
+ else:
+ new_sentence.append(x)
+ return new_sentence
+
+class Calculator :
+ def __init__(self) :
+ self.data = {}
+ self.space = []
+ self.cost = {}
+ self.cost['cor'] = 0
+ self.cost['sub'] = 1
+ self.cost['del'] = 1
+ self.cost['ins'] = 1
+ def calculate(self, lab, rec) :
+ # Initialization
+ lab.insert(0, '')
+ rec.insert(0, '')
+ while len(self.space) < len(lab) :
+ self.space.append([])
+ for row in self.space :
+ for element in row :
+ element['dist'] = 0
+ element['error'] = 'non'
+ while len(row) < len(rec) :
+ row.append({'dist' : 0, 'error' : 'non'})
+ for i in range(len(lab)) :
+ self.space[i][0]['dist'] = i
+ self.space[i][0]['error'] = 'del'
+ for j in range(len(rec)) :
+ self.space[0][j]['dist'] = j
+ self.space[0][j]['error'] = 'ins'
+ self.space[0][0]['error'] = 'non'
+ for token in lab :
+ if token not in self.data and len(token) > 0 :
+ self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0}
+ for token in rec :
+ if token not in self.data and len(token) > 0 :
+ self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0}
+ # Computing edit distance
+ for i, lab_token in enumerate(lab) :
+ for j, rec_token in enumerate(rec) :
+ if i == 0 or j == 0 :
+ continue
+ min_dist = sys.maxsize
+ min_error = 'none'
+ dist = self.space[i-1][j]['dist'] + self.cost['del']
+ error = 'del'
+ if dist < min_dist :
+ min_dist = dist
+ min_error = error
+ dist = self.space[i][j-1]['dist'] + self.cost['ins']
+ error = 'ins'
+ if dist < min_dist :
+ min_dist = dist
+ min_error = error
+ if lab_token == rec_token :
+ dist = self.space[i-1][j-1]['dist'] + self.cost['cor']
+ error = 'cor'
+ else :
+ dist = self.space[i-1][j-1]['dist'] + self.cost['sub']
+ error = 'sub'
+ if dist < min_dist :
+ min_dist = dist
+ min_error = error
+ self.space[i][j]['dist'] = min_dist
+ self.space[i][j]['error'] = min_error
+ # Tracing back
+ result = {'lab':[], 'rec':[], 'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0}
+ i = len(lab) - 1
+ j = len(rec) - 1
+ while True :
+ if self.space[i][j]['error'] == 'cor' : # correct
+ if len(lab[i]) > 0 :
+ self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
+ self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
+ result['all'] = result['all'] + 1
+ result['cor'] = result['cor'] + 1
+ result['lab'].insert(0, lab[i])
+ result['rec'].insert(0, rec[j])
+ i = i - 1
+ j = j - 1
+ elif self.space[i][j]['error'] == 'sub' : # substitution
+ if len(lab[i]) > 0 :
+ self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
+ self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
+ result['all'] = result['all'] + 1
+ result['sub'] = result['sub'] + 1
+ result['lab'].insert(0, lab[i])
+ result['rec'].insert(0, rec[j])
+ i = i - 1
+ j = j - 1
+ elif self.space[i][j]['error'] == 'del' : # deletion
+ if len(lab[i]) > 0 :
+ self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
+ self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
+ result['all'] = result['all'] + 1
+ result['del'] = result['del'] + 1
+ result['lab'].insert(0, lab[i])
+ result['rec'].insert(0, "")
+ i = i - 1
+ elif self.space[i][j]['error'] == 'ins' : # insertion
+ if len(rec[j]) > 0 :
+ self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
+ result['ins'] = result['ins'] + 1
+ result['lab'].insert(0, "")
+ result['rec'].insert(0, rec[j])
+ j = j - 1
+ elif self.space[i][j]['error'] == 'non' : # starting point
+ break
+ else : # shouldn't reach here
+ print('this should not happen , i = {i} , j = {j} , error = {error}'.format(i = i, j = j, error = self.space[i][j]['error']))
+ return result
+ def overall(self) :
+ result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0}
+ for token in self.data :
+ result['all'] = result['all'] + self.data[token]['all']
+ result['cor'] = result['cor'] + self.data[token]['cor']
+ result['sub'] = result['sub'] + self.data[token]['sub']
+ result['ins'] = result['ins'] + self.data[token]['ins']
+ result['del'] = result['del'] + self.data[token]['del']
+ return result
+ def cluster(self, data) :
+ result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0}
+ for token in data :
+ if token in self.data :
+ result['all'] = result['all'] + self.data[token]['all']
+ result['cor'] = result['cor'] + self.data[token]['cor']
+ result['sub'] = result['sub'] + self.data[token]['sub']
+ result['ins'] = result['ins'] + self.data[token]['ins']
+ result['del'] = result['del'] + self.data[token]['del']
+ return result
+ def keys(self) :
+ return list(self.data.keys())
+
+def width(string):
+ return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
+
+def default_cluster(word) :
+ unicode_names = [ unicodedata.name(char) for char in word ]
+ for i in reversed(range(len(unicode_names))) :
+ if unicode_names[i].startswith('DIGIT') : # 1
+ unicode_names[i] = 'Number' # 'DIGIT'
+ elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or
+ unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')) :
+ # 明 / 郎
+ unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH'
+ elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or
+ unicode_names[i].startswith('LATIN SMALL LETTER')) :
+ # A / a
+ unicode_names[i] = 'English' # 'LATIN LETTER'
+ elif unicode_names[i].startswith('HIRAGANA LETTER') : # は こ め
+ unicode_names[i] = 'Japanese' # 'GANA LETTER'
+ elif (unicode_names[i].startswith('AMPERSAND') or
+ unicode_names[i].startswith('APOSTROPHE') or
+ unicode_names[i].startswith('COMMERCIAL AT') or
+ unicode_names[i].startswith('DEGREE CELSIUS') or
+ unicode_names[i].startswith('EQUALS SIGN') or
+ unicode_names[i].startswith('FULL STOP') or
+ unicode_names[i].startswith('HYPHEN-MINUS') or
+ unicode_names[i].startswith('LOW LINE') or
+ unicode_names[i].startswith('NUMBER SIGN') or
+ unicode_names[i].startswith('PLUS SIGN') or
+ unicode_names[i].startswith('SEMICOLON')) :
+ # & / ' / @ / ℃ / = / . / - / _ / # / + / ;
+ del unicode_names[i]
+ else :
+ return 'Other'
+ if len(unicode_names) == 0 :
+ return 'Other'
+ if len(unicode_names) == 1 :
+ return unicode_names[0]
+ for i in range(len(unicode_names)-1) :
+ if unicode_names[i] != unicode_names[i+1] :
+ return 'Other'
+ return unicode_names[0]
+
+def usage() :
+ print("compute-wer.py : compute word error rate (WER) and align recognition results and references.")
+ print(" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer")
+
+if __name__ == '__main__':
+ if len(sys.argv) == 1 :
+ usage()
+ sys.exit(0)
+ calculator = Calculator()
+ cluster_file = ''
+ ignore_words = set()
+ tochar = False
+ verbose= 1
+ padding_symbol= ' '
+ case_sensitive = False
+ max_words_per_line = sys.maxsize
+ split = None
+ while len(sys.argv) > 3:
+ a = '--maxw='
+ if sys.argv[1].startswith(a):
+ b = sys.argv[1][len(a):]
+ del sys.argv[1]
+ max_words_per_line = int(b)
+ continue
+ a = '--rt='
+ if sys.argv[1].startswith(a):
+ b = sys.argv[1][len(a):].lower()
+ del sys.argv[1]
+ remove_tag = (b == 'true') or (b != '0')
+ continue
+ a = '--cs='
+ if sys.argv[1].startswith(a):
+ b = sys.argv[1][len(a):].lower()
+ del sys.argv[1]
+ case_sensitive = (b == 'true') or (b != '0')
+ continue
+ a = '--cluster='
+ if sys.argv[1].startswith(a):
+ cluster_file = sys.argv[1][len(a):]
+ del sys.argv[1]
+ continue
+ a = '--splitfile='
+ if sys.argv[1].startswith(a):
+ split_file = sys.argv[1][len(a):]
+ del sys.argv[1]
+ split = dict()
+ with codecs.open(split_file, 'r', 'utf-8') as fh:
+ for line in fh: # line in unicode
+ words = line.strip().split()
+ if len(words) >= 2:
+ split[words[0]] = words[1:]
+ continue
+ a = '--ig='
+ if sys.argv[1].startswith(a):
+ ignore_file = sys.argv[1][len(a):]
+ del sys.argv[1]
+ with codecs.open(ignore_file, 'r', 'utf-8') as fh:
+ for line in fh: # line in unicode
+ line = line.strip()
+ if len(line) > 0:
+ ignore_words.add(line)
+ continue
+ a = '--char='
+ if sys.argv[1].startswith(a):
+ b = sys.argv[1][len(a):].lower()
+ del sys.argv[1]
+ tochar = (b == 'true') or (b != '0')
+ continue
+ a = '--v='
+ if sys.argv[1].startswith(a):
+ b = sys.argv[1][len(a):].lower()
+ del sys.argv[1]
+ verbose=0
+ try:
+ verbose=int(b)
+ except:
+ if b == 'true' or b != '0':
+ verbose = 1
+ continue
+ a = '--padding-symbol='
+ if sys.argv[1].startswith(a):
+ b = sys.argv[1][len(a):].lower()
+ del sys.argv[1]
+ if b == 'space':
+ padding_symbol= ' '
+ elif b == 'underline':
+ padding_symbol= '_'
+ continue
+ if True or sys.argv[1].startswith('-'):
+ #ignore invalid switch
+ del sys.argv[1]
+ continue
+
+ if not case_sensitive:
+ ig=set([w.upper() for w in ignore_words])
+ ignore_words = ig
+
+ default_clusters = {}
+ default_words = {}
+
+ ref_file = sys.argv[1]
+ hyp_file = sys.argv[2]
+ rec_set = {}
+ if split and not case_sensitive:
+ newsplit = dict()
+ for w in split:
+ words = split[w]
+ for i in range(len(words)):
+ words[i] = words[i].upper()
+ newsplit[w.upper()] = words
+ split = newsplit
+
+ with codecs.open(hyp_file, 'r', 'utf-8') as fh:
+ for line in fh:
+ if tochar:
+ array = characterize(line)
+ else:
+ array = line.strip().split()
+ if len(array)==0: continue
+ fid = array[0]
+ rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, split)
+
+ # compute error rate on the interaction of reference file and hyp file
+ for line in open(ref_file, 'r', encoding='utf-8') :
+ if tochar:
+ array = characterize(line)
+ else:
+ array = line.rstrip('\n').split()
+ if len(array)==0: continue
+ fid = array[0]
+ if fid not in rec_set:
+ continue
+ lab = normalize(array[1:], ignore_words, case_sensitive, split)
+ rec = rec_set[fid]
+ if verbose:
+ print('\nutt: %s' % fid)
+
+ for word in rec + lab :
+ if word not in default_words :
+ default_cluster_name = default_cluster(word)
+ if default_cluster_name not in default_clusters :
+ default_clusters[default_cluster_name] = {}
+ if word not in default_clusters[default_cluster_name] :
+ default_clusters[default_cluster_name][word] = 1
+ default_words[word] = default_cluster_name
+
+ result = calculator.calculate(lab, rec)
+ if verbose:
+ if result['all'] != 0 :
+ wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
+ else :
+ wer = 0.0
+ print('WER: %4.2f %%' % wer, end = ' ')
+ print('N=%d C=%d S=%d D=%d I=%d' %
+ (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
+ space = {}
+ space['lab'] = []
+ space['rec'] = []
+ for idx in range(len(result['lab'])) :
+ len_lab = width(result['lab'][idx])
+ len_rec = width(result['rec'][idx])
+ length = max(len_lab, len_rec)
+ space['lab'].append(length-len_lab)
+ space['rec'].append(length-len_rec)
+ upper_lab = len(result['lab'])
+ upper_rec = len(result['rec'])
+ lab1, rec1 = 0, 0
+ while lab1 < upper_lab or rec1 < upper_rec:
+ if verbose > 1:
+ print('lab(%s):' % fid.encode('utf-8'), end = ' ')
+ else:
+ print('lab:', end = ' ')
+ lab2 = min(upper_lab, lab1 + max_words_per_line)
+ for idx in range(lab1, lab2):
+ token = result['lab'][idx]
+ print('{token}'.format(token = token), end = '')
+ for n in range(space['lab'][idx]) :
+ print(padding_symbol, end = '')
+ print(' ',end='')
+ print()
+ if verbose > 1:
+ print('rec(%s):' % fid.encode('utf-8'), end = ' ')
+ else:
+ print('rec:', end = ' ')
+ rec2 = min(upper_rec, rec1 + max_words_per_line)
+ for idx in range(rec1, rec2):
+ token = result['rec'][idx]
+ print('{token}'.format(token = token), end = '')
+ for n in range(space['rec'][idx]) :
+ print(padding_symbol, end = '')
+ print(' ',end='')
+ print('\n', end='\n')
+ lab1 = lab2
+ rec1 = rec2
+
+ if verbose:
+ print('===========================================================================')
+ print()
+
+ result = calculator.overall()
+ if result['all'] != 0 :
+ wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
+ else :
+ wer = 0.0
+ print('Overall -> %4.2f %%' % wer, end = ' ')
+ print('N=%d C=%d S=%d D=%d I=%d' %
+ (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
+ if not verbose:
+ print()
+
+ if verbose:
+ for cluster_id in default_clusters :
+ result = calculator.cluster([ k for k in default_clusters[cluster_id] ])
+ if result['all'] != 0 :
+ wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
+ else :
+ wer = 0.0
+ print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ')
+ print('N=%d C=%d S=%d D=%d I=%d' %
+ (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
+ if len(cluster_file) > 0 : # compute separated WERs for word clusters
+ cluster_id = ''
+ cluster = []
+ for line in open(cluster_file, 'r', encoding='utf-8') :
+ for token in line.decode('utf-8').rstrip('\n').split() :
+ # end of cluster reached, like
+ if token[0:2] == '' and token[len(token)-1] == '>' and \
+ token.lstrip('').rstrip('>') == cluster_id :
+ result = calculator.cluster(cluster)
+ if result['all'] != 0 :
+ wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
+ else :
+ wer = 0.0
+ print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ')
+ print('N=%d C=%d S=%d D=%d I=%d' %
+ (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
+ cluster_id = ''
+ cluster = []
+ # begin of cluster reached, like
+ elif token[0] == '<' and token[len(token)-1] == '>' and \
+ cluster_id == '' :
+ cluster_id = token.lstrip('<').rstrip('>')
+ cluster = []
+ # general terms, like WEATHER / CAR / ...
+ else :
+ cluster.append(token)
+ print()
+ print('===========================================================================')
diff --git a/speechx/examples/aishell/local/split_data.sh b/speechx/examples/aishell/local/split_data.sh
new file mode 100755
index 000000000..df454d6cf
--- /dev/null
+++ b/speechx/examples/aishell/local/split_data.sh
@@ -0,0 +1,24 @@
+#!/usr/bin/env bash
+
+data=$1
+feat_scp=$2
+split_feat_name=$3
+numsplit=$4
+
+
+if ! [ "$numsplit" -gt 0 ]; then
+ echo "Invalid num-split argument";
+ exit 1;
+fi
+
+directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n; done)
+feat_split_scp=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n/${split_feat_name}; done)
+echo $feat_split_scp
+# if this mkdir fails due to argument-list being too long, iterate.
+if ! mkdir -p $directories >&/dev/null; then
+ for n in `seq $numsplit`; do
+ mkdir -p $data/split${numsplit}/$n
+ done
+fi
+
+utils/split_scp.pl $feat_scp $feat_split_scp
diff --git a/speechx/examples/aishell/path.sh b/speechx/examples/aishell/path.sh
new file mode 100644
index 000000000..a0e7c9aed
--- /dev/null
+++ b/speechx/examples/aishell/path.sh
@@ -0,0 +1,14 @@
+# This contains the locations of binarys build required for running the examples.
+
+SPEECHX_ROOT=$PWD/../..
+SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
+
+SPEECHX_TOOLS=$SPEECHX_ROOT/tools
+TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
+
+[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
+
+export LC_AL=C
+
+SPEECHX_BIN=$SPEECHX_EXAMPLES/decoder:$SPEECHX_EXAMPLES/feat
+export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
diff --git a/speechx/examples/aishell/run.sh b/speechx/examples/aishell/run.sh
new file mode 100755
index 000000000..8a16a8650
--- /dev/null
+++ b/speechx/examples/aishell/run.sh
@@ -0,0 +1,113 @@
+#!/bin/bash
+set +x
+set -e
+
+. path.sh
+
+# 1. compile
+if [ ! -d ${SPEECHX_EXAMPLES} ]; then
+ pushd ${SPEECHX_ROOT}
+ bash build.sh
+ popd
+fi
+
+
+# 2. download model
+if [ ! -d ../paddle_asr_model ]; then
+ wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
+ tar xzfv paddle_asr_model.tar.gz
+ mv ./paddle_asr_model ../
+ # produce wav scp
+ echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
+fi
+
+mkdir -p data
+data=$PWD/data
+aishell_wav_scp=aishell_test.scp
+if [ ! -d $data/test ]; then
+ wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
+ unzip -d $data aishell_test.zip
+ realpath $data/test/*/*.wav > $data/wavlist
+ awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
+ paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
+fi
+
+model_dir=$PWD/aishell_ds2_online_model
+if [ ! -d $model_dir ]; then
+ mkdir -p $model_dir
+ wget -P $model_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
+ tar xzfv $model_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $model_dir
+fi
+
+# 3. make feature
+aishell_online_model=$model_dir/exp/deepspeech2_online/checkpoints
+lm_model_dir=../paddle_asr_model
+label_file=./aishell_result
+wer=./aishell_wer
+
+nj=40
+export GLOG_logtostderr=1
+
+#./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
+
+data=$PWD/data
+# 3. gen linear feat
+cmvn=$PWD/cmvn.ark
+cmvn_json2binary_main --json_file=$model_dir/data/mean_std.json --cmvn_write_path=$cmvn
+
+utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat_log \
+linear_spectrogram_without_db_norm_main \
+ --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
+ --feature_wspecifier=ark,scp:$data/split${nj}/JOB/feat.ark,$data/split${nj}/JOB/feat.scp \
+ --cmvn_file=$cmvn \
+ --streaming_chunk=0.36
+
+text=$data/test/text
+
+# 4. recognizer
+utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log \
+ offline_decoder_sliding_chunk_main \
+ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
+ --model_path=$aishell_online_model/avg_1.jit.pdmodel \
+ --param_path=$aishell_online_model/avg_1.jit.pdiparams \
+ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
+ --dict_file=$lm_model_dir/vocab.txt \
+ --result_wspecifier=ark,t:$data/split${nj}/JOB/result
+
+cat $data/split${nj}/*/result > ${label_file}
+local/compute-wer.py --char=1 --v=1 ${label_file} $text > ${wer}
+
+# 4. decode with lm
+utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log_lm \
+ offline_decoder_sliding_chunk_main \
+ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
+ --model_path=$aishell_online_model/avg_1.jit.pdmodel \
+ --param_path=$aishell_online_model/avg_1.jit.pdiparams \
+ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
+ --dict_file=$lm_model_dir/vocab.txt \
+ --lm_path=$lm_model_dir/avg_1.jit.klm \
+ --result_wspecifier=ark,t:$data/split${nj}/JOB/result_lm
+
+cat $data/split${nj}/*/result_lm > ${label_file}_lm
+local/compute-wer.py --char=1 --v=1 ${label_file}_lm $text > ${wer}_lm
+
+graph_dir=./aishell_graph
+if [ ! -d $ ]; then
+ wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip
+ unzip -d aishell_graph.zip
+fi
+
+# 5. test TLG decoder
+utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log_tlg \
+ offline_wfst_decoder_main \
+ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
+ --model_path=$aishell_online_model/avg_1.jit.pdmodel \
+ --param_path=$aishell_online_model/avg_1.jit.pdiparams \
+ --word_symbol_table=$graph_dir/words.txt \
+ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
+ --graph_path=$graph_dir/TLG.fst --max_active=7500 \
+ --acoustic_scale=1.2 \
+ --result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg
+
+cat $data/split${nj}/*/result_tlg > ${label_file}_tlg
+local/compute-wer.py --char=1 --v=1 ${label_file}_tlg $text > ${wer}_tlg
\ No newline at end of file
diff --git a/speechx/examples/aishell/utils b/speechx/examples/aishell/utils
new file mode 120000
index 000000000..973afe674
--- /dev/null
+++ b/speechx/examples/aishell/utils
@@ -0,0 +1 @@
+../../../utils
\ No newline at end of file
diff --git a/speechx/examples/decoder/CMakeLists.txt b/speechx/examples/decoder/CMakeLists.txt
index ded423e94..d446a6715 100644
--- a/speechx/examples/decoder/CMakeLists.txt
+++ b/speechx/examples/decoder/CMakeLists.txt
@@ -8,6 +8,10 @@ add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_
target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
+add_executable(offline_wfst_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_wfst_decoder_main.cc)
+target_include_directories(offline_wfst_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
+target_link_libraries(offline_wfst_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})
+
add_executable(decoder_test_main ${CMAKE_CURRENT_SOURCE_DIR}/decoder_test_main.cc)
target_include_directories(decoder_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(decoder_test_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
diff --git a/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc b/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc
index 7f6c572ca..40092de31 100644
--- a/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc
+++ b/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc
@@ -22,30 +22,37 @@
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
-DEFINE_string(feature_respecifier, "", "test feature rspecifier");
+DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
+DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
-DEFINE_string(lm_path, "lm.klm", "language model");
+DEFINE_string(lm_path, "", "language model");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=5) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=5) module downsampling rate.");
+DEFINE_string(model_output_names,
+ "save_infer_model/scale_0.tmp_1,save_infer_model/"
+ "scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
+ "scale_3.tmp_1",
+ "model output names");
+DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
-
// test ds2 online decoder by feeding speech feature
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader(
- FLAGS_feature_respecifier);
+ FLAGS_feature_rspecifier);
+ kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file;
@@ -55,7 +62,6 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path;
-
int32 num_done = 0, num_err = 0;
ppspeech::CTCBeamSearchOptions opts;
@@ -66,7 +72,8 @@ int main(int argc, char* argv[]) {
ppspeech::ModelOptions model_opts;
model_opts.model_path = model_graph;
model_opts.params_path = model_params;
- model_opts.cache_shape = "5-1-1024,5-1-1024";
+ model_opts.cache_shape = FLAGS_model_cache_names;
+ model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr raw_data(new ppspeech::DataCache());
@@ -129,9 +136,16 @@ int main(int argc, char* argv[]) {
}
std::string result;
result = decoder.GetFinalBestPath();
- KALDI_LOG << " the result of " << utt << " is " << result;
decodable->Reset();
decoder.Reset();
+ if (result.empty()) {
+ // the TokenWriter can not write empty string.
+ ++num_err;
+ KALDI_LOG << " the result of " << utt << " is empty";
+ continue;
+ }
+ KALDI_LOG << " the result of " << utt << " is " << result;
+ result_writer.Write(utt, result);
++num_done;
}
diff --git a/speechx/examples/decoder/offline_wfst_decoder_main.cc b/speechx/examples/decoder/offline_wfst_decoder_main.cc
new file mode 100644
index 000000000..06460a45e
--- /dev/null
+++ b/speechx/examples/decoder/offline_wfst_decoder_main.cc
@@ -0,0 +1,158 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// todo refactor, repalce with gtest
+
+#include "base/flags.h"
+#include "base/log.h"
+#include "decoder/ctc_tlg_decoder.h"
+#include "frontend/audio/data_cache.h"
+#include "kaldi/util/table-types.h"
+#include "nnet/decodable.h"
+#include "nnet/paddle_nnet.h"
+
+DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
+DEFINE_string(result_wspecifier, "", "test result wspecifier");
+DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
+DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
+DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
+DEFINE_string(graph_path, "TLG", "decoder graph");
+DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
+DEFINE_int32(max_active, 7500, "decoder graph");
+DEFINE_int32(receptive_field_length,
+ 7,
+ "receptive field of two CNN(kernel=5) downsampling module.");
+DEFINE_int32(downsampling_rate,
+ 4,
+ "two CNN(kernel=5) module downsampling rate.");
+DEFINE_string(model_output_names,
+ "save_infer_model/scale_0.tmp_1,save_infer_model/"
+ "scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
+ "scale_3.tmp_1",
+ "model output names");
+DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
+
+using kaldi::BaseFloat;
+using kaldi::Matrix;
+using std::vector;
+
+// test TLG decoder by feeding speech feature.
+int main(int argc, char* argv[]) {
+ gflags::ParseCommandLineFlags(&argc, &argv, false);
+ google::InitGoogleLogging(argv[0]);
+
+ kaldi::SequentialBaseFloatMatrixReader feature_reader(
+ FLAGS_feature_rspecifier);
+ kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
+ std::string model_graph = FLAGS_model_path;
+ std::string model_params = FLAGS_param_path;
+ std::string word_symbol_table = FLAGS_word_symbol_table;
+ std::string graph_path = FLAGS_graph_path;
+ LOG(INFO) << "model path: " << model_graph;
+ LOG(INFO) << "model param: " << model_params;
+ LOG(INFO) << "word symbol path: " << word_symbol_table;
+ LOG(INFO) << "graph path: " << graph_path;
+
+ int32 num_done = 0, num_err = 0;
+
+ ppspeech::TLGDecoderOptions opts;
+ opts.word_symbol_table = word_symbol_table;
+ opts.fst_path = graph_path;
+ opts.opts.max_active = FLAGS_max_active;
+ opts.opts.beam = 15.0;
+ opts.opts.lattice_beam = 7.5;
+ ppspeech::TLGDecoder decoder(opts);
+
+ ppspeech::ModelOptions model_opts;
+ model_opts.model_path = model_graph;
+ model_opts.params_path = model_params;
+ model_opts.cache_shape = FLAGS_model_cache_names;
+ model_opts.output_names = FLAGS_model_output_names;
+ std::shared_ptr nnet(
+ new ppspeech::PaddleNnet(model_opts));
+ std::shared_ptr raw_data(new ppspeech::DataCache());
+ std::shared_ptr decodable(
+ new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
+
+ int32 chunk_size = FLAGS_receptive_field_length;
+ int32 chunk_stride = FLAGS_downsampling_rate;
+ int32 receptive_field_length = FLAGS_receptive_field_length;
+ LOG(INFO) << "chunk size (frame): " << chunk_size;
+ LOG(INFO) << "chunk stride (frame): " << chunk_stride;
+ LOG(INFO) << "receptive field (frame): " << receptive_field_length;
+ decoder.InitDecoder();
+
+ for (; !feature_reader.Done(); feature_reader.Next()) {
+ string utt = feature_reader.Key();
+ kaldi::Matrix feature = feature_reader.Value();
+ raw_data->SetDim(feature.NumCols());
+ LOG(INFO) << "process utt: " << utt;
+ LOG(INFO) << "rows: " << feature.NumRows();
+ LOG(INFO) << "cols: " << feature.NumCols();
+
+ int32 row_idx = 0;
+ int32 padding_len = 0;
+ int32 ori_feature_len = feature.NumRows();
+ if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
+ padding_len =
+ chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
+ feature.Resize(feature.NumRows() + padding_len,
+ feature.NumCols(),
+ kaldi::kCopyData);
+ }
+ int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
+ for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
+ kaldi::Vector feature_chunk(chunk_size *
+ feature.NumCols());
+ int32 feature_chunk_size = 0;
+ if (ori_feature_len > chunk_idx * chunk_stride) {
+ feature_chunk_size = std::min(
+ ori_feature_len - chunk_idx * chunk_stride, chunk_size);
+ }
+ if (feature_chunk_size < receptive_field_length) break;
+
+ int32 start = chunk_idx * chunk_stride;
+ for (int row_id = 0; row_id < chunk_size; ++row_id) {
+ kaldi::SubVector tmp(feature, start);
+ kaldi::SubVector f_chunk_tmp(
+ feature_chunk.Data() + row_id * feature.NumCols(),
+ feature.NumCols());
+ f_chunk_tmp.CopyFromVec(tmp);
+ ++start;
+ }
+ raw_data->Accept(feature_chunk);
+ if (chunk_idx == num_chunks - 1) {
+ raw_data->SetFinished();
+ }
+ decoder.AdvanceDecode(decodable);
+ }
+ std::string result;
+ result = decoder.GetFinalBestPath();
+ decodable->Reset();
+ decoder.Reset();
+ if (result.empty()) {
+ // the TokenWriter can not write empty string.
+ ++num_err;
+ KALDI_LOG << " the result of " << utt << " is empty";
+ continue;
+ }
+ KALDI_LOG << " the result of " << utt << " is " << result;
+ result_writer.Write(utt, result);
+ ++num_done;
+ }
+
+ KALDI_LOG << "Done " << num_done << " utterances, " << num_err
+ << " with errors.";
+ return (num_done != 0 ? 0 : 1);
+}
diff --git a/speechx/examples/feat/CMakeLists.txt b/speechx/examples/feat/CMakeLists.txt
index b8f516afb..d6fdb9bc6 100644
--- a/speechx/examples/feat/CMakeLists.txt
+++ b/speechx/examples/feat/CMakeLists.txt
@@ -7,4 +7,12 @@ target_link_libraries(mfcc-test kaldi-mfcc)
add_executable(linear_spectrogram_main ${CMAKE_CURRENT_SOURCE_DIR}/linear_spectrogram_main.cc)
target_include_directories(linear_spectrogram_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
-target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog)
\ No newline at end of file
+target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog)
+
+add_executable(linear_spectrogram_without_db_norm_main ${CMAKE_CURRENT_SOURCE_DIR}/linear_spectrogram_without_db_norm_main.cc)
+target_include_directories(linear_spectrogram_without_db_norm_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
+target_link_libraries(linear_spectrogram_without_db_norm_main frontend kaldi-util kaldi-feat-common gflags glog)
+
+add_executable(cmvn_json2binary_main ${CMAKE_CURRENT_SOURCE_DIR}/cmvn_json2binary_main.cc)
+target_include_directories(cmvn_json2binary_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
+target_link_libraries(cmvn_json2binary_main utils kaldi-util kaldi-matrix gflags glog)
diff --git a/speechx/examples/feat/cmvn_json2binary_main.cc b/speechx/examples/feat/cmvn_json2binary_main.cc
new file mode 100644
index 000000000..e77f983aa
--- /dev/null
+++ b/speechx/examples/feat/cmvn_json2binary_main.cc
@@ -0,0 +1,58 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/flags.h"
+#include "base/log.h"
+#include "kaldi/matrix/kaldi-matrix.h"
+#include "kaldi/util/kaldi-io.h"
+#include "utils/file_utils.h"
+#include "utils/simdjson.h"
+
+DEFINE_string(json_file, "", "cmvn json file");
+DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
+DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)");
+
+using namespace simdjson;
+
+int main(int argc, char* argv[]) {
+ gflags::ParseCommandLineFlags(&argc, &argv, false);
+ google::InitGoogleLogging(argv[0]);
+
+ ondemand::parser parser;
+ padded_string json = padded_string::load(FLAGS_json_file);
+ ondemand::document val = parser.iterate(json);
+ ondemand::object doc = val;
+ kaldi::int32 frame_num = uint64_t(doc["frame_num"]);
+ auto mean_stat = doc["mean_stat"];
+ std::vector mean_stat_vec;
+ for (double x : mean_stat) {
+ mean_stat_vec.push_back(x);
+ }
+ auto var_stat = doc["var_stat"];
+ std::vector var_stat_vec;
+ for (double x : var_stat) {
+ var_stat_vec.push_back(x);
+ }
+
+ size_t mean_size = mean_stat_vec.size();
+ kaldi::Matrix cmvn_stats(2, mean_size + 1);
+ for (size_t idx = 0; idx < mean_size; ++idx) {
+ cmvn_stats(0, idx) = mean_stat_vec[idx];
+ cmvn_stats(1, idx) = var_stat_vec[idx];
+ }
+ cmvn_stats(0, mean_size) = frame_num;
+ kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
+ LOG(INFO) << "the json file have write into " << FLAGS_cmvn_write_path;
+ return 0;
+}
\ No newline at end of file
diff --git a/speechx/examples/feat/linear_spectrogram_main.cc b/speechx/examples/feat/linear_spectrogram_main.cc
index 2d75bb5df..2e70386d6 100644
--- a/speechx/examples/feat/linear_spectrogram_main.cc
+++ b/speechx/examples/feat/linear_spectrogram_main.cc
@@ -30,6 +30,7 @@
DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
+DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
std::vector mean_{
@@ -181,6 +182,7 @@ int main(int argc, char* argv[]) {
ppspeech::LinearSpectrogramOptions opt;
opt.frame_opts.frame_length_ms = 20;
opt.frame_opts.frame_shift_ms = 10;
+ opt.streaming_chunk = FLAGS_streaming_chunk;
opt.frame_opts.dither = 0.0;
opt.frame_opts.remove_dc_offset = false;
opt.frame_opts.window_type = "hanning";
@@ -198,7 +200,7 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "feat dim: " << feature_cache.Dim();
int sample_rate = 16000;
- float streaming_chunk = 0.36;
+ float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
@@ -256,6 +258,7 @@ int main(int argc, char* argv[]) {
}
}
feat_writer.Write(utt, features);
+ feature_cache.Reset();
if (num_done % 50 == 0 && num_done != 0)
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
diff --git a/speechx/examples/feat/linear_spectrogram_without_db_norm_main.cc b/speechx/examples/feat/linear_spectrogram_without_db_norm_main.cc
new file mode 100644
index 000000000..5b875a3ee
--- /dev/null
+++ b/speechx/examples/feat/linear_spectrogram_without_db_norm_main.cc
@@ -0,0 +1,139 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// todo refactor, repalce with gtest
+
+#include "base/flags.h"
+#include "base/log.h"
+#include "kaldi/feat/wave-reader.h"
+#include "kaldi/util/kaldi-io.h"
+#include "kaldi/util/table-types.h"
+
+#include "frontend/audio/audio_cache.h"
+#include "frontend/audio/data_cache.h"
+#include "frontend/audio/feature_cache.h"
+#include "frontend/audio/frontend_itf.h"
+#include "frontend/audio/linear_spectrogram.h"
+#include "frontend/audio/normalizer.h"
+
+DEFINE_string(wav_rspecifier, "", "test wav scp path");
+DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
+DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn");
+DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
+
+int main(int argc, char* argv[]) {
+ gflags::ParseCommandLineFlags(&argc, &argv, false);
+ google::InitGoogleLogging(argv[0]);
+
+ kaldi::SequentialTableReader wav_reader(
+ FLAGS_wav_rspecifier);
+ kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
+
+ int32 num_done = 0, num_err = 0;
+
+ // feature pipeline: wave cache --> hanning
+ // window -->linear_spectrogram --> global cmvn -> feat cache
+
+ std::unique_ptr data_source(
+ new ppspeech::AudioCache(3600 * 1600, true));
+
+ ppspeech::LinearSpectrogramOptions opt;
+ opt.frame_opts.frame_length_ms = 20;
+ opt.frame_opts.frame_shift_ms = 10;
+ opt.streaming_chunk = FLAGS_streaming_chunk;
+ opt.frame_opts.dither = 0.0;
+ opt.frame_opts.remove_dc_offset = false;
+ opt.frame_opts.window_type = "hanning";
+ opt.frame_opts.preemph_coeff = 0.0;
+ LOG(INFO) << "frame length (ms): " << opt.frame_opts.frame_length_ms;
+ LOG(INFO) << "frame shift (ms): " << opt.frame_opts.frame_shift_ms;
+
+ std::unique_ptr linear_spectrogram(
+ new ppspeech::LinearSpectrogram(opt, std::move(data_source)));
+
+ std::unique_ptr cmvn(
+ new ppspeech::CMVN(FLAGS_cmvn_file, std::move(linear_spectrogram)));
+
+ ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn));
+ LOG(INFO) << "feat dim: " << feature_cache.Dim();
+
+ int sample_rate = 16000;
+ float streaming_chunk = FLAGS_streaming_chunk;
+ int chunk_sample_size = streaming_chunk * sample_rate;
+ LOG(INFO) << "sr: " << sample_rate;
+ LOG(INFO) << "chunk size (s): " << streaming_chunk;
+ LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
+
+
+ for (; !wav_reader.Done(); wav_reader.Next()) {
+ std::string utt = wav_reader.Key();
+ const kaldi::WaveData& wave_data = wav_reader.Value();
+ LOG(INFO) << "process utt: " << utt;
+
+ int32 this_channel = 0;
+ kaldi::SubVector waveform(wave_data.Data(),
+ this_channel);
+ int tot_samples = waveform.Dim();
+ LOG(INFO) << "wav len (sample): " << tot_samples;
+
+ int sample_offset = 0;
+ std::vector> feats;
+ int feature_rows = 0;
+ while (sample_offset < tot_samples) {
+ int cur_chunk_size =
+ std::min(chunk_sample_size, tot_samples - sample_offset);
+
+ kaldi::Vector wav_chunk(cur_chunk_size);
+ for (int i = 0; i < cur_chunk_size; ++i) {
+ wav_chunk(i) = waveform(sample_offset + i);
+ }
+
+ kaldi::Vector features;
+ feature_cache.Accept(wav_chunk);
+ if (cur_chunk_size < chunk_sample_size) {
+ feature_cache.SetFinished();
+ }
+ feature_cache.Read(&features);
+ if (features.Dim() == 0) break;
+
+ feats.push_back(features);
+ sample_offset += cur_chunk_size;
+ feature_rows += features.Dim() / feature_cache.Dim();
+ }
+
+ int cur_idx = 0;
+ kaldi::Matrix features(feature_rows,
+ feature_cache.Dim());
+ for (auto feat : feats) {
+ int num_rows = feat.Dim() / feature_cache.Dim();
+ for (int row_idx = 0; row_idx < num_rows; ++row_idx) {
+ for (size_t col_idx = 0; col_idx < feature_cache.Dim();
+ ++col_idx) {
+ features(cur_idx, col_idx) =
+ feat(row_idx * feature_cache.Dim() + col_idx);
+ }
+ ++cur_idx;
+ }
+ }
+ feat_writer.Write(utt, features);
+ feature_cache.Reset();
+
+ if (num_done % 50 == 0 && num_done != 0)
+ KALDI_VLOG(2) << "Processed " << num_done << " utterances";
+ num_done++;
+ }
+ KALDI_LOG << "Done " << num_done << " utterances, " << num_err
+ << " with errors.";
+ return (num_done != 0 ? 0 : 1);
+}
diff --git a/speechx/speechx/decoder/CMakeLists.txt b/speechx/speechx/decoder/CMakeLists.txt
index 7cd281b66..ee0863fd5 100644
--- a/speechx/speechx/decoder/CMakeLists.txt
+++ b/speechx/speechx/decoder/CMakeLists.txt
@@ -6,5 +6,6 @@ add_library(decoder STATIC
ctc_decoders/decoder_utils.cpp
ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp
+ ctc_tlg_decoder.cc
)
-target_link_libraries(decoder PUBLIC kenlm utils fst)
\ No newline at end of file
+target_link_libraries(decoder PUBLIC kenlm utils fst)
diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_beam_search_decoder.cc
index 5d7a4f77a..b4caa8e7b 100644
--- a/speechx/speechx/decoder/ctc_beam_search_decoder.cc
+++ b/speechx/speechx/decoder/ctc_beam_search_decoder.cc
@@ -93,7 +93,7 @@ void CTCBeamSearch::AdvanceDecode(
vector> likelihood;
vector frame_prob;
bool flag =
- decodable->FrameLogLikelihood(num_frame_decoded_, &frame_prob);
+ decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
if (flag == false) break;
likelihood.push_back(frame_prob);
AdvanceDecoding(likelihood);
diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.h b/speechx/speechx/decoder/ctc_beam_search_decoder.h
index 1387eee79..9d0a5d142 100644
--- a/speechx/speechx/decoder/ctc_beam_search_decoder.h
+++ b/speechx/speechx/decoder/ctc_beam_search_decoder.h
@@ -15,7 +15,7 @@
#include "base/common.h"
#include "decoder/ctc_decoders/path_trie.h"
#include "decoder/ctc_decoders/scorer.h"
-#include "nnet/decodable-itf.h"
+#include "kaldi/decoder/decodable-itf.h"
#include "util/parse-options.h"
#pragma once
diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.cc b/speechx/speechx/decoder/ctc_tlg_decoder.cc
new file mode 100644
index 000000000..5365e7090
--- /dev/null
+++ b/speechx/speechx/decoder/ctc_tlg_decoder.cc
@@ -0,0 +1,66 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "decoder/ctc_tlg_decoder.h"
+namespace ppspeech {
+
+TLGDecoder::TLGDecoder(TLGDecoderOptions opts) {
+ fst_.reset(fst::Fst::Read(opts.fst_path));
+ CHECK(fst_ != nullptr);
+ word_symbol_table_.reset(
+ fst::SymbolTable::ReadText(opts.word_symbol_table));
+ decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts));
+ decoder_->InitDecoding();
+ frame_decoded_size_ = 0;
+}
+
+void TLGDecoder::InitDecoder() {
+ decoder_->InitDecoding();
+ frame_decoded_size_ = 0;
+}
+
+void TLGDecoder::AdvanceDecode(
+ const std::shared_ptr& decodable) {
+ while (!decodable->IsLastFrame(frame_decoded_size_)) {
+ LOG(INFO) << "num frame decode: " << frame_decoded_size_;
+ AdvanceDecoding(decodable.get());
+ }
+}
+
+void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) {
+ decoder_->AdvanceDecoding(decodable, 1);
+ frame_decoded_size_++;
+}
+
+void TLGDecoder::Reset() {
+ InitDecoder();
+ return;
+}
+
+std::string TLGDecoder::GetFinalBestPath() {
+ decoder_->FinalizeDecoding();
+ kaldi::Lattice lat;
+ kaldi::LatticeWeight weight;
+ std::vector alignment;
+ std::vector words_id;
+ decoder_->GetBestPath(&lat, true);
+ fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight);
+ std::string words;
+ for (int32 idx = 0; idx < words_id.size(); ++idx) {
+ std::string word = word_symbol_table_->Find(words_id[idx]);
+ words += word;
+ }
+ return words;
+}
+}
\ No newline at end of file
diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.h b/speechx/speechx/decoder/ctc_tlg_decoder.h
new file mode 100644
index 000000000..361c44af5
--- /dev/null
+++ b/speechx/speechx/decoder/ctc_tlg_decoder.h
@@ -0,0 +1,59 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include "base/basic_types.h"
+#include "kaldi/decoder/decodable-itf.h"
+#include "kaldi/decoder/lattice-faster-online-decoder.h"
+#include "util/parse-options.h"
+
+namespace ppspeech {
+
+struct TLGDecoderOptions {
+ kaldi::LatticeFasterDecoderConfig opts;
+ // todo remove later, add into decode resource
+ std::string word_symbol_table;
+ std::string fst_path;
+
+ TLGDecoderOptions() : word_symbol_table(""), fst_path("") {}
+};
+
+class TLGDecoder {
+ public:
+ explicit TLGDecoder(TLGDecoderOptions opts);
+ void InitDecoder();
+ void Decode();
+ std::string GetBestPath();
+ std::vector> GetNBestPath();
+ std::string GetFinalBestPath();
+ int NumFrameDecoded();
+ int DecodeLikelihoods(const std::vector>& probs,
+ std::vector& nbest_words);
+ void AdvanceDecode(
+ const std::shared_ptr& decodable);
+ void Reset();
+
+ private:
+ void AdvanceDecoding(kaldi::DecodableInterface* decodable);
+
+ std::shared_ptr decoder_;
+ std::shared_ptr> fst_;
+ std::shared_ptr word_symbol_table_;
+ // the frame size which have decoded starts from 0.
+ int32 frame_decoded_size_;
+};
+
+
+} // namespace ppspeech
\ No newline at end of file
diff --git a/speechx/speechx/frontend/audio/audio_cache.cc b/speechx/speechx/frontend/audio/audio_cache.cc
index c3233e595..50aca4fb0 100644
--- a/speechx/speechx/frontend/audio/audio_cache.cc
+++ b/speechx/speechx/frontend/audio/audio_cache.cc
@@ -21,15 +21,20 @@ using kaldi::BaseFloat;
using kaldi::VectorBase;
using kaldi::Vector;
-AudioCache::AudioCache(int buffer_size)
+AudioCache::AudioCache(int buffer_size, bool convert2PCM32)
: finished_(false),
capacity_(buffer_size),
size_(0),
offset_(0),
- timeout_(1) {
+ timeout_(1),
+ convert2PCM32_(convert2PCM32) {
ring_buffer_.resize(capacity_);
}
+BaseFloat AudioCache::Convert2PCM32(BaseFloat val) {
+ return val * (1. / std::pow(2.0, 15));
+}
+
void AudioCache::Accept(const VectorBase& waves) {
std::unique_lock lock(mutex_);
while (size_ + waves.Dim() > ring_buffer_.size()) {
@@ -38,6 +43,8 @@ void AudioCache::Accept(const VectorBase& waves) {
for (size_t idx = 0; idx < waves.Dim(); ++idx) {
int32 buffer_idx = (idx + offset_) % ring_buffer_.size();
ring_buffer_[buffer_idx] = waves(idx);
+ if (convert2PCM32_)
+ ring_buffer_[buffer_idx] = Convert2PCM32(waves(idx));
}
size_ += waves.Dim();
}
diff --git a/speechx/speechx/frontend/audio/audio_cache.h b/speechx/speechx/frontend/audio/audio_cache.h
index 17e1a8389..adef12399 100644
--- a/speechx/speechx/frontend/audio/audio_cache.h
+++ b/speechx/speechx/frontend/audio/audio_cache.h
@@ -23,7 +23,8 @@ namespace ppspeech {
// waves cache
class AudioCache : public FrontendInterface {
public:
- explicit AudioCache(int buffer_size = kint16max);
+ explicit AudioCache(int buffer_size = 1000 * kint16max,
+ bool convert2PCM32 = false);
virtual void Accept(const kaldi::VectorBase& waves);
@@ -46,14 +47,17 @@ class AudioCache : public FrontendInterface {
}
private:
+ kaldi::BaseFloat Convert2PCM32(kaldi::BaseFloat val);
+
std::vector ring_buffer_;
size_t offset_; // offset in ring_buffer_
size_t size_; // samples in ring_buffer_ now
size_t capacity_; // capacity of ring_buffer_
bool finished_; // reach audio end
- mutable std::mutex mutex_;
+ std::mutex mutex_;
std::condition_variable ready_feed_condition_;
kaldi::int32 timeout_; // millisecond
+ bool convert2PCM32_;
DISALLOW_COPY_AND_ASSIGN(AudioCache);
};
diff --git a/speechx/speechx/frontend/audio/cmvn.cc b/speechx/speechx/frontend/audio/cmvn.cc
index 4c1ffd6a1..c7e446c92 100644
--- a/speechx/speechx/frontend/audio/cmvn.cc
+++ b/speechx/speechx/frontend/audio/cmvn.cc
@@ -120,4 +120,4 @@ void CMVN::ApplyCMVN(kaldi::MatrixBase* feats) {
ApplyCmvn(stats_, var_norm_, feats);
}
-} // namespace ppspeech
\ No newline at end of file
+} // namespace ppspeech
diff --git a/speechx/speechx/frontend/audio/linear_spectrogram.h b/speechx/speechx/frontend/audio/linear_spectrogram.h
index 896c494dd..6b20b8b94 100644
--- a/speechx/speechx/frontend/audio/linear_spectrogram.h
+++ b/speechx/speechx/frontend/audio/linear_spectrogram.h
@@ -46,7 +46,10 @@ class LinearSpectrogram : public FrontendInterface {
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
- virtual void Reset() { base_extractor_->Reset(); }
+ virtual void Reset() {
+ base_extractor_->Reset();
+ reminded_wav_.Resize(0);
+ }
private:
bool Compute(const kaldi::Vector& waves,
diff --git a/speechx/speechx/kaldi/CMakeLists.txt b/speechx/speechx/kaldi/CMakeLists.txt
index 414a6fa0c..6f7398cd1 100644
--- a/speechx/speechx/kaldi/CMakeLists.txt
+++ b/speechx/speechx/kaldi/CMakeLists.txt
@@ -4,3 +4,6 @@ add_subdirectory(base)
add_subdirectory(util)
add_subdirectory(feat)
add_subdirectory(matrix)
+add_subdirectory(lat)
+add_subdirectory(fstext)
+add_subdirectory(decoder)
diff --git a/speechx/speechx/kaldi/decoder/CMakeLists.txt b/speechx/speechx/kaldi/decoder/CMakeLists.txt
new file mode 100644
index 000000000..f1ee6eabb
--- /dev/null
+++ b/speechx/speechx/kaldi/decoder/CMakeLists.txt
@@ -0,0 +1,6 @@
+
+add_library(kaldi-decoder
+lattice-faster-decoder.cc
+lattice-faster-online-decoder.cc
+)
+target_link_libraries(kaldi-decoder PUBLIC kaldi-lat)
diff --git a/speechx/speechx/nnet/decodable-itf.h b/speechx/speechx/kaldi/decoder/decodable-itf.h
similarity index 98%
rename from speechx/speechx/nnet/decodable-itf.h
rename to speechx/speechx/kaldi/decoder/decodable-itf.h
index 8e9a5a72a..b8ce9143e 100644
--- a/speechx/speechx/nnet/decodable-itf.h
+++ b/speechx/speechx/kaldi/decoder/decodable-itf.h
@@ -121,7 +121,7 @@ class DecodableInterface {
/// decoding-from-matrix setting where we want to allow the last delta or
/// LDA
/// features to be flushed out for compatibility with the baseline setup.
- virtual bool IsLastFrame(int32 frame) const = 0;
+ virtual bool IsLastFrame(int32 frame) = 0;
/// The call NumFramesReady() will return the number of frames currently
/// available
@@ -143,7 +143,7 @@ class DecodableInterface {
/// this is for compatibility with OpenFst).
virtual int32 NumIndices() const = 0;
- virtual bool FrameLogLikelihood(
+ virtual bool FrameLikelihood(
int32 frame, std::vector* likelihood) = 0;
diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc
index 42d1d2af4..ae6b71600 100644
--- a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc
+++ b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc
@@ -1007,14 +1007,10 @@ template class LatticeFasterDecoderTpl, decoder::StdToken>
template class LatticeFasterDecoderTpl, decoder::StdToken >;
template class LatticeFasterDecoderTpl, decoder::StdToken >;
-template class LatticeFasterDecoderTpl;
-template class LatticeFasterDecoderTpl;
template class LatticeFasterDecoderTpl , decoder::BackpointerToken>;
template class LatticeFasterDecoderTpl, decoder::BackpointerToken >;
template class LatticeFasterDecoderTpl, decoder::BackpointerToken >;
-template class LatticeFasterDecoderTpl;
-template class LatticeFasterDecoderTpl;
} // end namespace kaldi.
diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h
index 2016ad571..d142a8c7d 100644
--- a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h
+++ b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h
@@ -23,11 +23,10 @@
#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_H_
#define KALDI_DECODER_LATTICE_FASTER_DECODER_H_
-#include "decoder/grammar-fst.h"
#include "fst/fstlib.h"
#include "fst/memory.h"
#include "fstext/fstext-lib.h"
-#include "itf/decodable-itf.h"
+#include "decoder/decodable-itf.h"
#include "lat/determinize-lattice-pruned.h"
#include "lat/kaldi-lattice.h"
#include "util/hash-list.h"
diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc
index ebdace7e8..b5261503c 100644
--- a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc
+++ b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc
@@ -278,8 +278,8 @@ bool LatticeFasterOnlineDecoderTpl::GetRawLatticePruned(
template class LatticeFasterOnlineDecoderTpl >;
template class LatticeFasterOnlineDecoderTpl >;
template class LatticeFasterOnlineDecoderTpl >;
-template class LatticeFasterOnlineDecoderTpl;
-template class LatticeFasterOnlineDecoderTpl;
+//template class LatticeFasterOnlineDecoderTpl;
+//template class LatticeFasterOnlineDecoderTpl;
} // end namespace kaldi.
diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h
index 8b10996fd..f57368a44 100644
--- a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h
+++ b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h
@@ -30,7 +30,7 @@
#include "util/stl-utils.h"
#include "util/hash-list.h"
#include "fst/fstlib.h"
-#include "itf/decodable-itf.h"
+#include "decoder/decodable-itf.h"
#include "fstext/fstext-lib.h"
#include "lat/determinize-lattice-pruned.h"
#include "lat/kaldi-lattice.h"
diff --git a/speechx/speechx/kaldi/fstext/CMakeLists.txt b/speechx/speechx/kaldi/fstext/CMakeLists.txt
new file mode 100644
index 000000000..af91fd985
--- /dev/null
+++ b/speechx/speechx/kaldi/fstext/CMakeLists.txt
@@ -0,0 +1,5 @@
+
+add_library(kaldi-fstext
+kaldi-fst-io.cc
+)
+target_link_libraries(kaldi-fstext PUBLIC kaldi-util)
diff --git a/speechx/speechx/kaldi/fstext/determinize-lattice-inl.h b/speechx/speechx/kaldi/fstext/determinize-lattice-inl.h
new file mode 100644
index 000000000..0bfbc8f41
--- /dev/null
+++ b/speechx/speechx/kaldi/fstext/determinize-lattice-inl.h
@@ -0,0 +1,1357 @@
+// fstext/determinize-lattice-inl.h
+
+// Copyright 2009-2012 Microsoft Corporation
+// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
+
+// See ../../COPYING for clarification regarding multiple authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+// MERCHANTABLITY OR NON-INFRINGEMENT.
+// See the Apache 2 License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef KALDI_FSTEXT_DETERMINIZE_LATTICE_INL_H_
+#define KALDI_FSTEXT_DETERMINIZE_LATTICE_INL_H_
+// Do not include this file directly. It is included by determinize-lattice.h
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace fst {
+
+// This class maps back and forth from/to integer id's to sequences of strings.
+// used in determinization algorithm. It is constructed in such a way that
+// finding the string-id of the successor of (string, next-label) has constant
+// time.
+
+// Note: class IntType, typically int32, is the type of the element in the
+// string (typically a template argument of the CompactLatticeWeightTpl).
+
+template
+class LatticeStringRepository {
+ public:
+ struct Entry {
+ const Entry *parent; // NULL for empty string.
+ IntType i;
+ inline bool operator==(const Entry &other) const {
+ return (parent == other.parent && i == other.i);
+ }
+ Entry() {}
+ Entry(const Entry &e) : parent(e.parent), i(e.i) {}
+ };
+ // Note: all Entry* pointers returned in function calls are
+ // owned by the repository itself, not by the caller!
+
+ // Interface guarantees empty string is NULL.
+ inline const Entry *EmptyString() { return NULL; }
+
+ // Returns string of "parent" with i appended. Pointer
+ // owned by repository
+ const Entry *Successor(const Entry *parent, IntType i) {
+ new_entry_->parent = parent;
+ new_entry_->i = i;
+
+ std::pair pr = set_.insert(new_entry_);
+ if (pr.second) { // Was successfully inserted (was not there). We need to
+ // replace the element we inserted, which resides on the
+ // stack, with one from the heap.
+ const Entry *ans = new_entry_;
+ new_entry_ = new Entry();
+ return ans;
+ } else { // Was not inserted because an equivalent Entry already
+ // existed.
+ return *pr.first;
+ }
+ }
+
+ const Entry *Concatenate(const Entry *a, const Entry *b) {
+ if (a == NULL)
+ return b;
+ else if (b == NULL)
+ return a;
+ std::vector v;
+ ConvertToVector(b, &v);
+ const Entry *ans = a;
+ for (size_t i = 0; i < v.size(); i++) ans = Successor(ans, v[i]);
+ return ans;
+ }
+ const Entry *CommonPrefix(const Entry *a, const Entry *b) {
+ std::vector a_vec, b_vec;
+ ConvertToVector(a, &a_vec);
+ ConvertToVector(b, &b_vec);
+ const Entry *ans = NULL;
+ for (size_t i = 0;
+ i < a_vec.size() && i < b_vec.size() && a_vec[i] == b_vec[i]; i++)
+ ans = Successor(ans, a_vec[i]);
+ return ans;
+ }
+
+ // removes any elements from b that are not part of
+ // a common prefix with a.
+ void ReduceToCommonPrefix(const Entry *a, std::vector *b) {
+ size_t a_size = Size(a), b_size = b->size();
+ while (a_size > b_size) {
+ a = a->parent;
+ a_size--;
+ }
+ if (b_size > a_size) b_size = a_size;
+ typename std::vector::iterator b_begin = b->begin();
+ while (a_size != 0) {
+ if (a->i != *(b_begin + a_size - 1)) b_size = a_size - 1;
+ a = a->parent;
+ a_size--;
+ }
+ if (b_size != b->size()) b->resize(b_size);
+ }
+
+ // removes the first n elements of a.
+ const Entry *RemovePrefix(const Entry *a, size_t n) {
+ if (n == 0) return a;
+ std::vector a_vec;
+ ConvertToVector(a, &a_vec);
+ assert(a_vec.size() >= n);
+ const Entry *ans = NULL;
+ for (size_t i = n; i < a_vec.size(); i++) ans = Successor(ans, a_vec[i]);
+ return ans;
+ }
+
+ // Returns true if a is a prefix of b. If a is prefix of b,
+ // time taken is |b| - |a|. Else, time taken is |b|.
+ bool IsPrefixOf(const Entry *a, const Entry *b) const {
+ if (a == NULL) return true; // empty string prefix of all.
+ if (a == b) return true;
+ if (b == NULL) return false;
+ return IsPrefixOf(a, b->parent);
+ }
+
+ inline size_t Size(const Entry *entry) const {
+ size_t ans = 0;
+ while (entry != NULL) {
+ ans++;
+ entry = entry->parent;
+ }
+ return ans;
+ }
+
+ void ConvertToVector(const Entry *entry, std::vector *out) const {
+ size_t length = Size(entry);
+ out->resize(length);
+ if (entry != NULL) {
+ typename std::vector::reverse_iterator iter = out->rbegin();
+ while (entry != NULL) {
+ *iter = entry->i;
+ entry = entry->parent;
+ ++iter;
+ }
+ }
+ }
+
+ const Entry *ConvertFromVector(const std::vector &vec) {
+ const Entry *e = NULL;
+ for (size_t i = 0; i < vec.size(); i++) e = Successor(e, vec[i]);
+ return e;
+ }
+
+ LatticeStringRepository() { new_entry_ = new Entry; }
+
+ void Destroy() {
+ for (typename SetType::iterator iter = set_.begin(); iter != set_.end();
+ ++iter)
+ delete *iter;
+ SetType tmp;
+ tmp.swap(set_);
+ if (new_entry_) {
+ delete new_entry_;
+ new_entry_ = NULL;
+ }
+ }
+
+ // Rebuild will rebuild this object, guaranteeing only
+ // to preserve the Entry values that are in the vector pointed
+ // to (this list does not have to be unique). The point of
+ // this is to save memory.
+ void Rebuild(const std::vector &to_keep) {
+ SetType tmp_set;
+ for (typename std::vector::const_iterator iter =
+ to_keep.begin();
+ iter != to_keep.end(); ++iter)
+ RebuildHelper(*iter, &tmp_set);
+ // Now delete all elems not in tmp_set.
+ for (typename SetType::iterator iter = set_.begin(); iter != set_.end();
+ ++iter) {
+ if (tmp_set.count(*iter) == 0)
+ delete (*iter); // delete the Entry; not needed.
+ }
+ set_.swap(tmp_set);
+ }
+
+ ~LatticeStringRepository() { Destroy(); }
+ int32 MemSize() const {
+ return set_.size() * sizeof(Entry) * 2; // this is a lower bound
+ // on the size this structure might take.
+ }
+
+ private:
+ class EntryKey { // Hash function object.
+ public:
+ inline size_t operator()(const Entry *entry) const {
+ size_t prime = 49109;
+ return static_cast(entry->i) +
+ prime * reinterpret_cast(entry->parent);
+ }
+ };
+ class EntryEqual {
+ public:
+ inline bool operator()(const Entry *e1, const Entry *e2) const {
+ return (*e1 == *e2);
+ }
+ };
+ typedef std::unordered_set SetType;
+
+ void RebuildHelper(const Entry *to_add, SetType *tmp_set) {
+ while (true) {
+ if (to_add == NULL) return;
+ typename SetType::iterator iter = tmp_set->find(to_add);
+ if (iter == tmp_set->end()) { // not in tmp_set.
+ tmp_set->insert(to_add);
+ to_add = to_add->parent; // and loop.
+ } else {
+ return;
+ }
+ }
+ }
+
+ KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeStringRepository);
+ Entry *new_entry_; // We always have a pre-allocated Entry ready to use,
+ // to avoid unnecessary news and deletes.
+ SetType set_;
+};
+
+// class LatticeDeterminizer is templated on the same types that
+// CompactLatticeWeight is templated on: the base weight (Weight), typically
+// LatticeWeightTpl etc. but could also be e.g. TropicalWeight, and the
+// IntType, typically int32, used for the output symbols in the compact
+// representation of strings [note: the output symbols would usually be
+// p.d.f. id's in the anticipated use of this code] It has a special requirement
+// on the Weight type: that there should be a Compare function on the weights
+// such that Compare(w1, w2) returns -1 if w1 < w2, 0 if w1 == w2, and +1 if w1
+// > w2. This requires that there be a total order on the weights.
+
+template
+class LatticeDeterminizer {
+ public:
+ // Output to Gallic acceptor (so the strings go on weights, and there is a 1-1
+ // correspondence between our states and the states in ofst. If destroy ==
+ // true, release memory as we go (but we cannot output again).
+
+ typedef CompactLatticeWeightTpl CompactWeight;
+ typedef ArcTpl
+ CompactArc; // arc in compact, acceptor form of lattice
+ typedef ArcTpl Arc; // arc in non-compact version of lattice
+
+ // Output to standard FST with CompactWeightTpl as its weight type
+ // (the weight stores the original output-symbol strings). If destroy ==
+ // true, release memory as we go (but we cannot output again).
+ void Output(MutableFst *ofst, bool destroy = true) {
+ assert(determinized_);
+ typedef typename Arc::StateId StateId;
+ StateId nStates = static_cast(output_arcs_.size());
+ if (destroy) FreeMostMemory();
+ ofst->DeleteStates();
+ ofst->SetStart(kNoStateId);
+ if (nStates == 0) {
+ return;
+ }
+ for (StateId s = 0; s < nStates; s++) {
+ OutputStateId news = ofst->AddState();
+ assert(news == s);
+ }
+ ofst->SetStart(0);
+ // now process transitions.
+ for (StateId this_state = 0; this_state < nStates; this_state++) {
+ std::vector &this_vec(output_arcs_[this_state]);
+ typename std::vector::const_iterator iter = this_vec.begin(),
+ end = this_vec.end();
+
+ for (; iter != end; ++iter) {
+ const TempArc &temp_arc(*iter);
+ CompactArc new_arc;
+ std::vector seq;
+ repository_.ConvertToVector(temp_arc.string, &seq);
+ CompactWeight weight(temp_arc.weight, seq);
+ if (temp_arc.nextstate == kNoStateId) { // is really final weight.
+ ofst->SetFinal(this_state, weight);
+ } else { // is really an arc.
+ new_arc.nextstate = temp_arc.nextstate;
+ new_arc.ilabel = temp_arc.ilabel;
+ new_arc.olabel = temp_arc.ilabel; // acceptor. input == output.
+ new_arc.weight = weight; // includes string and weight.
+ ofst->AddArc(this_state, new_arc);
+ }
+ }
+ // Free up memory. Do this inside the loop as ofst is also allocating
+ // memory
+ if (destroy) {
+ std::vector temp;
+ std::swap(temp, this_vec);
+ }
+ }
+ if (destroy) {
+ std::vector > temp;
+ std::swap(temp, output_arcs_);
+ }
+ }
+
+ // Output to standard FST with Weight as its weight type. We will create
+ // extra states to handle sequences of symbols on the output. If destroy ==
+ // true, release memory as we go (but we cannot output again).
+ void Output(MutableFst *ofst, bool destroy = true) {
+ // Outputs to standard fst.
+ OutputStateId nStates = static_cast(output_arcs_.size());
+ ofst->DeleteStates();
+ if (nStates == 0) {
+ ofst->SetStart(kNoStateId);
+ return;
+ }
+ if (destroy) FreeMostMemory();
+ // Add basic states-- but we will add extra ones to account for strings on
+ // output.
+ for (OutputStateId s = 0; s < nStates; s++) {
+ OutputStateId news = ofst->AddState();
+ assert(news == s);
+ }
+ ofst->SetStart(0);
+ for (OutputStateId this_state = 0; this_state < nStates; this_state++) {
+ std::vector &this_vec(output_arcs_[this_state]);
+
+ typename std::vector::const_iterator iter = this_vec.begin(),
+ end = this_vec.end();
+ for (; iter != end; ++iter) {
+ const TempArc &temp_arc(*iter);
+ std::vector seq;
+ repository_.ConvertToVector(temp_arc.string, &seq);
+
+ if (temp_arc.nextstate == kNoStateId) { // Really a final weight.
+ // Make a sequence of states going to a final state, with the strings
+ // as labels. Put the weight on the first arc.
+ OutputStateId cur_state = this_state;
+ for (size_t i = 0; i < seq.size(); i++) {
+ OutputStateId next_state = ofst->AddState();
+ Arc arc;
+ arc.nextstate = next_state;
+ arc.weight = (i == 0 ? temp_arc.weight : Weight::One());
+ arc.ilabel = 0; // epsilon.
+ arc.olabel = seq[i];
+ ofst->AddArc(cur_state, arc);
+ cur_state = next_state;
+ }
+ ofst->SetFinal(cur_state,
+ (seq.size() == 0 ? temp_arc.weight : Weight::One()));
+ } else { // Really an arc.
+ OutputStateId cur_state = this_state;
+ // Have to be careful with this integer comparison (i+1 < seq.size())
+ // because unsigned. i < seq.size()-1 could fail for zero-length
+ // sequences.
+ for (size_t i = 0; i + 1 < seq.size(); i++) {
+ // for all but the last element of seq, create new state.
+ OutputStateId next_state = ofst->AddState();
+ Arc arc;
+ arc.nextstate = next_state;
+ arc.weight = (i == 0 ? temp_arc.weight : Weight::One());
+ arc.ilabel = (i == 0 ? temp_arc.ilabel
+ : 0); // put ilabel on first element of seq.
+ arc.olabel = seq[i];
+ ofst->AddArc(cur_state, arc);
+ cur_state = next_state;
+ }
+ // Add the final arc in the sequence.
+ Arc arc;
+ arc.nextstate = temp_arc.nextstate;
+ arc.weight = (seq.size() <= 1 ? temp_arc.weight : Weight::One());
+ arc.ilabel = (seq.size() <= 1 ? temp_arc.ilabel : 0);
+ arc.olabel = (seq.size() > 0 ? seq.back() : 0);
+ ofst->AddArc(cur_state, arc);
+ }
+ }
+ // Free up memory. Do this inside the loop as ofst is also allocating
+ // memory
+ if (destroy) {
+ std::vector temp;
+ temp.swap(this_vec);
+ }
+ }
+ if (destroy) {
+ std::vector > temp;
+ temp.swap(output_arcs_);
+ repository_.Destroy();
+ }
+ }
+
+ // Initializer. After initializing the object you will typically
+ // call Determinize() and then call one of the Output functions.
+ // Note: ifst.Copy() will generally do a
+ // shallow copy. We do it like this for memory safety, rather than
+ // keeping a reference or pointer to ifst_.
+ LatticeDeterminizer(const Fst &ifst, DeterminizeLatticeOptions opts)
+ : num_arcs_(0),
+ num_elems_(0),
+ ifst_(ifst.Copy()),
+ opts_(opts),
+ equal_(opts_.delta),
+ determinized_(false),
+ minimal_hash_(3, hasher_, equal_),
+ initial_hash_(3, hasher_, equal_) {
+ KALDI_ASSERT(Weight::Properties() & kIdempotent); // this algorithm won't
+ // work correctly otherwise.
+ }
+
+ // frees all except output_arcs_, which contains the important info
+ // we need to output the FST.
+ void FreeMostMemory() {
+ if (ifst_) {
+ delete ifst_;
+ ifst_ = NULL;
+ }
+ for (typename MinimalSubsetHash::iterator iter = minimal_hash_.begin();
+ iter != minimal_hash_.end(); ++iter)
+ delete iter->first;
+ {
+ MinimalSubsetHash tmp;
+ tmp.swap(minimal_hash_);
+ }
+ for (typename InitialSubsetHash::iterator iter = initial_hash_.begin();
+ iter != initial_hash_.end(); ++iter)
+ delete iter->first;
+ {
+ InitialSubsetHash tmp;
+ tmp.swap(initial_hash_);
+ }
+ {
+ std::vector *> output_states_tmp;
+ output_states_tmp.swap(output_states_);
+ }
+ {
+ std::vector tmp;
+ tmp.swap(isymbol_or_final_);
+ }
+ {
+ std::vector tmp;
+ tmp.swap(queue_);
+ }
+ {
+ std::vector > tmp;
+ tmp.swap(all_elems_tmp_);
+ }
+ }
+
+ ~LatticeDeterminizer() {
+ FreeMostMemory(); // rest is deleted by destructors.
+ }
+ void RebuildRepository() { // rebuild the string repository,
+ // freeing stuff we don't need.. we call this when memory usage
+ // passes a supplied threshold. We need to accumulate all the
+ // strings we need the repository to "remember", then tell it
+ // to clean the repository.
+ std::vector needed_strings;
+ for (size_t i = 0; i < output_arcs_.size(); i++)
+ for (size_t j = 0; j < output_arcs_[i].size(); j++)
+ needed_strings.push_back(output_arcs_[i][j].string);
+
+ // the following loop covers strings present in minimal_hash_
+ // which are also accessible via output_states_.
+ for (size_t i = 0; i < output_states_.size(); i++)
+ for (size_t j = 0; j < output_states_[i]->size(); j++)
+ needed_strings.push_back((*(output_states_[i]))[j].string);
+
+ // the following loop covers strings present in initial_hash_.
+ for (typename InitialSubsetHash::const_iterator iter =
+ initial_hash_.begin();
+ iter != initial_hash_.end(); ++iter) {
+ const std::vector &vec = *(iter->first);
+ Element elem = iter->second;
+ for (size_t i = 0; i < vec.size(); i++)
+ needed_strings.push_back(vec[i].string);
+ needed_strings.push_back(elem.string);
+ }
+
+ std::sort(needed_strings.begin(), needed_strings.end());
+ needed_strings.erase(
+ std::unique(needed_strings.begin(), needed_strings.end()),
+ needed_strings.end()); // uniq the strings.
+ repository_.Rebuild(needed_strings);
+ }
+
+ bool CheckMemoryUsage() {
+ int32 repo_size = repository_.MemSize(),
+ arcs_size = num_arcs_ * sizeof(TempArc),
+ elems_size = num_elems_ * sizeof(Element),
+ total_size = repo_size + arcs_size + elems_size;
+ if (opts_.max_mem > 0 &&
+ total_size > opts_.max_mem) { // We passed the memory threshold.
+ // This is usually due to the repository getting large, so we
+ // clean this out.
+ RebuildRepository();
+ int32 new_repo_size = repository_.MemSize(),
+ new_total_size = new_repo_size + arcs_size + elems_size;
+
+ KALDI_VLOG(2) << "Rebuilt repository in determinize-lattice: repository "
+ "shrank from "
+ << repo_size << " to " << new_repo_size
+ << " bytes (approximately)";
+
+ if (new_total_size > static_cast(opts_.max_mem * 0.8)) {
+ // Rebuilding didn't help enough-- we need a margin to stop
+ // having to rebuild too often.
+ KALDI_WARN << "Failure in determinize-lattice: size exceeds maximum "
+ << opts_.max_mem << " bytes; (repo,arcs,elems) = ("
+ << repo_size << "," << arcs_size << "," << elems_size
+ << "), after rebuilding, repo size was " << new_repo_size;
+ return false;
+ }
+ }
+ return true;
+ }
+
+ // Returns true on success. Can fail for out-of-memory
+ // or max-states related reasons.
+ bool Determinize(bool *debug_ptr) {
+ assert(!determinized_);
+ // This determinizes the input fst but leaves it in the "special format"
+ // in "output_arcs_". Must be called after Initialize(). To get the
+ // output, call one of the Output routines.
+ try {
+ InitializeDeterminization(); // some start-up tasks.
+ while (!queue_.empty()) {
+ OutputStateId out_state = queue_.back();
+ queue_.pop_back();
+ ProcessState(out_state);
+ if (debug_ptr && *debug_ptr) Debug(); // will exit.
+ if (!CheckMemoryUsage()) return false;
+ }
+ return (determinized_ = true);
+ } catch (const std::bad_alloc &) {
+ int32 repo_size = repository_.MemSize(),
+ arcs_size = num_arcs_ * sizeof(TempArc),
+ elems_size = num_elems_ * sizeof(Element),
+ total_size = repo_size + arcs_size + elems_size;
+ KALDI_WARN
+ << "Memory allocation error doing lattice determinization; using "
+ << total_size << " bytes (max = " << opts_.max_mem
+ << " (repo,arcs,elems) = (" << repo_size << "," << arcs_size << ","
+ << elems_size << ")";
+ return (determinized_ = false);
+ } catch (const std::runtime_error &) {
+ KALDI_WARN << "Caught exception doing lattice determinization";
+ return (determinized_ = false);
+ }
+ }
+
+ private:
+ typedef typename Arc::Label Label;
+ typedef typename Arc::StateId
+ StateId; // use this when we don't know if it's input or output.
+ typedef typename Arc::StateId InputStateId; // state in the input FST.
+ typedef typename Arc::StateId OutputStateId; // same as above but distinguish
+ // states in output Fst.
+
+ typedef LatticeStringRepository StringRepositoryType;
+ typedef const typename StringRepositoryType::Entry *StringId;
+
+ // Element of a subset [of original states]
+ struct Element {
+ StateId state; // use StateId as this is usually InputStateId but in one
+ // case OutputStateId.
+ StringId string;
+ Weight weight;
+ bool operator!=(const Element &other) const {
+ return (state != other.state || string != other.string ||
+ weight != other.weight);
+ }
+ // This operator is only intended to support sorting in EpsilonClosure()
+ bool operator<(const Element &other) const { return state < other.state; }
+ };
+
+ // Arcs in the format we temporarily create in this class (a representation,
+ // essentially of a Gallic Fst).
+ struct TempArc {
+ Label ilabel;
+ StringId string; // Look it up in the StringRepository, it's a sequence of
+ // Labels.
+ OutputStateId nextstate; // or kNoState for final weights.
+ Weight weight;
+ };
+
+ // Hashing function used in hash of subsets.
+ // A subset is a pointer to vector.
+ // The Elements are in sorted order on state id, and without repeated states.
+ // Because the order of Elements is fixed, we can use a hashing function that
+ // is order-dependent. However the weights are not included in the hashing
+ // function-- we hash subsets that differ only in weight to the same key. This
+ // is not optimal in terms of the O(N) performance but typically if we have a
+ // lot of determinized states that differ only in weight then the input
+ // probably was pathological in some way, or even non-determinizable.
+ // We don't quantize the weights, in order to avoid inexactness in simple
+ // cases.
+ // Instead we apply the delta when comparing subsets for equality, and allow a
+ // small difference.
+
+ class SubsetKey {
+ public:
+ size_t operator()(const std::vector *subset)
+ const { // hashes only the state and string.
+ size_t hash = 0, factor = 1;
+ for (typename std::vector::const_iterator iter = subset->begin();
+ iter != subset->end(); ++iter) {
+ hash *= factor;
+ hash += iter->state + reinterpret_cast(iter->string);
+ factor *= 23531; // these numbers are primes.
+ }
+ return hash;
+ }
+ };
+
+ // This is the equality operator on subsets. It checks for exact match on
+ // state-id and string, and approximate match on weights.
+ class SubsetEqual {
+ public:
+ bool operator()(const std::vector *s1,
+ const std::vector *s2) const {
+ size_t sz = s1->size();
+ assert(sz >= 0);
+ if (sz != s2->size()) return false;
+ typename std::vector::const_iterator iter1 = s1->begin(),
+ iter1_end = s1->end(),
+ iter2 = s2->begin();
+ for (; iter1 < iter1_end; ++iter1, ++iter2) {
+ if (iter1->state != iter2->state || iter1->string != iter2->string ||
+ !ApproxEqual(iter1->weight, iter2->weight, delta_))
+ return false;
+ }
+ return true;
+ }
+ float delta_;
+ explicit SubsetEqual(float delta) : delta_(delta) {}
+ SubsetEqual() : delta_(kDelta) {}
+ };
+
+ // Operator that says whether two Elements have the same states.
+ // Used only for debug.
+ class SubsetEqualStates {
+ public:
+ bool operator()(const std::vector *s1,
+ const std::vector *s2) const {
+ size_t sz = s1->size();
+ assert(sz >= 0);
+ if (sz != s2->size()) return false;
+ typename std::vector::const_iterator iter1 = s1->begin(),
+ iter1_end = s1->end(),
+ iter2 = s2->begin();
+ for (; iter1 < iter1_end; ++iter1, ++iter2) {
+ if (iter1->state != iter2->state) return false;
+ }
+ return true;
+ }
+ };
+
+ // Define the hash type we use to map subsets (in minimal
+ // representation) to OutputStateId.
+ typedef std::unordered_map *, OutputStateId,
+ SubsetKey, SubsetEqual>
+ MinimalSubsetHash;
+
+ // Define the hash type we use to map subsets (in initial
+ // representation) to OutputStateId, together with an
+ // extra weight. [note: we interpret the Element.state in here
+ // as an OutputStateId even though it's declared as InputStateId;
+ // these types are the same anyway].
+ typedef std::unordered_map *, Element, SubsetKey,
+ SubsetEqual>
+ InitialSubsetHash;
+
+ // converts the representation of the subset from canonical (all states) to
+ // minimal (only states with output symbols on arcs leaving them, and final
+ // states). Output is not necessarily normalized, even if input_subset was.
+ void ConvertToMinimal(std::vector *subset) {
+ assert(!subset->empty());
+ typename std::vector::iterator cur_in = subset->begin(),
+ cur_out = subset->begin(),
+ end = subset->end();
+ while (cur_in != end) {
+ if (IsIsymbolOrFinal(cur_in->state)) { // keep it...
+ *cur_out = *cur_in;
+ cur_out++;
+ }
+ cur_in++;
+ }
+ subset->resize(cur_out - subset->begin());
+ }
+
+ // Takes a minimal, normalized subset, and converts it to an OutputStateId.
+ // Involves a hash lookup, and possibly adding a new OutputStateId.
+ // If it creates a new OutputStateId, it adds it to the queue.
+ OutputStateId MinimalToStateId(const std::vector &subset) {
+ typename MinimalSubsetHash::const_iterator iter =
+ minimal_hash_.find(&subset);
+ if (iter != minimal_hash_.end()) // Found a matching subset.
+ return iter->second;
+ OutputStateId ans = static_cast(output_arcs_.size());
+ std::vector *subset_ptr = new std::vector(subset);
+ output_states_.push_back(subset_ptr);
+ num_elems_ += subset_ptr->size();
+ output_arcs_.push_back(std::vector());
+ minimal_hash_[subset_ptr] = ans;
+ queue_.push_back(ans);
+ return ans;
+ }
+
+ // Given a normalized initial subset of elements (i.e. before epsilon
+ // closure), compute the corresponding output-state.
+ OutputStateId InitialToStateId(const std::vector &subset_in,
+ Weight *remaining_weight,
+ StringId *common_prefix) {
+ typename InitialSubsetHash::const_iterator iter =
+ initial_hash_.find(&subset_in);
+ if (iter != initial_hash_.end()) { // Found a matching subset.
+ const Element &elem = iter->second;
+ *remaining_weight = elem.weight;
+ *common_prefix = elem.string;
+ if (elem.weight == Weight::Zero()) KALDI_WARN << "Zero weight!"; // TEMP
+ return elem.state;
+ }
+ // else no matching subset-- have to work it out.
+ std::vector subset(subset_in);
+ // Follow through epsilons. Will add no duplicate states. note: after
+ // EpsilonClosure, it is the same as "canonical" subset, except not
+ // normalized (actually we never compute the normalized canonical subset,
+ // only the normalized minimal one).
+ EpsilonClosure(&subset); // follow epsilons.
+ ConvertToMinimal(&subset); // remove all but emitting and final states.
+
+ Element elem; // will be used to store remaining weight and string, and
+ // OutputStateId, in initial_hash_;
+ NormalizeSubset(&subset, &elem.weight,
+ &elem.string); // normalize subset; put
+ // common string and weight in "elem". The subset is now a minimal,
+ // normalized subset.
+
+ OutputStateId ans = MinimalToStateId(subset);
+ *remaining_weight = elem.weight;
+ *common_prefix = elem.string;
+ if (elem.weight == Weight::Zero()) KALDI_WARN << "Zero weight!"; // TEMP
+
+ // Before returning "ans", add the initial subset to the hash,
+ // so that we can bypass the epsilon-closure etc., next time
+ // we process the same initial subset.
+ std::vector *initial_subset_ptr =
+ new std::vector(subset_in);
+ elem.state = ans;
+ initial_hash_[initial_subset_ptr] = elem;
+ num_elems_ += initial_subset_ptr->size(); // keep track of memory usage.
+ return ans;
+ }
+
+ // returns the Compare value (-1 if a < b, 0 if a == b, 1 if a > b) according
+ // to the ordering we defined on strings for the CompactLatticeWeightTpl.
+ // see function
+ // inline int Compare (const CompactLatticeWeightTpl &w1,
+ // const CompactLatticeWeightTpl &w2)
+ // in lattice-weight.h.
+ // this is the same as that, but optimized for our data structures.
+ inline int Compare(const Weight &a_w, StringId a_str, const Weight &b_w,
+ StringId b_str) const {
+ int weight_comp = fst::Compare(a_w, b_w);
+ if (weight_comp != 0) return weight_comp;
+ // now comparing strings.
+ if (a_str == b_str) return 0;
+ std::vector a_vec, b_vec;
+ repository_.ConvertToVector(a_str, &a_vec);
+ repository_.ConvertToVector(b_str, &b_vec);
+ // First compare their lengths.
+ int a_len = a_vec.size(), b_len = b_vec.size();
+ // use opposite order on the string lengths (c.f. Compare in
+ // lattice-weight.h)
+ if (a_len > b_len)
+ return -1;
+ else if (a_len < b_len)
+ return 1;
+ for (int i = 0; i < a_len; i++) {
+ if (a_vec[i] < b_vec[i])
+ return -1;
+ else if (a_vec[i] > b_vec[i])
+ return 1;
+ }
+ assert(
+ 0); // because we checked if a_str == b_str above, shouldn't reach here
+ return 0;
+ }
+
+ // This function computes epsilon closure of subset of states by following
+ // epsilon links. Called by InitialToStateId and Initialize. Has no side
+ // effects except on the string repository. The "output_subset" is not
+ // necessarily normalized (in the sense of there being no common substring),
+ // unless input_subset was.
+ void EpsilonClosure(std::vector *subset) {
+ // at input, subset must have only one example of each StateId. [will still
+ // be so at output]. This function follows input-epsilons, and augments the
+ // subset accordingly.
+
+ std::deque queue;
+ std::unordered_map cur_subset;
+ typedef
+ typename std::unordered_map::iterator MapIter;
+ typedef typename std::vector::const_iterator VecIter;
+
+ for (VecIter iter = subset->begin(); iter != subset->end(); ++iter) {
+ queue.push_back(*iter);
+ cur_subset[iter->state] = *iter;
+ }
+
+ // find whether input fst is known to be sorted on input label.
+ bool sorted =
+ ((ifst_->Properties(kILabelSorted, false) & kILabelSorted) != 0);
+ bool replaced_elems = false; // relates to an optimization, see below.
+ int counter =
+ 0; // stops infinite loops here for non-lattice-determinizable input;
+ // useful in testing.
+ while (queue.size() != 0) {
+ Element elem = queue.front();
+ queue.pop_front();
+
+ // The next if-statement is a kind of optimization. It's to prevent us
+ // unnecessarily repeating the processing of a state. "cur_subset" always
+ // contains only one Element with a particular state. The issue is that
+ // whenever we modify the Element corresponding to that state in
+ // "cur_subset", both the new (optimal) and old (less-optimal) Element
+ // will still be in "queue". The next if-statement stops us from wasting
+ // compute by processing the old Element.
+ if (replaced_elems && cur_subset[elem.state] != elem) continue;
+ if (opts_.max_loop > 0 && counter++ > opts_.max_loop) {
+ KALDI_ERR << "Lattice determinization aborted since looped more than "
+ << opts_.max_loop << " times during epsilon closure";
+ }
+ for (ArcIterator > aiter(*ifst_, elem.state); !aiter.Done();
+ aiter.Next()) {
+ const Arc &arc = aiter.Value();
+ if (sorted && arc.ilabel != 0)
+ break; // Break from the loop: due to sorting there will be no
+ // more transitions with epsilons as input labels.
+ if (arc.ilabel == 0 &&
+ arc.weight != Weight::Zero()) { // Epsilon transition.
+ Element next_elem;
+ next_elem.state = arc.nextstate;
+ next_elem.weight = Times(elem.weight, arc.weight);
+ // now must append strings
+ if (arc.olabel == 0)
+ next_elem.string = elem.string;
+ else
+ next_elem.string = repository_.Successor(elem.string, arc.olabel);
+
+ MapIter iter = cur_subset.find(next_elem.state);
+ if (iter == cur_subset.end()) {
+ // was no such StateId: insert and add to queue.
+ cur_subset[next_elem.state] = next_elem;
+ queue.push_back(next_elem);
+ } else {
+ // was not inserted because one already there. In normal
+ // determinization we'd add the weights. Here, we find which one
+ // has the better weight, and keep its corresponding string.
+ int comp = Compare(next_elem.weight, next_elem.string,
+ iter->second.weight, iter->second.string);
+ if (comp ==
+ 1) { // next_elem is better, so use its (weight, string)
+ iter->second.string = next_elem.string;
+ iter->second.weight = next_elem.weight;
+ queue.push_back(next_elem);
+ replaced_elems = true;
+ }
+ // else it is the same or worse, so use original one.
+ }
+ }
+ }
+ }
+
+ { // copy cur_subset to subset.
+ subset->clear();
+ subset->reserve(cur_subset.size());
+ MapIter iter = cur_subset.begin(), end = cur_subset.end();
+ for (; iter != end; ++iter) subset->push_back(iter->second);
+ // sort by state ID, because the subset hash function is
+ // order-dependent(see SubsetKey)
+ std::sort(subset->begin(), subset->end());
+ }
+ }
+
+ // This function works out the final-weight of the determinized state.
+ // called by ProcessSubset.
+ // Has no side effects except on the variable repository_, and output_arcs_.
+
+ void ProcessFinal(OutputStateId output_state) {
+ const std::vector &minimal_subset =
+ *(output_states_[output_state]);
+ // processes final-weights for this subset.
+
+ // minimal_subset may be empty if the graphs is not connected/trimmed, I
+ // think, do don't check that it's nonempty.
+ bool is_final = false;
+ StringId final_string = NULL; // = NULL to keep compiler happy.
+ Weight final_weight = Weight::Zero();
+ typename std::vector::const_iterator iter = minimal_subset.begin(),
+ end = minimal_subset.end();
+ for (; iter != end; ++iter) {
+ const Element &elem = *iter;
+ Weight this_final_weight = Times(elem.weight, ifst_->Final(elem.state));
+ StringId this_final_string = elem.string;
+ if (this_final_weight != Weight::Zero() &&
+ (!is_final || Compare(this_final_weight, this_final_string,
+ final_weight, final_string) == 1)) { // the new
+ // (weight, string) pair is more in semiring than our current
+ // one.
+ is_final = true;
+ final_weight = this_final_weight;
+ final_string = this_final_string;
+ }
+ }
+ if (is_final) {
+ // store final weights in TempArc structure, just like a transition.
+ TempArc temp_arc;
+ temp_arc.ilabel = 0;
+ temp_arc.nextstate =
+ kNoStateId; // special marker meaning "final weight".
+ temp_arc.string = final_string;
+ temp_arc.weight = final_weight;
+ output_arcs_[output_state].push_back(temp_arc);
+ num_arcs_++;
+ }
+ }
+
+ // NormalizeSubset normalizes the subset "elems" by
+ // removing any common string prefix (putting it in common_str),
+ // and dividing by the total weight (putting it in tot_weight).
+ void NormalizeSubset(std::vector *elems, Weight *tot_weight,
+ StringId *common_str) {
+ if (elems->empty()) { // just set common_str, tot_weight
+ KALDI_WARN << "[empty subset]"; // TEMP
+ // to defaults and return...
+ *common_str = repository_.EmptyString();
+ *tot_weight = Weight::Zero();
+ return;
+ }
+ size_t size = elems->size();
+ std::vector common_prefix;
+ repository_.ConvertToVector((*elems)[0].string, &common_prefix);
+ Weight weight = (*elems)[0].weight;
+ for (size_t i = 1; i < size; i++) {
+ weight = Plus(weight, (*elems)[i].weight);
+ repository_.ReduceToCommonPrefix((*elems)[i].string, &common_prefix);
+ }
+ assert(weight != Weight::Zero()); // we made sure to ignore arcs with zero
+ // weights on them, so we shouldn't have zero here.
+ size_t prefix_len = common_prefix.size();
+ for (size_t i = 0; i < size; i++) {
+ (*elems)[i].weight = Divide((*elems)[i].weight, weight, DIVIDE_LEFT);
+ (*elems)[i].string =
+ repository_.RemovePrefix((*elems)[i].string, prefix_len);
+ }
+ *common_str = repository_.ConvertFromVector(common_prefix);
+ *tot_weight = weight;
+ }
+
+ // Take a subset of Elements that is sorted on state, and
+ // merge any Elements that have the same state (taking the best
+ // (weight, string) pair in the semiring).
+ void MakeSubsetUnique(std::vector *subset) {
+ typedef typename std::vector::iterator IterType;
+
+ // This assert is designed to fail (usually) if the subset is not sorted on
+ // state.
+ assert(subset->size() < 2 || (*subset)[0].state <= (*subset)[1].state);
+
+ IterType cur_in = subset->begin(), cur_out = cur_in, end = subset->end();
+ size_t num_out = 0;
+ // Merge elements with same state-id
+ while (cur_in != end) { // while we have more elements to process.
+ // At this point, cur_out points to location of next place we want to put
+ // an element, cur_in points to location of next element we want to
+ // process.
+ if (cur_in != cur_out) *cur_out = *cur_in;
+ cur_in++;
+ while (cur_in != end && cur_in->state == cur_out->state) {
+ if (Compare(cur_in->weight, cur_in->string, cur_out->weight,
+ cur_out->string) == 1) {
+ // if *cur_in > *cur_out in semiring, then take *cur_in.
+ cur_out->string = cur_in->string;
+ cur_out->weight = cur_in->weight;
+ }
+ cur_in++;
+ }
+ cur_out++;
+ num_out++;
+ }
+ subset->resize(num_out);
+ }
+
+ // ProcessTransition is called from "ProcessTransitions". Broken out for
+ // clarity. Processes a transition from state "state". The set of Elements
+ // represents a set of next-states with associated weights and strings, each
+ // one arising from an arc from some state in a determinized-state; the
+ // next-states are not necessarily unique (i.e. there may be >1 entry
+ // associated with each), and any such sets of Elements have to be merged
+ // within this routine (we take the [weight, string] pair that's better in the
+ // semiring).
+ void ProcessTransition(OutputStateId state, Label ilabel,
+ std::vector *subset) {
+ MakeSubsetUnique(subset); // remove duplicates with the same state.
+
+ StringId common_str;
+ Weight tot_weight;
+ NormalizeSubset(subset, &tot_weight, &common_str);
+
+ OutputStateId nextstate;
+ {
+ Weight next_tot_weight;
+ StringId next_common_str;
+ nextstate = InitialToStateId(*subset, &next_tot_weight, &next_common_str);
+ common_str = repository_.Concatenate(common_str, next_common_str);
+ tot_weight = Times(tot_weight, next_tot_weight);
+ }
+
+ // Now add an arc to the next state (would have been created if necessary by
+ // InitialToStateId).
+ TempArc temp_arc;
+ temp_arc.ilabel = ilabel;
+ temp_arc.nextstate = nextstate;
+ temp_arc.string = common_str;
+ temp_arc.weight = tot_weight;
+ output_arcs_[state].push_back(temp_arc); // record the arc.
+ num_arcs_++;
+ }
+
+ // "less than" operator for pair. Used in
+ // ProcessTransitions. Lexicographical order, which only compares the state
+ // when ordering the "Element" member of the pair.
+
+ class PairComparator {
+ public:
+ inline bool operator()(const std::pair &p1,
+ const std::pair &p2) {
+ if (p1.first < p2.first) {
+ return true;
+ } else if (p1.first > p2.first) {
+ return false;
+ } else {
+ return p1.second.state < p2.second.state;
+ }
+ }
+ };
+
+ // ProcessTransitions processes emitting transitions (transitions
+ // with ilabels) out of this subset of states.
+ // Does not consider final states. Breaks the emitting transitions up by
+ // ilabel, and creates a new transition in the determinized FST for each
+ // unique ilabel. Does this by creating a big vector of pairs
+ // and then sorting them using a lexicographical ordering, and calling
+ // ProcessTransition for each range with the same ilabel. Side effects on
+ // repository, and (via ProcessTransition) on Q_, hash_, and output_arcs_.
+
+ void ProcessTransitions(OutputStateId output_state) {
+ const std::vector &minimal_subset =
+ *(output_states_[output_state]);
+ // it's possible that minimal_subset could be empty if there are
+ // unreachable parts of the graph, so don't check that it's nonempty.
+ std::vector > &all_elems(
+ all_elems_tmp_); // use class member
+ // to avoid memory allocation/deallocation.
+ {
+ // Push back into "all_elems", elements corresponding to all
+ // non-epsilon-input transitions out of all states in "minimal_subset".
+ typename std::vector::const_iterator iter =
+ minimal_subset.begin(),
+ end = minimal_subset.end();
+ for (; iter != end; ++iter) {
+ const Element &elem = *iter;
+ for (ArcIterator > aiter(*ifst_, elem.state); !aiter.Done();
+ aiter.Next()) {
+ const Arc &arc = aiter.Value();
+ if (arc.ilabel != 0 &&
+ arc.weight != Weight::Zero()) { // Non-epsilon transition --
+ // ignore epsilons here.
+ std::pair this_pr;
+ this_pr.first = arc.ilabel;
+ Element &next_elem(this_pr.second);
+ next_elem.state = arc.nextstate;
+ next_elem.weight = Times(elem.weight, arc.weight);
+ if (arc.olabel == 0) // output epsilon
+ next_elem.string = elem.string;
+ else
+ next_elem.string = repository_.Successor(elem.string, arc.olabel);
+ all_elems.push_back(this_pr);
+ }
+ }
+ }
+ }
+ PairComparator pc;
+ std::sort(all_elems.begin(), all_elems.end(), pc);
+ // now sorted first on input label, then on state.
+ typedef typename std::vector >::const_iterator
+ PairIter;
+ PairIter cur = all_elems.begin(), end = all_elems.end();
+ std::vector this_subset;
+ while (cur != end) {
+ // Process ranges that share the same input symbol.
+ Label ilabel = cur->first;
+ this_subset.clear();
+ while (cur != end && cur->first == ilabel) {
+ this_subset.push_back(cur->second);
+ cur++;
+ }
+ // We now have a subset for this ilabel.
+ assert(!this_subset.empty()); // temp.
+ ProcessTransition(output_state, ilabel, &this_subset);
+ }
+ all_elems.clear(); // as it's a class variable-- want it to stay
+ // emtpy.
+ }
+
+ // ProcessState does the processing of a determinized state, i.e. it creates
+ // transitions out of it and the final-probability if any.
+ void ProcessState(OutputStateId output_state) {
+ ProcessFinal(output_state);
+ ProcessTransitions(output_state);
+ }
+
+ void Debug() { // this function called if you send a signal
+ // SIGUSR1 to the process (and it's caught by the handler in
+ // fstdeterminizestar). It prints out some traceback
+ // info and exits.
+
+ KALDI_WARN << "Debug function called (probably SIGUSR1 caught)";
+ // free up memory from the hash as we need a little memory
+ {
+ MinimalSubsetHash hash_tmp;
+ hash_tmp.swap(minimal_hash_);
+ }
+
+ if (output_arcs_.size() <= 2) {
+ KALDI_ERR << "Nothing to trace back";
+ }
+ size_t max_state = output_arcs_.size() - 2; // Don't take the last
+ // one as we might be halfway into constructing it.
+
+ std::vector predecessor(max_state + 1, kNoStateId);
+ for (size_t i = 0; i < max_state; i++) {
+ for (size_t j = 0; j < output_arcs_[i].size(); j++) {
+ OutputStateId nextstate = output_arcs_[i][j].nextstate;
+ // Always find an earlier-numbered predecessor; this
+ // is always possible because of the way the algorithm
+ // works.
+ if (nextstate <= max_state && nextstate > i) predecessor[nextstate] = i;
+ }
+ }
+ std::vector > traceback;
+ // 'traceback' is a pair of (ilabel, olabel-seq).
+ OutputStateId cur_state = max_state; // A recently constructed state.
+
+ while (cur_state != 0 && cur_state != kNoStateId) {
+ OutputStateId last_state = predecessor[cur_state];
+ std::pair p;
+ size_t i;
+ for (i = 0; i < output_arcs_[last_state].size(); i++) {
+ if (output_arcs_[last_state][i].nextstate == cur_state) {
+ p.first = output_arcs_[last_state][i].ilabel;
+ p.second = output_arcs_[last_state][i].string;
+ traceback.push_back(p);
+ break;
+ }
+ }
+ KALDI_ASSERT(i != output_arcs_[last_state].size()); // Or fell off loop.
+ cur_state = last_state;
+ }
+ if (cur_state == kNoStateId)
+ KALDI_WARN << "Traceback did not reach start state "
+ << "(possibly debug-code error)";
+
+ std::stringstream ss;
+ ss << "Traceback follows in format "
+ << "ilabel (olabel olabel) ilabel (olabel) ... :";
+ for (ssize_t i = traceback.size() - 1; i >= 0; i--) {
+ ss << ' ' << traceback[i].first << " ( ";
+ std::vector seq;
+ repository_.ConvertToVector(traceback[i].second, &seq);
+ for (size_t j = 0; j < seq.size(); j++) ss << seq[j] << ' ';
+ ss << ')';
+ }
+ KALDI_ERR << ss.str();
+ }
+
+ bool IsIsymbolOrFinal(InputStateId state) { // returns true if this state
+ // of the input FST either is final or has an osymbol on an arc out of it.
+ // Uses the vector isymbol_or_final_ as a cache for this info.
+ assert(state >= 0);
+ if (isymbol_or_final_.size() <= state)
+ isymbol_or_final_.resize(state + 1, static_cast(OSF_UNKNOWN));
+ if (isymbol_or_final_[state] == static_cast(OSF_NO))
+ return false;
+ else if (isymbol_or_final_[state] == static_cast(OSF_YES))
+ return true;
+ // else work it out...
+ isymbol_or_final_[state] = static_cast(OSF_NO);
+ if (ifst_->Final(state) != Weight::Zero())
+ isymbol_or_final_[state] = static_cast(OSF_YES);
+ for (ArcIterator > aiter(*ifst_, state); !aiter.Done();
+ aiter.Next()) {
+ const Arc &arc = aiter.Value();
+ if (arc.ilabel != 0 && arc.weight != Weight::Zero()) {
+ isymbol_or_final_[state] = static_cast(OSF_YES);
+ return true;
+ }
+ }
+ return IsIsymbolOrFinal(state); // will only recurse once.
+ }
+
+ void InitializeDeterminization() {
+ if (ifst_->Properties(kExpanded, false) != 0) { // if we know the number of
+ // states in ifst_, it might be a bit more efficient
+ // to pre-size the hashes so we're not constantly rebuilding them.
+#if !(__GNUC__ == 4 && __GNUC_MINOR__ == 0)
+ StateId num_states =
+ down_cast *, const Fst >(ifst_)
+ ->NumStates();
+ minimal_hash_.rehash(num_states / 2 + 3);
+ initial_hash_.rehash(num_states / 2 + 3);
+#endif
+ }
+ InputStateId start_id = ifst_->Start();
+ if (start_id != kNoStateId) {
+ /* Insert determinized-state corresponding to the start state into hash
+ and queue. Unlike all the other states, we don't "normalize" the
+ representation of this determinized-state before we put it into
+ minimal_hash_. This is actually what we want, as otherwise we'd have
+ problems dealing with any extra weight and string and might have to
+ create a "super-initial" state which would make the output
+ nondeterministic. Normalization is only needed to make the
+ determinized output more minimal anyway, it's not needed for
+ correctness. Note, we don't put anything in the initial_hash_. The
+ initial_hash_ is only a lookaside buffer anyway, so this isn't a
+ problem-- it will get populated later if it needs to be.
+ */
+ Element elem;
+ elem.state = start_id;
+ elem.weight = Weight::One();
+ elem.string = repository_.EmptyString(); // Id of empty sequence.
+ std::vector subset;
+ subset.push_back(elem);
+ EpsilonClosure(&subset); // follow through epsilon-inputs links
+ ConvertToMinimal(&subset); // remove all but final states and
+ // states with input-labels on arcs out of them.
+ std::vector *subset_ptr = new std::vector(subset);
+ assert(output_arcs_.empty() && output_states_.empty());
+ // add the new state...
+ output_states_.push_back(subset_ptr);
+ output_arcs_.push_back(std::vector());
+ OutputStateId initial_state = 0;
+ minimal_hash_[subset_ptr] = initial_state;
+ queue_.push_back(initial_state);
+ }
+ }
+
+ KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeDeterminizer);
+
+ std::vector *>
+ output_states_; // maps from output state to
+ // minimal representation [normalized].
+ // View pointers as owned in
+ // minimal_hash_.
+ std::vector >
+ output_arcs_; // essentially an FST in our format.
+
+ int num_arcs_; // keep track of memory usage: number of arcs in output_arcs_
+ int num_elems_; // keep track of memory usage: number of elems in
+ // output_states_
+
+ const Fst *ifst_;
+ DeterminizeLatticeOptions opts_;
+ SubsetKey hasher_; // object that computes keys-- has no data members.
+ SubsetEqual
+ equal_; // object that compares subsets-- only data member is delta_.
+ bool determinized_; // set to true when user called Determinize(); used to
+ // make
+ // sure this object is used correctly.
+ MinimalSubsetHash
+ minimal_hash_; // hash from Subset to OutputStateId. Subset is "minimal
+ // representation" (only include final and states and
+ // states with nonzero ilabel on arc out of them. Owns
+ // the pointers in its keys.
+ InitialSubsetHash initial_hash_; // hash from Subset to Element, which
+ // represents the OutputStateId together
+ // with an extra weight and string. Subset
+ // is "initial representation". The extra
+ // weight and string is needed because after
+ // we convert to minimal representation and
+ // normalize, there may be an extra weight
+ // and string. Owns the pointers
+ // in its keys.
+ std::vector
+ queue_; // Queue of output-states to process. Starts with
+ // state 0, and increases and then (hopefully) decreases in length during
+ // determinization. LIFO queue (queue discipline doesn't really matter).
+
+ std::vector >
+ all_elems_tmp_; // temporary vector used in ProcessTransitions.
+
+ enum IsymbolOrFinal { OSF_UNKNOWN = 0, OSF_NO = 1, OSF_YES = 2 };
+
+ std::vector isymbol_or_final_; // A kind of cache; it says whether
+ // each state is (emitting or final) where emitting means it has at least one
+ // non-epsilon output arc. Only accessed by IsIsymbolOrFinal()
+
+ LatticeStringRepository
+ repository_; // defines a compact and fast way of
+ // storing sequences of labels.
+};
+
+// normally Weight would be LatticeWeight (which has two floats),
+// or possibly TropicalWeightTpl, and IntType would be int32.
+template
+bool DeterminizeLattice(const Fst > &ifst,
+ MutableFst > *ofst,
+ DeterminizeLatticeOptions opts, bool *debug_ptr) {
+ ofst->SetInputSymbols(ifst.InputSymbols());
+ ofst->SetOutputSymbols(ifst.OutputSymbols());
+ LatticeDeterminizer det(ifst, opts);
+ if (!det.Determinize(debug_ptr)) return false;
+ det.Output(ofst);
+ return true;
+}
+
+// normally Weight would be LatticeWeight (which has two floats),
+// or possibly TropicalWeightTpl, and IntType would be int32.
+template
+bool DeterminizeLattice(
+ const Fst > &ifst,
+ MutableFst > > *ofst,
+ DeterminizeLatticeOptions opts, bool *debug_ptr) {
+ ofst->SetInputSymbols(ifst.InputSymbols());
+ ofst->SetOutputSymbols(ifst.OutputSymbols());
+ LatticeDeterminizer det(ifst, opts);
+ if (!det.Determinize(debug_ptr)) return false;
+ det.Output(ofst);
+ return true;
+}
+
+} // namespace fst
+
+#endif // KALDI_FSTEXT_DETERMINIZE_LATTICE_INL_H_
diff --git a/speechx/speechx/kaldi/fstext/determinize-lattice.h b/speechx/speechx/kaldi/fstext/determinize-lattice.h
new file mode 100644
index 000000000..4a4251197
--- /dev/null
+++ b/speechx/speechx/kaldi/fstext/determinize-lattice.h
@@ -0,0 +1,144 @@
+// fstext/determinize-lattice.h
+
+// Copyright 2009-2011 Microsoft Corporation
+
+// See ../../COPYING for clarification regarding multiple authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+// MERCHANTABLITY OR NON-INFRINGEMENT.
+// See the Apache 2 License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef KALDI_FSTEXT_DETERMINIZE_LATTICE_H_
+#define KALDI_FSTEXT_DETERMINIZE_LATTICE_H_
+#include
+#include
+#include
+#include
+#include
+#include
+#include "fstext/lattice-weight.h"
+
+namespace fst {
+
+/// \addtogroup fst_extensions
+/// @{
+
+// For example of usage, see test-determinize-lattice.cc
+
+/*
+ DeterminizeLattice implements a special form of determinization
+ with epsilon removal, optimized for a phase of lattice generation.
+ Its input is an FST with weight-type BaseWeightType (usually a pair of
+ floats, with a lexicographical type of order, such as
+ LatticeWeightTpl