diff --git a/README.md b/README.md index a90498293..5093dbd67 100644 --- a/README.md +++ b/README.md @@ -280,10 +280,14 @@ paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav For more information about server command lines, please see: [speech server demos](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/demos/speech_server) + + ## Model List PaddleSpeech supports a series of most popular models. They are summarized in [released models](./docs/source/released_model.md) and attached with available pretrained models. + + **Speech-to-Text** contains *Acoustic Model*, *Language Model*, and *Speech Translation*, with the following details: @@ -357,6 +361,8 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
+ + **Text-to-Speech** in PaddleSpeech mainly contains three modules: *Text Frontend*, *Acoustic Model* and *Vocoder*. Acoustic Model and Vocoder models are listed as follow: @@ -457,10 +463,10 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r - + @@ -473,6 +479,8 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
GE2E + Tactron2GE2E + Tacotron2 AISHELL-3 - ge2e-tactron2-aishell3 + ge2e-tacotron2-aishell3
+ + **Audio Classification** @@ -496,6 +504,8 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
+ + **Speaker Verification** @@ -519,6 +529,8 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
+ + **Punctuation Restoration** @@ -559,10 +571,18 @@ Normally, [Speech SoTA](https://paperswithcode.com/area/speech), [Audio SoTA](ht - [Advanced Usage](./docs/source/tts/advanced_usage.md) - [Chinese Rule Based Text Frontend](./docs/source/tts/zh_text_frontend.md) - [Test Audio Samples](https://paddlespeech.readthedocs.io/en/latest/tts/demo.html) + - Speaker Verification + - [Audio Searching](./demos/audio_searching/README.md) + - [Speaker Verification](./demos/speaker_verification/README.md) - [Audio Classification](./demos/audio_tagging/README.md) - - [Speaker Verification](./demos/speaker_verification/README.md) - [Speech Translation](./demos/speech_translation/README.md) + - [Speech Server](./demos/speech_server/README.md) - [Released Models](./docs/source/released_model.md) + - [Speech-to-Text](#SpeechToText) + - [Text-to-Speech](#TextToSpeech) + - [Audio Classification](#AudioClassification) + - [Speaker Verification](#SpeakerVerification) + - [Punctuation Restoration](#PunctuationRestoration) - [Community](#Community) - [Welcome to contribute](#contribution) - [License](#License) diff --git a/README_cn.md b/README_cn.md index ab4ce6e6b..5dab7fa0c 100644 --- a/README_cn.md +++ b/README_cn.md @@ -273,6 +273,8 @@ paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav ## 模型列表 PaddleSpeech 支持很多主流的模型,并提供了预训练模型,详情请见[模型列表](./docs/source/released_model.md)。 + + PaddleSpeech 的 **语音转文本** 包含语音识别声学模型、语音识别语言模型和语音翻译, 详情如下:
@@ -347,6 +349,7 @@ PaddleSpeech 的 **语音转文本** 包含语音识别声学模型、语音识
+ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声学模型和声码器。声学模型和声码器模型如下: @@ -447,10 +450,10 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声 - + @@ -488,6 +491,8 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
GE2E + Tactron2GE2E + Tacotron2 AISHELL-3 - ge2e-tactron2-aishell3 + ge2e-tacotron2-aishell3
+ + **声纹识别** @@ -511,6 +516,8 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
+ + **标点恢复** @@ -556,13 +563,18 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声 - [进阶用法](./docs/source/tts/advanced_usage.md) - [中文文本前端](./docs/source/tts/zh_text_frontend.md) - [测试语音样本](https://paddlespeech.readthedocs.io/en/latest/tts/demo.html) + - 声纹识别 + - [声纹识别](./demos/speaker_verification/README_cn.md) + - [音频检索](./demos/audio_searching/README_cn.md) - [声音分类](./demos/audio_tagging/README_cn.md) - - [声纹识别](./demos/speaker_verification/README_cn.md) - [语音翻译](./demos/speech_translation/README_cn.md) + - [服务化部署](./demos/speech_server/README_cn.md) - [模型列表](#模型列表) - [语音识别](#语音识别模型) - [语音合成](#语音合成模型) - [声音分类](#声音分类模型) + - [声纹识别](#声纹识别模型) + - [标点恢复](#标点恢复模型) - [技术交流群](#技术交流群) - [欢迎贡献](#欢迎贡献) - [License](#License) diff --git a/dataset/rir_noise/rir_noise.py b/dataset/rir_noise/rir_noise.py index e7b122890..009175e5b 100644 --- a/dataset/rir_noise/rir_noise.py +++ b/dataset/rir_noise/rir_noise.py @@ -34,14 +34,14 @@ from utils.utility import unzip DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') -URL_ROOT = 'http://www.openslr.org/resources/28' +URL_ROOT = '--no-check-certificate http://www.openslr.org/resources/28' DATA_URL = URL_ROOT + '/rirs_noises.zip' MD5_DATA = 'e6f48e257286e05de56413b4779d8ffb' parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--target_dir", - default=DATA_HOME + "/Aishell", + default=DATA_HOME + "/rirs_noise", type=str, help="Directory to save the dataset. (default: %(default)s)") parser.add_argument( @@ -81,6 +81,10 @@ def create_manifest(data_dir, manifest_path_prefix): }, ensure_ascii=False)) manifest_path = manifest_path_prefix + '.' + dtype + + if not os.path.exists(os.path.dirname(manifest_path)): + os.makedirs(os.path.dirname(manifest_path)) + with codecs.open(manifest_path, 'w', 'utf-8') as fout: for line in json_lines: fout.write(line + '\n') diff --git a/dataset/voxceleb/voxceleb1.py b/dataset/voxceleb/voxceleb1.py index 905862008..95827f708 100644 --- a/dataset/voxceleb/voxceleb1.py +++ b/dataset/voxceleb/voxceleb1.py @@ -149,7 +149,7 @@ def prepare_dataset(base_url, data_list, target_dir, manifest_path, # we will download the voxceleb1 data to ${target_dir}/vox1/dev/ or ${target_dir}/vox1/test directory if not os.path.exists(os.path.join(target_dir, "wav")): # download all dataset part - print("start to download the vox1 dev zip package") + print(f"start to download the vox1 zip package to {target_dir}") for zip_part in data_list.keys(): download_url = " --no-check-certificate " + base_url + "/" + zip_part download( diff --git a/dataset/voxceleb/voxceleb2.py b/dataset/voxceleb/voxceleb2.py index 22a2e2ffe..fe9e8b9c8 100644 --- a/dataset/voxceleb/voxceleb2.py +++ b/dataset/voxceleb/voxceleb2.py @@ -22,10 +22,12 @@ import codecs import glob import json import os +import subprocess from pathlib import Path import soundfile +from utils.utility import check_md5sum from utils.utility import download from utils.utility import unzip @@ -35,12 +37,22 @@ DATA_HOME = os.path.expanduser('.') BASE_URL = "--no-check-certificate https://www.robots.ox.ac.uk/~vgg/data/voxceleb/data/" # dev data -DEV_DATA_URL = BASE_URL + '/vox2_aac.zip' -DEV_MD5SUM = "bbc063c46078a602ca71605645c2a402" +DEV_LIST = { + "vox2_dev_aac_partaa": "da070494c573e5c0564b1d11c3b20577", + "vox2_dev_aac_partab": "17fe6dab2b32b48abaf1676429cdd06f", + "vox2_dev_aac_partac": "1de58e086c5edf63625af1cb6d831528", + "vox2_dev_aac_partad": "5a043eb03e15c5a918ee6a52aad477f9", + "vox2_dev_aac_partae": "cea401b624983e2d0b2a87fb5d59aa60", + "vox2_dev_aac_partaf": "fc886d9ba90ab88e7880ee98effd6ae9", + "vox2_dev_aac_partag": "d160ecc3f6ee3eed54d55349531cb42e", + "vox2_dev_aac_partah": "6b84a81b9af72a9d9eecbb3b1f602e65", +} + +DEV_TARGET_DATA = "vox2_dev_aac_parta* vox2_dev_aac.zip bbc063c46078a602ca71605645c2a402" # test data -TEST_DATA_URL = BASE_URL + '/vox2_test_aac.zip' -TEST_MD5SUM = "0d2b3ea430a821c33263b5ea37ede312" +TEST_LIST = {"vox2_test_aac.zip": "0d2b3ea430a821c33263b5ea37ede312"} +TEST_TARGET_DATA = "vox2_test_aac.zip vox2_test_aac.zip 0d2b3ea430a821c33263b5ea37ede312" parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( @@ -68,6 +80,14 @@ args = parser.parse_args() def create_manifest(data_dir, manifest_path_prefix): + """Generate the voxceleb2 dataset manifest file. + We will create the ${manifest_path_prefix}.vox2 as the final manifest file + The dev and test wav info will be put in one manifest file. + + Args: + data_dir (str): voxceleb2 wav directory, which include dev and test subdataset + manifest_path_prefix (str): manifest file prefix + """ print("Creating manifest %s ..." % manifest_path_prefix) json_lines = [] data_path = os.path.join(data_dir, "**", "*.wav") @@ -119,7 +139,19 @@ def create_manifest(data_dir, manifest_path_prefix): print(f"{total_sec / total_num} sec/utt", file=f) -def download_dataset(url, md5sum, target_dir, dataset): +def download_dataset(base_url, data_list, target_data, target_dir, dataset): + """Download the voxceleb2 zip package + + Args: + base_url (str): the voxceleb2 dataset download baseline url + data_list (dict): the dataset part zip package and the md5 value + target_data (str): the final dataset zip info + target_dir (str): the dataset stored directory + dataset (str): the dataset name, dev or test + + Raises: + RuntimeError: the md5sum occurs error + """ if not os.path.exists(target_dir): os.makedirs(target_dir) @@ -129,9 +161,34 @@ def download_dataset(url, md5sum, target_dir, dataset): # but the test dataset will unzip to aac # so, wo create the ${target_dir}/test and unzip the m4a to test dir if not os.path.exists(os.path.join(target_dir, dataset)): - filepath = download(url, md5sum, target_dir) + print(f"start to download the vox2 zip package to {target_dir}") + for zip_part in data_list.keys(): + download_url = " --no-check-certificate " + base_url + "/" + zip_part + download( + url=download_url, + md5sum=data_list[zip_part], + target_dir=target_dir) + + # pack the all part to target zip file + all_target_part, target_name, target_md5sum = target_data.split() + target_name = os.path.join(target_dir, target_name) + if not os.path.exists(target_name): + pack_part_cmd = "cat {}/{} > {}".format(target_dir, all_target_part, + target_name) + subprocess.call(pack_part_cmd, shell=True) + + # check the target zip file md5sum + if not check_md5sum(target_name, target_md5sum): + raise RuntimeError("{} MD5 checkssum failed".format(target_name)) + else: + print("Check {} md5sum successfully".format(target_name)) + if dataset == "test": - unzip(filepath, os.path.join(target_dir, "test")) + # we need make the test directory + unzip(target_name, os.path.join(target_dir, "test")) + else: + # upzip dev zip pacakge and will create the dev directory + unzip(target_name, target_dir) def main(): @@ -142,14 +199,16 @@ def main(): print("download: {}".format(args.download)) if args.download: download_dataset( - url=DEV_DATA_URL, - md5sum=DEV_MD5SUM, + base_url=BASE_URL, + data_list=DEV_LIST, + target_data=DEV_TARGET_DATA, target_dir=args.target_dir, dataset="dev") download_dataset( - url=TEST_DATA_URL, - md5sum=TEST_MD5SUM, + base_url=BASE_URL, + data_list=TEST_LIST, + target_data=TEST_TARGET_DATA, target_dir=args.target_dir, dataset="test") diff --git a/demos/speaker_verification/README.md b/demos/speaker_verification/README.md index 8739d402d..7d7180ae9 100644 --- a/demos/speaker_verification/README.md +++ b/demos/speaker_verification/README.md @@ -30,6 +30,11 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav paddlespeech vector --task spk --input vec.job echo -e "demo2 85236145389.wav \n demo3 85236145389.wav" | paddlespeech vector --task spk + + paddlespeech vector --task score --input "./85236145389.wav ./123456789.wav" + + echo -e "demo4 85236145389.wav 85236145389.wav \n demo5 85236145389.wav 123456789.wav" > vec.job + paddlespeech vector --task score --input vec.job ``` Usage: @@ -38,6 +43,7 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav ``` Arguments: - `input`(required): Audio file to recognize. + - `task` (required): Specify `vector` task. Default `spk`。 - `model`: Model type of vector task. Default: `ecapatdnn_voxceleb12`. - `sample_rate`: Sample rate of the model. Default: `16000`. - `config`: Config of vector task. Use pretrained model when it is None. Default: `None`. @@ -47,45 +53,45 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav Output: ```bash - demo [ -5.749211 9.505463 -8.200284 -5.2075014 5.3940268 - -3.04878 1.611095 10.127234 -10.534177 -15.821609 - 1.2032688 -0.35080156 1.2629458 -12.643498 -2.5758228 - -11.343508 2.3385992 -8.719341 14.213509 15.404744 - -0.39327756 6.338786 2.688887 8.7104025 17.469526 - -8.77959 7.0576906 4.648855 -1.3089896 -23.294737 - 8.013747 13.891729 -9.926753 5.655307 -5.9422326 - -22.842539 0.6293588 -18.46266 -10.811862 9.8192625 - 3.0070958 3.8072643 -2.3861165 3.0821571 -14.739942 - 1.7594414 -0.6485091 4.485623 2.0207152 7.264915 - -6.40137 23.63524 2.9711294 -22.708025 9.93719 - 20.354511 -10.324688 -0.700492 -8.783211 -5.27593 - 15.999649 3.3004563 12.747926 15.429879 4.7849145 - 5.6699696 -2.3826702 10.605882 3.9112158 3.1500628 - 15.859915 -2.1832209 -23.908653 -6.4799504 -4.5365124 - -9.224193 14.568347 -10.568833 4.982321 -4.342062 - 0.0914714 12.645902 -5.74285 -3.2141201 -2.7173362 - -6.680575 0.4757669 -5.035051 -6.7964664 16.865469 - -11.54324 7.681869 0.44475392 9.708182 -8.932846 - 0.4123232 -4.361452 1.3948607 9.511665 0.11667654 - 2.9079323 6.049952 9.275183 -18.078873 6.2983274 - -0.7500531 -2.725033 -7.6027865 3.3404543 2.990815 - 4.010979 11.000591 -2.8873312 7.1352735 -16.79663 - 18.495346 -14.293832 7.89578 2.2714825 22.976387 - -4.875734 -3.0836344 -2.9999814 13.751918 6.448228 - -11.924197 2.171869 2.0423572 -6.173772 10.778437 - 25.77281 -4.9495463 14.57806 0.3044315 2.6132357 - -7.591999 -2.076944 9.025118 1.7834753 -3.1799617 - -4.9401326 23.465864 5.1685796 -9.018578 9.037825 - -4.4150195 6.859591 -12.274467 -0.88911164 5.186309 - -3.9988663 -13.638606 -9.925445 -0.06329413 -3.6709652 - -12.397416 -12.719869 -1.395601 2.1150916 5.7381287 - -4.4691963 -3.82819 -0.84233856 -1.1604277 -13.490127 - 8.731719 -20.778936 -11.495662 5.8033476 -4.752041 - 10.833007 -6.717991 4.504732 13.4244375 1.1306485 - 7.3435574 1.400918 14.704036 -9.501399 7.2315617 - -6.417456 1.3333273 11.872697 -0.30664724 8.8845 - 6.5569253 4.7948146 0.03662816 -8.704245 6.224871 - -3.2701402 -11.508579 ] + demo [ 1.4217498 5.626253 -5.342073 1.1773866 3.308055 + 1.756596 5.167894 10.80636 -3.8226728 -5.6141334 + 2.623845 -0.8072968 1.9635103 -7.3128724 0.01103897 + -9.723131 0.6619743 -6.976803 10.213478 7.494748 + 2.9105635 3.8949256 3.7999806 7.1061673 16.905321 + -7.1493764 8.733103 3.4230042 -4.831653 -11.403367 + 11.232214 7.1274667 -4.2828417 2.452362 -5.130748 + -18.177666 -2.6116815 -11.000337 -6.7314315 1.6564683 + 0.7618269 1.1253023 -2.083836 4.725744 -8.782597 + -3.539873 3.814236 5.1420674 2.162061 4.096431 + -6.4162116 12.747448 1.9429878 -15.152943 6.417416 + 16.097002 -9.716668 -1.9920526 -3.3649497 -1.871939 + 11.567354 3.69788 11.258265 7.442363 9.183411 + 4.5281515 -1.2417862 4.3959084 6.6727695 5.8898783 + 7.627124 -0.66919386 -11.889693 -9.208865 -7.4274073 + -3.7776625 6.917234 -9.848748 -2.0944717 -5.135116 + 0.49563864 9.317534 -5.9141874 -1.8098574 -0.11738578 + -7.169265 -1.0578263 -5.7216787 -5.1173844 16.137651 + -4.473626 7.6624317 -0.55381083 9.631587 -6.4704556 + -8.548508 4.3716145 -0.79702514 4.478997 -2.9758704 + 3.272176 2.8382776 5.134597 -9.190781 -0.5657382 + -4.8745747 2.3165567 -5.984303 -2.1798875 0.35541576 + -0.31784213 9.493548 2.1144536 4.358092 -12.089823 + 8.451689 -7.925461 4.6242585 4.4289427 18.692003 + -2.6204622 -5.149185 -0.35821092 8.488551 4.981496 + -9.32683 -2.2544234 6.6417594 1.2119585 10.977129 + 16.555033 3.3238444 9.551863 -1.6676947 -0.79539716 + -8.605674 -0.47356385 2.6741948 -5.359179 -2.6673796 + 0.66607 15.443222 4.740594 -3.4725387 11.592567 + -2.054497 1.7361217 -8.265324 -9.30447 5.4068313 + -1.5180256 -7.746615 -6.089606 0.07112726 -0.34904733 + -8.649895 -9.998958 -2.564841 -0.53999114 2.601808 + -0.31927416 -1.8815292 -2.07215 -3.4105783 -8.2998085 + 1.483641 -15.365992 -8.288208 3.8847756 -3.4876456 + 7.3629923 0.4657332 3.132599 12.438889 -1.8337058 + 4.532936 2.7264361 10.145339 -6.521951 2.897153 + -3.3925855 5.079156 7.759716 4.677565 5.8457737 + 2.402413 7.7071047 3.9711342 -6.390043 6.1268735 + -3.7760346 -11.118123 ] ``` - Python API @@ -97,56 +103,113 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav audio_emb = vector_executor( model='ecapatdnn_voxceleb12', sample_rate=16000, - config=None, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. ckpt_path=None, audio_file='./85236145389.wav', - force_yes=False, device=paddle.get_device()) print('Audio embedding Result: \n{}'.format(audio_emb)) + + test_emb = vector_executor( + model='ecapatdnn_voxceleb12', + sample_rate=16000, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. + ckpt_path=None, + audio_file='./123456789.wav', + device=paddle.get_device()) + print('Test embedding Result: \n{}'.format(test_emb)) + + # score range [0, 1] + score = vector_executor.get_embeddings_score(audio_emb, test_emb) + print(f"Eembeddings Score: {score}") ``` - Output: + Output: + ```bash # Vector Result: - [ -5.749211 9.505463 -8.200284 -5.2075014 5.3940268 - -3.04878 1.611095 10.127234 -10.534177 -15.821609 - 1.2032688 -0.35080156 1.2629458 -12.643498 -2.5758228 - -11.343508 2.3385992 -8.719341 14.213509 15.404744 - -0.39327756 6.338786 2.688887 8.7104025 17.469526 - -8.77959 7.0576906 4.648855 -1.3089896 -23.294737 - 8.013747 13.891729 -9.926753 5.655307 -5.9422326 - -22.842539 0.6293588 -18.46266 -10.811862 9.8192625 - 3.0070958 3.8072643 -2.3861165 3.0821571 -14.739942 - 1.7594414 -0.6485091 4.485623 2.0207152 7.264915 - -6.40137 23.63524 2.9711294 -22.708025 9.93719 - 20.354511 -10.324688 -0.700492 -8.783211 -5.27593 - 15.999649 3.3004563 12.747926 15.429879 4.7849145 - 5.6699696 -2.3826702 10.605882 3.9112158 3.1500628 - 15.859915 -2.1832209 -23.908653 -6.4799504 -4.5365124 - -9.224193 14.568347 -10.568833 4.982321 -4.342062 - 0.0914714 12.645902 -5.74285 -3.2141201 -2.7173362 - -6.680575 0.4757669 -5.035051 -6.7964664 16.865469 - -11.54324 7.681869 0.44475392 9.708182 -8.932846 - 0.4123232 -4.361452 1.3948607 9.511665 0.11667654 - 2.9079323 6.049952 9.275183 -18.078873 6.2983274 - -0.7500531 -2.725033 -7.6027865 3.3404543 2.990815 - 4.010979 11.000591 -2.8873312 7.1352735 -16.79663 - 18.495346 -14.293832 7.89578 2.2714825 22.976387 - -4.875734 -3.0836344 -2.9999814 13.751918 6.448228 - -11.924197 2.171869 2.0423572 -6.173772 10.778437 - 25.77281 -4.9495463 14.57806 0.3044315 2.6132357 - -7.591999 -2.076944 9.025118 1.7834753 -3.1799617 - -4.9401326 23.465864 5.1685796 -9.018578 9.037825 - -4.4150195 6.859591 -12.274467 -0.88911164 5.186309 - -3.9988663 -13.638606 -9.925445 -0.06329413 -3.6709652 - -12.397416 -12.719869 -1.395601 2.1150916 5.7381287 - -4.4691963 -3.82819 -0.84233856 -1.1604277 -13.490127 - 8.731719 -20.778936 -11.495662 5.8033476 -4.752041 - 10.833007 -6.717991 4.504732 13.4244375 1.1306485 - 7.3435574 1.400918 14.704036 -9.501399 7.2315617 - -6.417456 1.3333273 11.872697 -0.30664724 8.8845 - 6.5569253 4.7948146 0.03662816 -8.704245 6.224871 - -3.2701402 -11.508579 ] + Audio embedding Result: + [ 1.4217498 5.626253 -5.342073 1.1773866 3.308055 + 1.756596 5.167894 10.80636 -3.8226728 -5.6141334 + 2.623845 -0.8072968 1.9635103 -7.3128724 0.01103897 + -9.723131 0.6619743 -6.976803 10.213478 7.494748 + 2.9105635 3.8949256 3.7999806 7.1061673 16.905321 + -7.1493764 8.733103 3.4230042 -4.831653 -11.403367 + 11.232214 7.1274667 -4.2828417 2.452362 -5.130748 + -18.177666 -2.6116815 -11.000337 -6.7314315 1.6564683 + 0.7618269 1.1253023 -2.083836 4.725744 -8.782597 + -3.539873 3.814236 5.1420674 2.162061 4.096431 + -6.4162116 12.747448 1.9429878 -15.152943 6.417416 + 16.097002 -9.716668 -1.9920526 -3.3649497 -1.871939 + 11.567354 3.69788 11.258265 7.442363 9.183411 + 4.5281515 -1.2417862 4.3959084 6.6727695 5.8898783 + 7.627124 -0.66919386 -11.889693 -9.208865 -7.4274073 + -3.7776625 6.917234 -9.848748 -2.0944717 -5.135116 + 0.49563864 9.317534 -5.9141874 -1.8098574 -0.11738578 + -7.169265 -1.0578263 -5.7216787 -5.1173844 16.137651 + -4.473626 7.6624317 -0.55381083 9.631587 -6.4704556 + -8.548508 4.3716145 -0.79702514 4.478997 -2.9758704 + 3.272176 2.8382776 5.134597 -9.190781 -0.5657382 + -4.8745747 2.3165567 -5.984303 -2.1798875 0.35541576 + -0.31784213 9.493548 2.1144536 4.358092 -12.089823 + 8.451689 -7.925461 4.6242585 4.4289427 18.692003 + -2.6204622 -5.149185 -0.35821092 8.488551 4.981496 + -9.32683 -2.2544234 6.6417594 1.2119585 10.977129 + 16.555033 3.3238444 9.551863 -1.6676947 -0.79539716 + -8.605674 -0.47356385 2.6741948 -5.359179 -2.6673796 + 0.66607 15.443222 4.740594 -3.4725387 11.592567 + -2.054497 1.7361217 -8.265324 -9.30447 5.4068313 + -1.5180256 -7.746615 -6.089606 0.07112726 -0.34904733 + -8.649895 -9.998958 -2.564841 -0.53999114 2.601808 + -0.31927416 -1.8815292 -2.07215 -3.4105783 -8.2998085 + 1.483641 -15.365992 -8.288208 3.8847756 -3.4876456 + 7.3629923 0.4657332 3.132599 12.438889 -1.8337058 + 4.532936 2.7264361 10.145339 -6.521951 2.897153 + -3.3925855 5.079156 7.759716 4.677565 5.8457737 + 2.402413 7.7071047 3.9711342 -6.390043 6.1268735 + -3.7760346 -11.118123 ] + # get the test embedding + Test embedding Result: + [ -1.902964 2.0690894 -8.034194 3.5472693 0.18089125 + 6.9085927 1.4097427 -1.9487704 -10.021278 -0.20755845 + -8.04332 4.344489 2.3200977 -14.306299 5.184692 + -11.55602 -3.8497238 0.6444722 1.2833948 2.6766639 + 0.5878921 0.7946299 1.7207596 2.5791872 14.998469 + -1.3385371 15.031221 -0.8006958 1.99287 -9.52007 + 2.435466 4.003221 -4.33817 -4.898601 -5.304714 + -18.033886 10.790787 -12.784645 -5.641755 2.9761686 + -10.566622 1.4839455 6.152458 -5.7195854 2.8603241 + 6.112133 8.489869 5.5958056 1.2836679 -1.2293907 + 0.89927405 7.0288725 -2.854029 -0.9782962 5.8255906 + 14.905906 -5.025907 0.7866458 -4.2444224 -16.354029 + 10.521315 0.9604709 -3.3257897 7.144871 -13.592733 + -8.568869 -1.7953678 0.26313916 10.916714 -6.9374123 + 1.857403 -6.2746415 2.8154466 -7.2338667 -2.293357 + -0.05452765 5.4287076 5.0849075 -6.690375 -1.6183422 + 3.654291 0.94352573 -9.200294 -5.4749465 -3.5235846 + 1.3420814 4.240421 -2.772944 -2.8451524 16.311104 + 4.2969875 -1.762936 -12.5758915 8.595198 -0.8835239 + -1.5708797 1.568961 1.1413603 3.5032008 -0.45251232 + -6.786333 16.89443 5.3366146 -8.789056 0.6355629 + 3.2579517 -3.328322 7.5969577 0.66025066 -6.550468 + -9.148656 2.020372 -0.4615173 1.1965656 -3.8764873 + 11.6562195 -6.0750933 12.182899 3.2218833 0.81969476 + 5.570001 -3.8459578 -7.205299 7.9262037 -7.6611166 + -5.249467 -2.2671914 7.2658715 -13.298164 4.821147 + -2.7263982 11.691089 -3.8918593 -2.838112 -1.0336838 + -3.8034165 2.8536487 -5.60398 -1.1972581 1.3455094 + -3.4903061 2.2408795 5.5010734 -3.970756 11.99696 + -7.8858757 0.43160373 -5.5059714 4.3426995 16.322706 + 11.635366 0.72157705 -9.245714 -3.91465 -4.449838 + -1.5716927 7.713747 -2.2430465 -6.198303 -13.481864 + 2.8156567 -5.7812386 5.1456156 2.7289324 -14.505571 + 13.270688 3.448231 -7.0659585 4.5886116 -4.466099 + -0.296428 -11.463529 -2.6076477 14.110243 -6.9725137 + -1.9962958 2.7119343 19.391657 0.01961198 14.607133 + -1.6695905 -4.391516 1.3131028 -6.670972 -5.888604 + 12.0612335 5.9285784 3.3715196 1.492534 10.723728 + -0.95514804 -12.085431 ] + # get the score between enroll and test + Eembeddings Score: 0.4292638301849365 ``` ### 4.Pretrained Models diff --git a/demos/speaker_verification/README_cn.md b/demos/speaker_verification/README_cn.md index fe8949b3c..db382f298 100644 --- a/demos/speaker_verification/README_cn.md +++ b/demos/speaker_verification/README_cn.md @@ -29,6 +29,11 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav paddlespeech vector --task spk --input vec.job echo -e "demo2 85236145389.wav \n demo3 85236145389.wav" | paddlespeech vector --task spk + + paddlespeech vector --task score --input "./85236145389.wav ./123456789.wav" + + echo -e "demo4 85236145389.wav 85236145389.wav \n demo5 85236145389.wav 123456789.wav" > vec.job + paddlespeech vector --task score --input vec.job ``` 使用方法: @@ -37,6 +42,7 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav ``` 参数: - `input`(必须输入):用于识别的音频文件。 + - `task` (必须输入): 用于指定 `vector` 处理的具体任务,默认是 `spk`。 - `model`:声纹任务的模型,默认值:`ecapatdnn_voxceleb12`。 - `sample_rate`:音频采样率,默认值:`16000`。 - `config`:声纹任务的参数文件,若不设置则使用预训练模型中的默认配置,默认值:`None`。 @@ -45,45 +51,45 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav 输出: ```bash - demo [ -5.749211 9.505463 -8.200284 -5.2075014 5.3940268 - -3.04878 1.611095 10.127234 -10.534177 -15.821609 - 1.2032688 -0.35080156 1.2629458 -12.643498 -2.5758228 - -11.343508 2.3385992 -8.719341 14.213509 15.404744 - -0.39327756 6.338786 2.688887 8.7104025 17.469526 - -8.77959 7.0576906 4.648855 -1.3089896 -23.294737 - 8.013747 13.891729 -9.926753 5.655307 -5.9422326 - -22.842539 0.6293588 -18.46266 -10.811862 9.8192625 - 3.0070958 3.8072643 -2.3861165 3.0821571 -14.739942 - 1.7594414 -0.6485091 4.485623 2.0207152 7.264915 - -6.40137 23.63524 2.9711294 -22.708025 9.93719 - 20.354511 -10.324688 -0.700492 -8.783211 -5.27593 - 15.999649 3.3004563 12.747926 15.429879 4.7849145 - 5.6699696 -2.3826702 10.605882 3.9112158 3.1500628 - 15.859915 -2.1832209 -23.908653 -6.4799504 -4.5365124 - -9.224193 14.568347 -10.568833 4.982321 -4.342062 - 0.0914714 12.645902 -5.74285 -3.2141201 -2.7173362 - -6.680575 0.4757669 -5.035051 -6.7964664 16.865469 - -11.54324 7.681869 0.44475392 9.708182 -8.932846 - 0.4123232 -4.361452 1.3948607 9.511665 0.11667654 - 2.9079323 6.049952 9.275183 -18.078873 6.2983274 - -0.7500531 -2.725033 -7.6027865 3.3404543 2.990815 - 4.010979 11.000591 -2.8873312 7.1352735 -16.79663 - 18.495346 -14.293832 7.89578 2.2714825 22.976387 - -4.875734 -3.0836344 -2.9999814 13.751918 6.448228 - -11.924197 2.171869 2.0423572 -6.173772 10.778437 - 25.77281 -4.9495463 14.57806 0.3044315 2.6132357 - -7.591999 -2.076944 9.025118 1.7834753 -3.1799617 - -4.9401326 23.465864 5.1685796 -9.018578 9.037825 - -4.4150195 6.859591 -12.274467 -0.88911164 5.186309 - -3.9988663 -13.638606 -9.925445 -0.06329413 -3.6709652 - -12.397416 -12.719869 -1.395601 2.1150916 5.7381287 - -4.4691963 -3.82819 -0.84233856 -1.1604277 -13.490127 - 8.731719 -20.778936 -11.495662 5.8033476 -4.752041 - 10.833007 -6.717991 4.504732 13.4244375 1.1306485 - 7.3435574 1.400918 14.704036 -9.501399 7.2315617 - -6.417456 1.3333273 11.872697 -0.30664724 8.8845 - 6.5569253 4.7948146 0.03662816 -8.704245 6.224871 - -3.2701402 -11.508579 ] + demo [ 1.4217498 5.626253 -5.342073 1.1773866 3.308055 + 1.756596 5.167894 10.80636 -3.8226728 -5.6141334 + 2.623845 -0.8072968 1.9635103 -7.3128724 0.01103897 + -9.723131 0.6619743 -6.976803 10.213478 7.494748 + 2.9105635 3.8949256 3.7999806 7.1061673 16.905321 + -7.1493764 8.733103 3.4230042 -4.831653 -11.403367 + 11.232214 7.1274667 -4.2828417 2.452362 -5.130748 + -18.177666 -2.6116815 -11.000337 -6.7314315 1.6564683 + 0.7618269 1.1253023 -2.083836 4.725744 -8.782597 + -3.539873 3.814236 5.1420674 2.162061 4.096431 + -6.4162116 12.747448 1.9429878 -15.152943 6.417416 + 16.097002 -9.716668 -1.9920526 -3.3649497 -1.871939 + 11.567354 3.69788 11.258265 7.442363 9.183411 + 4.5281515 -1.2417862 4.3959084 6.6727695 5.8898783 + 7.627124 -0.66919386 -11.889693 -9.208865 -7.4274073 + -3.7776625 6.917234 -9.848748 -2.0944717 -5.135116 + 0.49563864 9.317534 -5.9141874 -1.8098574 -0.11738578 + -7.169265 -1.0578263 -5.7216787 -5.1173844 16.137651 + -4.473626 7.6624317 -0.55381083 9.631587 -6.4704556 + -8.548508 4.3716145 -0.79702514 4.478997 -2.9758704 + 3.272176 2.8382776 5.134597 -9.190781 -0.5657382 + -4.8745747 2.3165567 -5.984303 -2.1798875 0.35541576 + -0.31784213 9.493548 2.1144536 4.358092 -12.089823 + 8.451689 -7.925461 4.6242585 4.4289427 18.692003 + -2.6204622 -5.149185 -0.35821092 8.488551 4.981496 + -9.32683 -2.2544234 6.6417594 1.2119585 10.977129 + 16.555033 3.3238444 9.551863 -1.6676947 -0.79539716 + -8.605674 -0.47356385 2.6741948 -5.359179 -2.6673796 + 0.66607 15.443222 4.740594 -3.4725387 11.592567 + -2.054497 1.7361217 -8.265324 -9.30447 5.4068313 + -1.5180256 -7.746615 -6.089606 0.07112726 -0.34904733 + -8.649895 -9.998958 -2.564841 -0.53999114 2.601808 + -0.31927416 -1.8815292 -2.07215 -3.4105783 -8.2998085 + 1.483641 -15.365992 -8.288208 3.8847756 -3.4876456 + 7.3629923 0.4657332 3.132599 12.438889 -1.8337058 + 4.532936 2.7264361 10.145339 -6.521951 2.897153 + -3.3925855 5.079156 7.759716 4.677565 5.8457737 + 2.402413 7.7071047 3.9711342 -6.390043 6.1268735 + -3.7760346 -11.118123 ] ``` - Python API @@ -98,53 +104,109 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav config=None, # Set `config` and `ckpt_path` to None to use pretrained model. ckpt_path=None, audio_file='./85236145389.wav', - force_yes=False, device=paddle.get_device()) print('Audio embedding Result: \n{}'.format(audio_emb)) + + test_emb = vector_executor( + model='ecapatdnn_voxceleb12', + sample_rate=16000, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. + ckpt_path=None, + audio_file='./123456789.wav', + device=paddle.get_device()) + print('Test embedding Result: \n{}'.format(test_emb)) + + # score range [0, 1] + score = vector_executor.get_embeddings_score(audio_emb, test_emb) + print(f"Eembeddings Score: {score}") ``` 输出: ```bash # Vector Result: - [ -5.749211 9.505463 -8.200284 -5.2075014 5.3940268 - -3.04878 1.611095 10.127234 -10.534177 -15.821609 - 1.2032688 -0.35080156 1.2629458 -12.643498 -2.5758228 - -11.343508 2.3385992 -8.719341 14.213509 15.404744 - -0.39327756 6.338786 2.688887 8.7104025 17.469526 - -8.77959 7.0576906 4.648855 -1.3089896 -23.294737 - 8.013747 13.891729 -9.926753 5.655307 -5.9422326 - -22.842539 0.6293588 -18.46266 -10.811862 9.8192625 - 3.0070958 3.8072643 -2.3861165 3.0821571 -14.739942 - 1.7594414 -0.6485091 4.485623 2.0207152 7.264915 - -6.40137 23.63524 2.9711294 -22.708025 9.93719 - 20.354511 -10.324688 -0.700492 -8.783211 -5.27593 - 15.999649 3.3004563 12.747926 15.429879 4.7849145 - 5.6699696 -2.3826702 10.605882 3.9112158 3.1500628 - 15.859915 -2.1832209 -23.908653 -6.4799504 -4.5365124 - -9.224193 14.568347 -10.568833 4.982321 -4.342062 - 0.0914714 12.645902 -5.74285 -3.2141201 -2.7173362 - -6.680575 0.4757669 -5.035051 -6.7964664 16.865469 - -11.54324 7.681869 0.44475392 9.708182 -8.932846 - 0.4123232 -4.361452 1.3948607 9.511665 0.11667654 - 2.9079323 6.049952 9.275183 -18.078873 6.2983274 - -0.7500531 -2.725033 -7.6027865 3.3404543 2.990815 - 4.010979 11.000591 -2.8873312 7.1352735 -16.79663 - 18.495346 -14.293832 7.89578 2.2714825 22.976387 - -4.875734 -3.0836344 -2.9999814 13.751918 6.448228 - -11.924197 2.171869 2.0423572 -6.173772 10.778437 - 25.77281 -4.9495463 14.57806 0.3044315 2.6132357 - -7.591999 -2.076944 9.025118 1.7834753 -3.1799617 - -4.9401326 23.465864 5.1685796 -9.018578 9.037825 - -4.4150195 6.859591 -12.274467 -0.88911164 5.186309 - -3.9988663 -13.638606 -9.925445 -0.06329413 -3.6709652 - -12.397416 -12.719869 -1.395601 2.1150916 5.7381287 - -4.4691963 -3.82819 -0.84233856 -1.1604277 -13.490127 - 8.731719 -20.778936 -11.495662 5.8033476 -4.752041 - 10.833007 -6.717991 4.504732 13.4244375 1.1306485 - 7.3435574 1.400918 14.704036 -9.501399 7.2315617 - -6.417456 1.3333273 11.872697 -0.30664724 8.8845 - 6.5569253 4.7948146 0.03662816 -8.704245 6.224871 - -3.2701402 -11.508579 ] + Audio embedding Result: + [ 1.4217498 5.626253 -5.342073 1.1773866 3.308055 + 1.756596 5.167894 10.80636 -3.8226728 -5.6141334 + 2.623845 -0.8072968 1.9635103 -7.3128724 0.01103897 + -9.723131 0.6619743 -6.976803 10.213478 7.494748 + 2.9105635 3.8949256 3.7999806 7.1061673 16.905321 + -7.1493764 8.733103 3.4230042 -4.831653 -11.403367 + 11.232214 7.1274667 -4.2828417 2.452362 -5.130748 + -18.177666 -2.6116815 -11.000337 -6.7314315 1.6564683 + 0.7618269 1.1253023 -2.083836 4.725744 -8.782597 + -3.539873 3.814236 5.1420674 2.162061 4.096431 + -6.4162116 12.747448 1.9429878 -15.152943 6.417416 + 16.097002 -9.716668 -1.9920526 -3.3649497 -1.871939 + 11.567354 3.69788 11.258265 7.442363 9.183411 + 4.5281515 -1.2417862 4.3959084 6.6727695 5.8898783 + 7.627124 -0.66919386 -11.889693 -9.208865 -7.4274073 + -3.7776625 6.917234 -9.848748 -2.0944717 -5.135116 + 0.49563864 9.317534 -5.9141874 -1.8098574 -0.11738578 + -7.169265 -1.0578263 -5.7216787 -5.1173844 16.137651 + -4.473626 7.6624317 -0.55381083 9.631587 -6.4704556 + -8.548508 4.3716145 -0.79702514 4.478997 -2.9758704 + 3.272176 2.8382776 5.134597 -9.190781 -0.5657382 + -4.8745747 2.3165567 -5.984303 -2.1798875 0.35541576 + -0.31784213 9.493548 2.1144536 4.358092 -12.089823 + 8.451689 -7.925461 4.6242585 4.4289427 18.692003 + -2.6204622 -5.149185 -0.35821092 8.488551 4.981496 + -9.32683 -2.2544234 6.6417594 1.2119585 10.977129 + 16.555033 3.3238444 9.551863 -1.6676947 -0.79539716 + -8.605674 -0.47356385 2.6741948 -5.359179 -2.6673796 + 0.66607 15.443222 4.740594 -3.4725387 11.592567 + -2.054497 1.7361217 -8.265324 -9.30447 5.4068313 + -1.5180256 -7.746615 -6.089606 0.07112726 -0.34904733 + -8.649895 -9.998958 -2.564841 -0.53999114 2.601808 + -0.31927416 -1.8815292 -2.07215 -3.4105783 -8.2998085 + 1.483641 -15.365992 -8.288208 3.8847756 -3.4876456 + 7.3629923 0.4657332 3.132599 12.438889 -1.8337058 + 4.532936 2.7264361 10.145339 -6.521951 2.897153 + -3.3925855 5.079156 7.759716 4.677565 5.8457737 + 2.402413 7.7071047 3.9711342 -6.390043 6.1268735 + -3.7760346 -11.118123 ] + # get the test embedding + Test embedding Result: + [ -1.902964 2.0690894 -8.034194 3.5472693 0.18089125 + 6.9085927 1.4097427 -1.9487704 -10.021278 -0.20755845 + -8.04332 4.344489 2.3200977 -14.306299 5.184692 + -11.55602 -3.8497238 0.6444722 1.2833948 2.6766639 + 0.5878921 0.7946299 1.7207596 2.5791872 14.998469 + -1.3385371 15.031221 -0.8006958 1.99287 -9.52007 + 2.435466 4.003221 -4.33817 -4.898601 -5.304714 + -18.033886 10.790787 -12.784645 -5.641755 2.9761686 + -10.566622 1.4839455 6.152458 -5.7195854 2.8603241 + 6.112133 8.489869 5.5958056 1.2836679 -1.2293907 + 0.89927405 7.0288725 -2.854029 -0.9782962 5.8255906 + 14.905906 -5.025907 0.7866458 -4.2444224 -16.354029 + 10.521315 0.9604709 -3.3257897 7.144871 -13.592733 + -8.568869 -1.7953678 0.26313916 10.916714 -6.9374123 + 1.857403 -6.2746415 2.8154466 -7.2338667 -2.293357 + -0.05452765 5.4287076 5.0849075 -6.690375 -1.6183422 + 3.654291 0.94352573 -9.200294 -5.4749465 -3.5235846 + 1.3420814 4.240421 -2.772944 -2.8451524 16.311104 + 4.2969875 -1.762936 -12.5758915 8.595198 -0.8835239 + -1.5708797 1.568961 1.1413603 3.5032008 -0.45251232 + -6.786333 16.89443 5.3366146 -8.789056 0.6355629 + 3.2579517 -3.328322 7.5969577 0.66025066 -6.550468 + -9.148656 2.020372 -0.4615173 1.1965656 -3.8764873 + 11.6562195 -6.0750933 12.182899 3.2218833 0.81969476 + 5.570001 -3.8459578 -7.205299 7.9262037 -7.6611166 + -5.249467 -2.2671914 7.2658715 -13.298164 4.821147 + -2.7263982 11.691089 -3.8918593 -2.838112 -1.0336838 + -3.8034165 2.8536487 -5.60398 -1.1972581 1.3455094 + -3.4903061 2.2408795 5.5010734 -3.970756 11.99696 + -7.8858757 0.43160373 -5.5059714 4.3426995 16.322706 + 11.635366 0.72157705 -9.245714 -3.91465 -4.449838 + -1.5716927 7.713747 -2.2430465 -6.198303 -13.481864 + 2.8156567 -5.7812386 5.1456156 2.7289324 -14.505571 + 13.270688 3.448231 -7.0659585 4.5886116 -4.466099 + -0.296428 -11.463529 -2.6076477 14.110243 -6.9725137 + -1.9962958 2.7119343 19.391657 0.01961198 14.607133 + -1.6695905 -4.391516 1.3131028 -6.670972 -5.888604 + 12.0612335 5.9285784 3.3715196 1.492534 10.723728 + -0.95514804 -12.085431 ] + # get the score between enroll and test + Eembeddings Score: 0.4292638301849365 ``` ### 4.预训练模型 diff --git a/demos/speaker_verification/run.sh b/demos/speaker_verification/run.sh index 856886d33..6140f7f38 100644 --- a/demos/speaker_verification/run.sh +++ b/demos/speaker_verification/run.sh @@ -1,6 +1,9 @@ #!/bin/bash wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav +wget -c https://paddlespeech.bj.bcebos.com/vector/audio/123456789.wav -# asr -paddlespeech vector --task spk --input ./85236145389.wav \ No newline at end of file +# vector +paddlespeech vector --task spk --input ./85236145389.wav + +paddlespeech vector --task score --input "./85236145389.wav ./123456789.wav" diff --git a/docs/source/released_model.md b/docs/source/released_model.md index 9a423e03e..1cbe39895 100644 --- a/docs/source/released_model.md +++ b/docs/source/released_model.md @@ -6,7 +6,7 @@ ### Speech Recognition Model Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER | Hours of speech | Example Link :-------------:| :------------:| :-----: | -----: | :-----: |:-----:| :-----: | :-----: | :-----: -[Ds2 Online Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz) | Aishell Dataset | Char-based | 345 MB | 2 Conv + 5 LSTM layers with only forward direction | 0.080 |-| 151 h | [D2 Online Aishell ASR0](../../examples/aishell/asr0) +[Ds2 Online Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz) | Aishell Dataset | Char-based | 345 MB | 2 Conv + 5 LSTM layers with only forward direction | 0.078 |-| 151 h | [D2 Online Aishell ASR0](../../examples/aishell/asr0) [Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz)| Aishell Dataset | Char-based | 306 MB | 2 Conv + 3 bidirectional GRU layers| 0.064 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0) [Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.1.2.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0565 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1) [Conformer Offline Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_0.1.2.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0483 |-| 151 h | [Conformer Offline Aishell ASR1](../../examples/aishell/asr1) @@ -37,8 +37,8 @@ Model Type | Dataset| Example Link | Pretrained Models|Static Models|Size (stati Tacotron2|LJSpeech|[tacotron2-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts0)|[tacotron2_ljspeech_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.2.0.zip)||| Tacotron2|CSMSC|[tacotron2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts0)|[tacotron2_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip)|[tacotron2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_static_0.2.0.zip)|103MB| TransformerTTS| LJSpeech| [transformer-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts1)|[transformer_tts_ljspeech_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/transformer_tts/transformer_tts_ljspeech_ckpt_0.4.zip)||| -SpeedySpeech| CSMSC | [speedyspeech-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts2) |[speedyspeech_nosil_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_ckpt_0.5.zip)|[speedyspeech_nosil_baker_static_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_static_0.5.zip)|12MB| -FastSpeech2| CSMSC |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3)|[fastspeech2_nosil_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip)|[fastspeech2_nosil_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_static_0.4.zip)|157MB| +SpeedySpeech| CSMSC | [speedyspeech-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts2) |[speedyspeech_nosil_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_ckpt_0.5.zip)|[speedyspeech_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_static_0.2.0.zip)|12MB| +FastSpeech2| CSMSC |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3)|[fastspeech2_nosil_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip)|[fastspeech2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_static_0.2.0.zip)|157MB| FastSpeech2-Conformer| CSMSC |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3)|[fastspeech2_conformer_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip)||| FastSpeech2| AISHELL-3 |[fastspeech2-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/tts3)|[fastspeech2_nosil_aishell3_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_ckpt_0.4.zip)||| FastSpeech2| LJSpeech |[fastspeech2-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts3)|[fastspeech2_nosil_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip)||| @@ -80,7 +80,7 @@ PANN | ESC-50 |[pann-esc50](../../examples/esc50/cls0)|[esc50_cnn6.tar.gz](https Model Type | Dataset| Example Link | Pretrained Models | Static Models :-------------:| :------------:| :-----: | :-----: | :-----: -PANN | VoxCeleb| [voxceleb_ecapatdnn](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/voxceleb/sv0) | [ecapatdnn.tar.gz](https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz) | - +PANN | VoxCeleb| [voxceleb_ecapatdnn](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/voxceleb/sv0) | [ecapatdnn.tar.gz](https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz) | - ## Punctuation Restoration Models Model Type | Dataset| Example Link | Pretrained Models diff --git a/examples/aishell/asr0/README.md b/examples/aishell/asr0/README.md index bb45d8df0..4459b1382 100644 --- a/examples/aishell/asr0/README.md +++ b/examples/aishell/asr0/README.md @@ -151,21 +151,14 @@ avg.sh best exp/deepspeech2/checkpoints 1 CUDA_VISIBLE_DEVICES= ./local/test.sh conf/deepspeech2.yaml exp/deepspeech2/checkpoints/avg_1 ``` ## Pretrained Model -You can get the pretrained transformer or conformer using the scripts below: -```bash -Deepspeech2 offline: -wget https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/ds2.model.tar.gz - -Deepspeech2 online: -wget https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/aishell_ds2_online_cer8.00_release.tar.gz +You can get the pretrained models from [this](../../../docs/source/released_model.md). -``` using the `tar` scripts to unpack the model and then you can use the script to test the model. For example: ``` -wget https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/ds2.model.tar.gz -tar xzvf ds2.model.tar.gz +wget https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz +tar xzvf asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz source path.sh # If you have process the data and get the manifest file, you can skip the following 2 steps bash local/data.sh --stage -1 --stop_stage -1 @@ -173,12 +166,7 @@ bash local/data.sh --stage 2 --stop_stage 2 CUDA_VISIBLE_DEVICES= ./local/test.sh conf/deepspeech2.yaml exp/deepspeech2/checkpoints/avg_1 ``` -The performance of the released models are shown below: - -| Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER | Hours of speech | -| :----------------------------: | :-------------: | :---------: | -----: | :------------------------------------------------- | :---- | :--- | :-------------- | -| Ds2 Online Aishell ASR0 Model | Aishell Dataset | Char-based | 345 MB | 2 Conv + 5 LSTM layers with only forward direction | 0.080 | - | 151 h | -| Ds2 Offline Aishell ASR0 Model | Aishell Dataset | Char-based | 306 MB | 2 Conv + 3 bidirectional GRU layers | 0.064 | - | 151 h | +The performance of the released models are shown in [this](./RESULTS.md) ## Stage 4: Static graph model Export This stage is to transform dygraph to static graph. ```bash @@ -214,8 +202,8 @@ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then ``` you can train the model by yourself, or you can download the pretrained model by the script below: ```bash -wget https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/ds2.model.tar.gz -tar xzvf ds2.model.tar.gz +wget https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz +tar xzvf asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz ``` You can download the audio demo: ```bash diff --git a/examples/aishell/asr0/RESULTS.md b/examples/aishell/asr0/RESULTS.md index 5841a8522..8af3d66d1 100644 --- a/examples/aishell/asr0/RESULTS.md +++ b/examples/aishell/asr0/RESULTS.md @@ -4,15 +4,16 @@ | Model | Number of Params | Release | Config | Test set | Valid Loss | CER | | --- | --- | --- | --- | --- | --- | --- | -| DeepSpeech2 | 45.18M | 2.2.0 | conf/deepspeech2_online.yaml + spec aug | test | 7.994938373565674 | 0.080 | +| DeepSpeech2 | 45.18M | r0.2.0 | conf/deepspeech2_online.yaml + spec aug | test | 7.708217620849609| 0.078 | +| DeepSpeech2 | 45.18M | v2.2.0 | conf/deepspeech2_online.yaml + spec aug | test | 7.994938373565674 | 0.080 | ## Deepspeech2 Non-Streaming | Model | Number of Params | Release | Config | Test set | Valid Loss | CER | | --- | --- | --- | --- | --- | --- | --- | -| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug | test | 5.738585948944092 | 0.064000 | -| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 | -| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 | -| DeepSpeech2 | 58.4M | 2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 | +| DeepSpeech2 | 58.4M | v2.2.0 | conf/deepspeech2.yaml + spec aug | test | 5.738585948944092 | 0.064000 | +| DeepSpeech2 | 58.4M | v2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 | +| DeepSpeech2 | 58.4M | v2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 | +| DeepSpeech2 | 58.4M | v2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 | | --- | --- | --- | --- | --- | --- | --- | -| DeepSpeech2 | 58.4M | 1.8.5 | - | test | - | 0.080447 | +| DeepSpeech2 | 58.4M | v1.8.5 | - | test | - | 0.080447 | diff --git a/examples/aishell/asr1/README.md b/examples/aishell/asr1/README.md index 5277a31eb..25b28ede8 100644 --- a/examples/aishell/asr1/README.md +++ b/examples/aishell/asr1/README.md @@ -143,25 +143,14 @@ avg.sh best exp/conformer/checkpoints 20 CUDA_VISIBLE_DEVICES= ./local/test.sh conf/conformer.yaml exp/conformer/checkpoints/avg_20 ``` ## Pretrained Model -You can get the pretrained transformer or conformer using the scripts below: +You can get the pretrained transformer or conformer from [this](../../../docs/source/released_model.md) -```bash -# Conformer: -wget https://deepspeech.bj.bcebos.com/release2.1/aishell/s1/aishell.release.tar.gz - -# Chunk Conformer: -wget https://deepspeech.bj.bcebos.com/release2.1/aishell/s1/aishell.chunk.release.tar.gz - -# Transformer: -wget https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz - -``` using the `tar` scripts to unpack the model and then you can use the script to test the model. For example: ``` -wget https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz -tar xzvf transformer.model.tar.gz +wget https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz +tar xzvf asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz source path.sh # If you have process the data and get the manifest file, you can skip the following 2 steps bash local/data.sh --stage -1 --stop_stage -1 @@ -206,7 +195,7 @@ In some situations, you want to use the trained model to do the inference for th ``` you can train the model by yourself using ```bash run.sh --stage 0 --stop_stage 3```, or you can download the pretrained model through the script below: ```bash -wget https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz +wget https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz tar xzvf transformer.model.tar.gz ``` You can download the audio demo: diff --git a/examples/aishell3/vc0/README.md b/examples/aishell3/vc0/README.md index 664ec1ac3..925663ab1 100644 --- a/examples/aishell3/vc0/README.md +++ b/examples/aishell3/vc0/README.md @@ -118,7 +118,7 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/voice_cloning.sh ${conf_path} ${train_outpu ``` ## Pretrained Model -[tacotron2_aishell3_ckpt_vc0_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_aishell3_ckpt_vc0_0.2.0.zip) +- [tacotron2_aishell3_ckpt_vc0_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_aishell3_ckpt_vc0_0.2.0.zip) Model | Step | eval/loss | eval/l1_loss | eval/mse_loss | eval/bce_loss| eval/attn_loss diff --git a/examples/aishell3/vc1/README.md b/examples/aishell3/vc1/README.md index 04b83a5ff..8ab0f9c8c 100644 --- a/examples/aishell3/vc1/README.md +++ b/examples/aishell3/vc1/README.md @@ -119,7 +119,7 @@ ref_audio CUDA_VISIBLE_DEVICES=${gpus} ./local/voice_cloning.sh ${conf_path} ${train_output_path} ${ckpt_name} ${ge2e_params_path} ${ref_audio_dir} ``` ## Pretrained Model -[fastspeech2_nosil_aishell3_vc1_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_vc1_ckpt_0.5.zip) +- [fastspeech2_nosil_aishell3_vc1_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_vc1_ckpt_0.5.zip) Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/pitch_loss| eval/energy_loss :-------------:| :------------:| :-----: | :-----: | :--------: |:--------:|:---------: diff --git a/examples/aishell3/voc1/README.md b/examples/aishell3/voc1/README.md index dad464092..eb30e7c40 100644 --- a/examples/aishell3/voc1/README.md +++ b/examples/aishell3/voc1/README.md @@ -137,7 +137,8 @@ optional arguments: 5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. ## Pretrained Models -Pretrained models can be downloaded here [pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip). +Pretrained models can be downloaded here: +- [pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip) Model | Step | eval/generator_loss | eval/log_stft_magnitude_loss:| eval/spectral_convergence_loss :-------------:| :------------:| :-----: | :-----: | :--------: diff --git a/examples/aishell3/voc5/README.md b/examples/aishell3/voc5/README.md index ebe2530be..c957c4a3a 100644 --- a/examples/aishell3/voc5/README.md +++ b/examples/aishell3/voc5/README.md @@ -136,7 +136,8 @@ optional arguments: 4. `--output-dir` is the directory to save the synthesized audio files. 5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. ## Pretrained Models -The pretrained model can be downloaded here [hifigan_aishell3_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip). +The pretrained model can be downloaded here: +- [hifigan_aishell3_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip) Model | Step | eval/generator_loss | eval/mel_loss| eval/feature_matching_loss diff --git a/examples/ami/sd0/conf/ecapa_tdnn.yaml b/examples/ami/sd0/conf/ecapa_tdnn.yaml new file mode 100755 index 000000000..319e44976 --- /dev/null +++ b/examples/ami/sd0/conf/ecapa_tdnn.yaml @@ -0,0 +1,62 @@ +########################################################### +# AMI DATA PREPARE SETTING # +########################################################### +split_type: 'full_corpus_asr' +skip_TNO: True +# Options for mic_type: 'Mix-Lapel', 'Mix-Headset', 'Array1', 'Array1-01', 'BeamformIt' +mic_type: 'Mix-Headset' +vad_type: 'oracle' +max_subseg_dur: 3.0 +overlap: 1.5 +# Some more exp folders (for cleaner structure). +embedding_dir: emb #!ref /emb +meta_data_dir: metadata #!ref /metadata +ref_rttm_dir: ref_rttms #!ref /ref_rttms +sys_rttm_dir: sys_rttms #!ref /sys_rttms +der_dir: DER #!ref /DER + + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +# currently, we only support fbank +sr: 16000 # sample rate +n_mels: 80 +window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400 +hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160 +#left_frames: 0 +#right_frames: 0 +#deltas: False + + +########################################################### +# MODEL SETTING # +########################################################### +# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml +# if we want use another model, please choose another configuration yaml file +seed: 1234 +emb_dim: 192 +batch_size: 16 +model: + input_size: 80 + channels: [1024, 1024, 1024, 1024, 3072] + kernel_sizes: [5, 3, 3, 3, 1] + dilations: [1, 2, 3, 4, 1] + attention_channels: 128 + lin_neurons: 192 +# Will automatically download ECAPA-TDNN model (best). + +########################################################### +# SPECTRAL CLUSTERING SETTING # +########################################################### +backend: 'SC' # options: 'kmeans' # Note: kmeans goes only with cos affinity +affinity: 'cos' # options: cos, nn +max_num_spkrs: 10 +oracle_n_spkrs: True + + +########################################################### +# DER EVALUATION SETTING # +########################################################### +ignore_overlap: True +forgiveness_collar: 0.25 diff --git a/examples/ami/sd0/local/compute_embdding.py b/examples/ami/sd0/local/compute_embdding.py new file mode 100644 index 000000000..dc824d7ca --- /dev/null +++ b/examples/ami/sd0/local/compute_embdding.py @@ -0,0 +1,231 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os +import pickle +import sys + +import numpy as np +import paddle +from paddle.io import BatchSampler +from paddle.io import DataLoader +from tqdm.contrib import tqdm +from yacs.config import CfgNode + +from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.cluster.diarization import EmbeddingMeta +from paddlespeech.vector.io.batch import batch_feature_normalize +from paddlespeech.vector.io.dataset_from_json import JSONDataset +from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn +from paddlespeech.vector.modules.sid_model import SpeakerIdetification +from paddlespeech.vector.training.seeding import seed_everything + +# Logger setup +logger = Log(__name__).getlog() + + +def prepare_subset_json(full_meta_data, rec_id, out_meta_file): + """Prepares metadata for a given recording ID. + + Arguments + --------- + full_meta_data : json + Full meta (json) containing all the recordings + rec_id : str + The recording ID for which meta (json) has to be prepared + out_meta_file : str + Path of the output meta (json) file. + """ + + subset = {} + for key in full_meta_data: + k = str(key) + if k.startswith(rec_id): + subset[key] = full_meta_data[key] + + with open(out_meta_file, mode="w") as json_f: + json.dump(subset, json_f, indent=2) + + +def create_dataloader(json_file, batch_size): + """Creates the datasets and their data processing pipelines. + This is used for multi-mic processing. + """ + + # create datasets + dataset = JSONDataset( + json_file=json_file, + feat_type='melspectrogram', + n_mels=config.n_mels, + window_size=config.window_size, + hop_length=config.hop_size) + + # create dataloader + batch_sampler = BatchSampler(dataset, batch_size=batch_size, shuffle=True) + dataloader = DataLoader(dataset, + batch_sampler=batch_sampler, + collate_fn=lambda x: batch_feature_normalize( + x, mean_norm=True, std_norm=False), + return_list=True) + + return dataloader + + +def main(args, config): + # set the training device, cpu or gpu + paddle.set_device(args.device) + # set the random seed + seed_everything(config.seed) + + # stage1: build the dnn backbone model network + ecapa_tdnn = EcapaTdnn(**config.model) + + # stage2: build the speaker verification eval instance with backbone model + model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=1) + + # stage3: load the pre-trained model + # we get the last model from the epoch and save_interval + args.load_checkpoint = os.path.abspath( + os.path.expanduser(args.load_checkpoint)) + + # load model checkpoint to sid model + state_dict = paddle.load( + os.path.join(args.load_checkpoint, 'model.pdparams')) + model.set_state_dict(state_dict) + logger.info(f'Checkpoint loaded from {args.load_checkpoint}') + + # set the model to eval mode + model.eval() + + # load meta data + meta_file = os.path.join( + args.data_dir, + config.meta_data_dir, + "ami_" + args.dataset + "." + config.mic_type + ".subsegs.json", ) + with open(meta_file, "r") as f: + full_meta = json.load(f) + + # get all the recording IDs in this dataset. + all_keys = full_meta.keys() + A = [word.rstrip().split("_")[0] for word in all_keys] + all_rec_ids = list(set(A[1:])) + all_rec_ids.sort() + split = "AMI_" + args.dataset + i = 1 + + msg = "Extra embdding for " + args.dataset + " set" + logger.info(msg) + + if len(all_rec_ids) <= 0: + msg = "No recording IDs found! Please check if meta_data json file is properly generated." + logger.error(msg) + sys.exit() + + # extra different recordings embdding in a dataset. + for rec_id in tqdm(all_rec_ids): + # This tag will be displayed in the log. + tag = ("[" + str(args.dataset) + ": " + str(i) + "/" + + str(len(all_rec_ids)) + "]") + i = i + 1 + + # log message. + msg = "Embdding %s : %s " % (tag, rec_id) + logger.debug(msg) + + # embedding directory. + if not os.path.exists( + os.path.join(args.data_dir, config.embedding_dir, split)): + os.makedirs( + os.path.join(args.data_dir, config.embedding_dir, split)) + + # file to store embeddings. + emb_file_name = rec_id + "." + config.mic_type + ".emb_stat.pkl" + diary_stat_emb_file = os.path.join(args.data_dir, config.embedding_dir, + split, emb_file_name) + + # prepare a metadata (json) for one recording. This is basically a subset of full_meta. + # lets keep this meta-info in embedding directory itself. + json_file_name = rec_id + "." + config.mic_type + ".json" + meta_per_rec_file = os.path.join(args.data_dir, config.embedding_dir, + split, json_file_name) + + # write subset (meta for one recording) json metadata. + prepare_subset_json(full_meta, rec_id, meta_per_rec_file) + + # prepare data loader. + diary_set_loader = create_dataloader(meta_per_rec_file, + config.batch_size) + + # extract embeddings (skip if already done). + if not os.path.isfile(diary_stat_emb_file): + logger.debug("Extracting deep embeddings") + embeddings = np.empty(shape=[0, config.emb_dim], dtype=np.float64) + segset = [] + + for batch_idx, batch in enumerate(tqdm(diary_set_loader)): + # extrac the audio embedding + ids, feats, lengths = batch['ids'], batch['feats'], batch[ + 'lengths'] + seg = [x for x in ids] + segset = segset + seg + emb = model.backbone(feats, lengths).squeeze( + -1).numpy() # (N, emb_size, 1) -> (N, emb_size) + embeddings = np.concatenate((embeddings, emb), axis=0) + + segset = np.array(segset, dtype="|O") + stat_obj = EmbeddingMeta( + segset=segset, + stats=embeddings, ) + logger.debug("Saving Embeddings...") + with open(diary_stat_emb_file, "wb") as output: + pickle.dump(stat_obj, output) + + else: + logger.debug("Skipping embedding extraction (as already present).") + + +# Begin experiment! +if __name__ == "__main__": + parser = argparse.ArgumentParser(__doc__) + parser.add_argument( + '--device', + default="gpu", + help="Select which device to perform diarization, defaults to gpu.") + parser.add_argument( + "--config", default=None, type=str, help="configuration file") + parser.add_argument( + "--data-dir", + default="../save/", + type=str, + help="processsed data directory") + parser.add_argument( + "--dataset", + choices=['dev', 'eval'], + default="dev", + type=str, + help="Select which dataset to extra embdding, defaults to dev") + parser.add_argument( + "--load-checkpoint", + type=str, + default='', + help="Directory to load model checkpoint to compute embeddings.") + args = parser.parse_args() + config = CfgNode(new_allowed=True) + if args.config: + config.merge_from_file(args.config) + + config.freeze() + + main(args, config) diff --git a/examples/ami/sd0/local/data.sh b/examples/ami/sd0/local/data.sh deleted file mode 100755 index 478ec432d..000000000 --- a/examples/ami/sd0/local/data.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash - -stage=1 - -TARGET_DIR=${MAIN_ROOT}/dataset/ami -data_folder=${TARGET_DIR}/amicorpus #e.g., /path/to/amicorpus/ -manual_annot_folder=${TARGET_DIR}/ami_public_manual_1.6.2 #e.g., /path/to/ami_public_manual_1.6.2/ - -save_folder=${MAIN_ROOT}/examples/ami/sd0/data -ref_rttm_dir=${save_folder}/ref_rttms -meta_data_dir=${save_folder}/metadata - -set=L - -. ${MAIN_ROOT}/utils/parse_options.sh || exit 1; -set -u -set -o pipefail - -mkdir -p ${save_folder} - -if [ ${stage} -le 0 ]; then - # Download AMI corpus, You need around 10GB of free space to get whole data - # The signals are too large to package in this way, - # so you need to use the chooser to indicate which ones you wish to download - echo "Please follow https://groups.inf.ed.ac.uk/ami/download/ to download the data." - echo "Annotations: AMI manual annotations v1.6.2 " - echo "Signals: " - echo "1) Select one or more AMI meetings: the IDs please follow ./ami_split.py" - echo "2) Select media streams: Just select Headset mix" - exit 0; -fi - -if [ ${stage} -le 1 ]; then - echo "AMI Data preparation" - - python local/ami_prepare.py --data_folder ${data_folder} \ - --manual_annot_folder ${manual_annot_folder} \ - --save_folder ${save_folder} --ref_rttm_dir ${ref_rttm_dir} \ - --meta_data_dir ${meta_data_dir} - - if [ $? -ne 0 ]; then - echo "Prepare AMI failed. Please check log message." - exit 1 - fi - -fi - -echo "AMI data preparation done." -exit 0 diff --git a/examples/ami/sd0/local/experiment.py b/examples/ami/sd0/local/experiment.py new file mode 100755 index 000000000..298228376 --- /dev/null +++ b/examples/ami/sd0/local/experiment.py @@ -0,0 +1,428 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import glob +import json +import os +import pickle +import shutil +import sys + +import numpy as np +from tqdm.contrib import tqdm +from yacs.config import CfgNode + +from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.cluster import diarization as diar +from utils.DER import DER + +# Logger setup +logger = Log(__name__).getlog() + + +def diarize_dataset( + full_meta, + split_type, + n_lambdas, + pval, + save_dir, + config, + n_neighbors=10, ): + """This function diarizes all the recordings in a given dataset. It performs + computation of embedding and clusters them using spectral clustering (or other backends). + The output speaker boundary file is stored in the RTTM format. + """ + + # prepare `spkr_info` only once when Oracle num of speakers is selected. + # spkr_info is essential to obtain number of speakers from groundtruth. + if config.oracle_n_spkrs is True: + full_ref_rttm_file = os.path.join(save_dir, config.ref_rttm_dir, + "fullref_ami_" + split_type + ".rttm") + rttm = diar.read_rttm(full_ref_rttm_file) + + spkr_info = list( # noqa F841 + filter(lambda x: x.startswith("SPKR-INFO"), rttm)) + + # get all the recording IDs in this dataset. + all_keys = full_meta.keys() + A = [word.rstrip().split("_")[0] for word in all_keys] + all_rec_ids = list(set(A[1:])) + all_rec_ids.sort() + split = "AMI_" + split_type + i = 1 + + # adding tag for directory path. + type_of_num_spkr = "oracle" if config.oracle_n_spkrs else "est" + tag = (type_of_num_spkr + "_" + str(config.affinity) + "_" + config.backend) + + # make out rttm dir + out_rttm_dir = os.path.join(save_dir, config.sys_rttm_dir, config.mic_type, + split, tag) + if not os.path.exists(out_rttm_dir): + os.makedirs(out_rttm_dir) + + # diarizing different recordings in a dataset. + for rec_id in tqdm(all_rec_ids): + # this tag will be displayed in the log. + tag = ("[" + str(split_type) + ": " + str(i) + "/" + + str(len(all_rec_ids)) + "]") + i = i + 1 + + # log message. + msg = "Diarizing %s : %s " % (tag, rec_id) + logger.debug(msg) + + # load embeddings. + emb_file_name = rec_id + "." + config.mic_type + ".emb_stat.pkl" + diary_stat_emb_file = os.path.join(save_dir, config.embedding_dir, + split, emb_file_name) + if not os.path.isfile(diary_stat_emb_file): + msg = "Embdding file %s not found! Please check if embdding file is properly generated." % ( + diary_stat_emb_file) + logger.error(msg) + sys.exit() + with open(diary_stat_emb_file, "rb") as in_file: + diary_obj = pickle.load(in_file) + + out_rttm_file = out_rttm_dir + "/" + rec_id + ".rttm" + + # processing starts from here. + if config.oracle_n_spkrs is True: + # oracle num of speakers. + num_spkrs = diar.get_oracle_num_spkrs(rec_id, spkr_info) + else: + if config.affinity == "nn": + # num of speakers tunned on dev set (only for nn affinity). + num_spkrs = n_lambdas + else: + # num of speakers will be estimated using max eigen gap for cos based affinity. + # so adding None here. Will use this None later-on. + num_spkrs = None + + if config.backend == "kmeans": + diar.do_kmeans_clustering( + diary_obj, + out_rttm_file, + rec_id, + num_spkrs, + pval, ) + + if config.backend == "SC": + # go for Spectral Clustering (SC). + diar.do_spec_clustering( + diary_obj, + out_rttm_file, + rec_id, + num_spkrs, + pval, + config.affinity, + n_neighbors, ) + + # can used for AHC later. Likewise one can add different backends here. + if config.backend == "AHC": + # call AHC + threshold = pval # pval for AHC is nothing but threshold. + diar.do_AHC(diary_obj, out_rttm_file, rec_id, num_spkrs, threshold) + + # once all RTTM outputs are generated, concatenate individual RTTM files to obtain single RTTM file. + # this is not needed but just staying with the standards. + concate_rttm_file = out_rttm_dir + "/sys_output.rttm" + logger.debug("Concatenating individual RTTM files...") + with open(concate_rttm_file, "w") as cat_file: + for f in glob.glob(out_rttm_dir + "/*.rttm"): + if f == concate_rttm_file: + continue + with open(f, "r") as indi_rttm_file: + shutil.copyfileobj(indi_rttm_file, cat_file) + + msg = "The system generated RTTM file for %s set : %s" % ( + split_type, concate_rttm_file, ) + logger.debug(msg) + + return concate_rttm_file + + +def dev_pval_tuner(full_meta, save_dir, config): + """Tuning p_value for affinity matrix. + The p_value used so that only p% of the values in each row is retained. + """ + + DER_list = [] + prange = np.arange(0.002, 0.015, 0.001) + + n_lambdas = None # using it as flag later. + for p_v in prange: + # Process whole dataset for value of p_v. + concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v, + save_dir, config) + + ref_rttm_file = os.path.join(save_dir, config.ref_rttm_dir, + "fullref_ami_dev.rttm") + sys_rttm_file = concate_rttm_file + [MS, FA, SER, DER_] = DER( + ref_rttm_file, + sys_rttm_file, + config.ignore_overlap, + config.forgiveness_collar, ) + + DER_list.append(DER_) + + if config.oracle_n_spkrs is True and config.backend == "kmeans": + # no need of p_val search. Note p_val is needed for SC for both oracle and est num of speakers. + # p_val is needed in oracle_n_spkr=False when using kmeans backend. + break + + # Take p_val that gave minmum DER on Dev dataset. + tuned_p_val = prange[DER_list.index(min(DER_list))] + + return tuned_p_val + + +def dev_ahc_threshold_tuner(full_meta, save_dir, config): + """Tuning threshold for affinity matrix. This function is called when AHC is used as backend. + """ + + DER_list = [] + prange = np.arange(0.0, 1.0, 0.1) + + n_lambdas = None # using it as flag later. + + # Note: p_val is threshold in case of AHC. + for p_v in prange: + # Process whole dataset for value of p_v. + concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v, + save_dir, config) + + ref_rttm = os.path.join(save_dir, config.ref_rttm_dir, + "fullref_ami_dev.rttm") + sys_rttm = concate_rttm_file + [MS, FA, SER, DER_] = DER( + ref_rttm, + sys_rttm, + config.ignore_overlap, + config.forgiveness_collar, ) + + DER_list.append(DER_) + + if config.oracle_n_spkrs is True: + break # no need of threshold search. + + # Take p_val that gave minmum DER on Dev dataset. + tuned_p_val = prange[DER_list.index(min(DER_list))] + + return tuned_p_val + + +def dev_nn_tuner(full_meta, split_type, save_dir, config): + """Tuning n_neighbors on dev set. Assuming oracle num of speakers. + This is used when nn based affinity is selected. + """ + + DER_list = [] + pval = None + + # Now assumming oracle num of speakers. + n_lambdas = 4 + + for nn in range(5, 15): + + # Process whole dataset for value of n_lambdas. + concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v, + save_dir, config, nn) + + ref_rttm = os.path.join(save_dir, config.ref_rttm_dir, + "fullref_ami_dev.rttm") + sys_rttm = concate_rttm_file + [MS, FA, SER, DER_] = DER( + ref_rttm, + sys_rttm, + config.ignore_overlap, + config.forgiveness_collar, ) + + DER_list.append([nn, DER_]) + + if config.oracle_n_spkrs is True and config.backend == "kmeans": + break + + DER_list.sort(key=lambda x: x[1]) + tunned_nn = DER_list[0] + + return tunned_nn[0] + + +def dev_tuner(full_meta, split_type, save_dir, config): + """Tuning n_components on dev set. Used for nn based affinity matrix. + Note: This is a very basic tunning for nn based affinity. + This is work in progress till we find a better way. + """ + + DER_list = [] + pval = None + for n_lambdas in range(1, config.max_num_spkrs + 1): + + # Process whole dataset for value of n_lambdas. + concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v, + save_dir, config) + + ref_rttm = os.path.join(save_dir, config.ref_rttm_dir, + "fullref_ami_dev.rttm") + sys_rttm = concate_rttm_file + [MS, FA, SER, DER_] = DER( + ref_rttm, + sys_rttm, + config.ignore_overlap, + config.forgiveness_collar, ) + + DER_list.append(DER_) + + # Take n_lambdas with minmum DER. + tuned_n_lambdas = DER_list.index(min(DER_list)) + 1 + + return tuned_n_lambdas + + +def main(args, config): + # AMI Dev Set: Tune hyperparams on dev set. + # Read the embdding file for dev set generated during embdding compute + dev_meta_file = os.path.join( + args.data_dir, + config.meta_data_dir, + "ami_dev." + config.mic_type + ".subsegs.json", ) + with open(dev_meta_file, "r") as f: + meta_dev = json.load(f) + + full_meta = meta_dev + + # Processing starts from here + # Following few lines selects option for different backend and affinity matrices. Finds best values for hyperameters using dev set. + ref_rttm_file = os.path.join(args.data_dir, config.ref_rttm_dir, + "fullref_ami_dev.rttm") + best_nn = None + if config.affinity == "nn": + logger.info("Tuning for nn (Multiple iterations over AMI Dev set)") + best_nn = dev_nn_tuner(full_meta, args.data_dir, config) + + n_lambdas = None + best_pval = None + + if config.affinity == "cos" and (config.backend == "SC" or + config.backend == "kmeans"): + # oracle num_spkrs or not, doesn't matter for kmeans and SC backends + # cos: Tune for the best pval for SC /kmeans (for unknown num of spkrs) + logger.info( + "Tuning for p-value for SC (Multiple iterations over AMI Dev set)") + best_pval = dev_pval_tuner(full_meta, args.data_dir, config) + + elif config.backend == "AHC": + logger.info("Tuning for threshold-value for AHC") + best_threshold = dev_ahc_threshold_tuner(full_meta, args.data_dir, + config) + best_pval = best_threshold + else: + # NN for unknown num of speakers (can be used in future) + if config.oracle_n_spkrs is False: + # nn: Tune num of number of components (to be updated later) + logger.info( + "Tuning for number of eigen components for NN (Multiple iterations over AMI Dev set)" + ) + # dev_tuner used for tuning num of components in NN. Can be used in future. + n_lambdas = dev_tuner(full_meta, args.data_dir, config) + + # load 'dev' and 'eval' metadata files. + full_meta_dev = full_meta # current full_meta is for 'dev' + eval_meta_file = os.path.join( + args.data_dir, + config.meta_data_dir, + "ami_eval." + config.mic_type + ".subsegs.json", ) + with open(eval_meta_file, "r") as f: + full_meta_eval = json.load(f) + + # tag to be appended to final output DER files. Writing DER for individual files. + type_of_num_spkr = "oracle" if config.oracle_n_spkrs else "est" + tag = ( + type_of_num_spkr + "_" + str(config.affinity) + "." + config.mic_type) + + # perform final diarization on 'dev' and 'eval' with best hyperparams. + final_DERs = {} + out_der_dir = os.path.join(args.data_dir, config.der_dir) + if not os.path.exists(out_der_dir): + os.makedirs(out_der_dir) + + for split_type in ["dev", "eval"]: + if split_type == "dev": + full_meta = full_meta_dev + else: + full_meta = full_meta_eval + + # performing diarization. + msg = "Diarizing using best hyperparams: " + split_type + " set" + logger.info(msg) + out_boundaries = diarize_dataset( + full_meta, + split_type, + n_lambdas=n_lambdas, + pval=best_pval, + n_neighbors=best_nn, + save_dir=args.data_dir, + config=config) + + # computing DER. + msg = "Computing DERs for " + split_type + " set" + logger.info(msg) + ref_rttm = os.path.join(args.data_dir, config.ref_rttm_dir, + "fullref_ami_" + split_type + ".rttm") + sys_rttm = out_boundaries + [MS, FA, SER, DER_vals] = DER( + ref_rttm, + sys_rttm, + config.ignore_overlap, + config.forgiveness_collar, + individual_file_scores=True, ) + + # writing DER values to a file. Append tag. + der_file_name = split_type + "_DER_" + tag + out_der_file = os.path.join(out_der_dir, der_file_name) + msg = "Writing DER file to: " + out_der_file + logger.info(msg) + diar.write_ders_file(ref_rttm, DER_vals, out_der_file) + + msg = ("AMI " + split_type + " set DER = %s %%\n" % + (str(round(DER_vals[-1], 2)))) + logger.info(msg) + final_DERs[split_type] = round(DER_vals[-1], 2) + + # final print DERs + msg = ( + "Final Diarization Error Rate (%%) on AMI corpus: Dev = %s %% | Eval = %s %%\n" + % (str(final_DERs["dev"]), str(final_DERs["eval"]))) + logger.info(msg) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(__doc__) + parser.add_argument( + "--config", default=None, type=str, help="configuration file") + parser.add_argument( + "--data-dir", + default="../data/", + type=str, + help="processsed data directory") + args = parser.parse_args() + config = CfgNode(new_allowed=True) + if args.config: + config.merge_from_file(args.config) + + config.freeze() + + main(args, config) diff --git a/examples/ami/sd0/local/process.sh b/examples/ami/sd0/local/process.sh new file mode 100755 index 000000000..1dfd11b86 --- /dev/null +++ b/examples/ami/sd0/local/process.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +stage=0 +set=L + +. ${MAIN_ROOT}/utils/parse_options.sh || exit 1; +set -o pipefail + +data_folder=$1 +manual_annot_folder=$2 +save_folder=$3 +pretrained_model_dir=$4 +conf_path=$5 +device=$6 + +ref_rttm_dir=${save_folder}/ref_rttms +meta_data_dir=${save_folder}/metadata + +if [ ${stage} -le 0 ]; then + echo "AMI Data preparation" + python local/ami_prepare.py --data_folder ${data_folder} \ + --manual_annot_folder ${manual_annot_folder} \ + --save_folder ${save_folder} --ref_rttm_dir ${ref_rttm_dir} \ + --meta_data_dir ${meta_data_dir} + + if [ $? -ne 0 ]; then + echo "Prepare AMI failed. Please check log message." + exit 1 + fi + echo "AMI data preparation done." +fi + +if [ ${stage} -le 1 ]; then + # extra embddings for dev and eval dataset + for name in dev eval; do + python local/compute_embdding.py --config ${conf_path} \ + --data-dir ${save_folder} \ + --device ${device} \ + --dataset ${name} \ + --load-checkpoint ${pretrained_model_dir} + done +fi + +if [ ${stage} -le 2 ]; then + # tune hyperparams on dev set + # perform final diarization on 'dev' and 'eval' with best hyperparams + python local/experiment.py --config ${conf_path} \ + --data-dir ${save_folder} +fi diff --git a/examples/ami/sd0/run.sh b/examples/ami/sd0/run.sh index 91d4b706a..9035f5955 100644 --- a/examples/ami/sd0/run.sh +++ b/examples/ami/sd0/run.sh @@ -1,14 +1,46 @@ #!/bin/bash -. path.sh || exit 1; +. ./path.sh || exit 1; set -e -stage=1 +stage=0 +#TARGET_DIR=${MAIN_ROOT}/dataset/ami +TARGET_DIR=/home/dataset/AMI +data_folder=${TARGET_DIR}/amicorpus #e.g., /path/to/amicorpus/ +manual_annot_folder=${TARGET_DIR}/ami_public_manual_1.6.2 #e.g., /path/to/ami_public_manual_1.6.2/ + +save_folder=./save +pretraind_model_dir=${save_folder}/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1/model +conf_path=conf/ecapa_tdnn.yaml +device=gpu . ${MAIN_ROOT}/utils/parse_options.sh || exit 1; -if [ ${stage} -le 1 ]; then - # prepare data - bash ./local/data.sh || exit -1 -fi \ No newline at end of file +if [ $stage -le 0 ]; then + # Prepare data + # Download AMI corpus, You need around 10GB of free space to get whole data + # The signals are too large to package in this way, + # so you need to use the chooser to indicate which ones you wish to download + echo "Please follow https://groups.inf.ed.ac.uk/ami/download/ to download the data." + echo "Annotations: AMI manual annotations v1.6.2 " + echo "Signals: " + echo "1) Select one or more AMI meetings: the IDs please follow ./ami_split.py" + echo "2) Select media streams: Just select Headset mix" +fi + +if [ $stage -le 1 ]; then + # Download the pretrained model + wget https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz + mkdir -p ${save_folder} && tar -xvf sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz -C ${save_folder} + rm -rf sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz + echo "download the pretrained ECAPA-TDNN Model to path: "${pretraind_model_dir} +fi + +if [ $stage -le 2 ]; then + # Tune hyperparams on dev set and perform final diarization on dev and eval with best hyperparams. + echo ${data_folder} ${manual_annot_folder} ${save_folder} ${pretraind_model_dir} ${conf_path} + bash ./local/process.sh ${data_folder} ${manual_annot_folder} \ + ${save_folder} ${pretraind_model_dir} ${conf_path} ${device} || exit 1 +fi + diff --git a/examples/csmsc/tts0/README.md b/examples/csmsc/tts0/README.md index 0129329ae..01376bd61 100644 --- a/examples/csmsc/tts0/README.md +++ b/examples/csmsc/tts0/README.md @@ -212,7 +212,8 @@ optional arguments: Pretrained Tacotron2 model with no silence in the edge of audios: - [tacotron2_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip) -The static model can be downloaded here [tacotron2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_static_0.2.0.zip). +The static model can be downloaded here: +- [tacotron2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_static_0.2.0.zip) Model | Step | eval/loss | eval/l1_loss | eval/mse_loss | eval/bce_loss| eval/attn_loss diff --git a/examples/csmsc/tts2/README.md b/examples/csmsc/tts2/README.md index 5f31f7b36..4fbe34cbf 100644 --- a/examples/csmsc/tts2/README.md +++ b/examples/csmsc/tts2/README.md @@ -221,9 +221,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} ``` ## Pretrained Model -Pretrained SpeedySpeech model with no silence in the edge of audios[speedyspeech_nosil_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_ckpt_0.5.zip). +Pretrained SpeedySpeech model with no silence in the edge of audios: +- [speedyspeech_nosil_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_ckpt_0.5.zip) -The static model can be downloaded here [speedyspeech_nosil_baker_static_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_static_0.5.zip). +The static model can be downloaded here: +- [speedyspeech_nosil_baker_static_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_static_0.5.zip) +- [speedyspeech_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_static_0.2.0.zip) Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/ssim_loss :-------------:| :------------:| :-----: | :-----: | :--------:|:--------: diff --git a/examples/csmsc/tts3/README.md b/examples/csmsc/tts3/README.md index ae8f7af60..bc672f66f 100644 --- a/examples/csmsc/tts3/README.md +++ b/examples/csmsc/tts3/README.md @@ -232,6 +232,9 @@ The static model can be downloaded here: - [fastspeech2_nosil_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_static_0.4.zip) - [fastspeech2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_static_0.2.0.zip) +The ONNX model can be downloaded here: +- [fastspeech2_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_onnx_0.2.0.zip) + Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/pitch_loss| eval/energy_loss :-------------:| :------------:| :-----: | :-----: | :--------: |:--------:|:---------: default| 2(gpu) x 76000|1.0991|0.59132|0.035815|0.31915|0.15287| diff --git a/examples/csmsc/tts3/local/ort_predict.sh b/examples/csmsc/tts3/local/ort_predict.sh new file mode 100755 index 000000000..3154f6e5a --- /dev/null +++ b/examples/csmsc/tts3/local/ort_predict.sh @@ -0,0 +1,31 @@ +train_output_path=$1 + +stage=0 +stop_stage=0 + +# only support default_fastspeech2 + hifigan/mb_melgan now! + +# synthesize from metadata +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + python3 ${BIN_DIR}/../ort_predict.py \ + --inference_dir=${train_output_path}/inference_onnx \ + --am=fastspeech2_csmsc \ + --voc=hifigan_csmsc \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/onnx_infer_out \ + --device=cpu \ + --cpu_threads=2 +fi + +# e2e, synthesize from text +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + python3 ${BIN_DIR}/../ort_predict_e2e.py \ + --inference_dir=${train_output_path}/inference_onnx \ + --am=fastspeech2_csmsc \ + --voc=hifigan_csmsc \ + --output_dir=${train_output_path}/onnx_infer_out_e2e \ + --text=${BIN_DIR}/../csmsc_test.txt \ + --phones_dict=dump/phone_id_map.txt \ + --device=cpu \ + --cpu_threads=2 +fi diff --git a/examples/csmsc/tts3/local/paddle2onnx.sh b/examples/csmsc/tts3/local/paddle2onnx.sh new file mode 100755 index 000000000..505f3b663 --- /dev/null +++ b/examples/csmsc/tts3/local/paddle2onnx.sh @@ -0,0 +1,22 @@ +train_output_path=$1 +model_dir=$2 +output_dir=$3 +model=$4 + +enable_dev_version=True + +model_name=${model%_*} +echo model_name: ${model_name} + +if [ ${model_name} = 'mb_melgan' ] ;then + enable_dev_version=False +fi + +mkdir -p ${train_output_path}/${output_dir} + +paddle2onnx \ + --model_dir ${train_output_path}/${model_dir} \ + --model_filename ${model}.pdmodel \ + --params_filename ${model}.pdiparams \ + --save_file ${train_output_path}/${output_dir}/${model}.onnx \ + --enable_dev_version ${enable_dev_version} \ No newline at end of file diff --git a/examples/csmsc/tts3/run.sh b/examples/csmsc/tts3/run.sh index e1a149b65..b617d5352 100755 --- a/examples/csmsc/tts3/run.sh +++ b/examples/csmsc/tts3/run.sh @@ -41,3 +41,25 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} || exit -1 fi +# paddle2onnx, please make sure the static models are in ${train_output_path}/inference first +# we have only tested the following models so far +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + # install paddle2onnx + version=$(echo `pip list |grep "paddle2onnx"` |awk -F" " '{print $2}') + if [[ -z "$version" || ${version} != '0.9.4' ]]; then + pip install paddle2onnx==0.9.4 + fi + ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx fastspeech2_csmsc + ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx hifigan_csmsc + ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx mb_melgan_csmsc +fi + +# inference with onnxruntime, use fastspeech2 + hifigan by default +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + # install onnxruntime + version=$(echo `pip list |grep "onnxruntime"` |awk -F" " '{print $2}') + if [[ -z "$version" || ${version} != '1.10.0' ]]; then + pip install onnxruntime==1.10.0 + fi + ./local/ort_predict.sh ${train_output_path} +fi diff --git a/examples/csmsc/voc1/README.md b/examples/csmsc/voc1/README.md index 5527e8088..2d6de168a 100644 --- a/examples/csmsc/voc1/README.md +++ b/examples/csmsc/voc1/README.md @@ -127,9 +127,11 @@ optional arguments: 5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. ## Pretrained Models -The pretrained model can be downloaded here [pwg_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip). +The pretrained model can be downloaded here: +- [pwg_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip) -The static model can be downloaded here [pwg_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip). +The static model can be downloaded here: +- [pwg_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip) Model | Step | eval/generator_loss | eval/log_stft_magnitude_loss| eval/spectral_convergence_loss :-------------:| :------------:| :-----: | :-----: | :--------: diff --git a/examples/csmsc/voc3/README.md b/examples/csmsc/voc3/README.md index 22104a8f2..12adaf7f4 100644 --- a/examples/csmsc/voc3/README.md +++ b/examples/csmsc/voc3/README.md @@ -152,11 +152,17 @@ TODO: The hyperparameter of `finetune.yaml` is not good enough, a smaller `learning_rate` should be used (more `milestones` should be set). ## Pretrained Models -The pretrained model can be downloaded here [mb_melgan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip). +The pretrained model can be downloaded here: +- [mb_melgan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip) -The finetuned model can be downloaded here [mb_melgan_baker_finetune_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_baker_finetune_ckpt_0.5.zip). +The finetuned model can be downloaded here: +- [mb_melgan_baker_finetune_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_baker_finetune_ckpt_0.5.zip) -The static model can be downloaded here [mb_melgan_csmsc_static_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_static_0.1.1.zip) +The static model can be downloaded here: +- [mb_melgan_csmsc_static_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_static_0.1.1.zip) + +The ONNX model can be downloaded here: +- [mb_melgan_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip) Model | Step | eval/generator_loss | eval/log_stft_magnitude_loss|eval/spectral_convergence_loss |eval/sub_log_stft_magnitude_loss|eval/sub_spectral_convergence_loss :-------------:| :------------:| :-----: | :-----: | :--------:| :--------:| :--------: diff --git a/examples/csmsc/voc4/README.md b/examples/csmsc/voc4/README.md index b5c687391..b7add3e57 100644 --- a/examples/csmsc/voc4/README.md +++ b/examples/csmsc/voc4/README.md @@ -112,7 +112,8 @@ optional arguments: 5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. ## Pretrained Models -The pretrained model can be downloaded here [style_melgan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip). +The pretrained model can be downloaded here: +- [style_melgan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip) The static model of Style MelGAN is not available now. diff --git a/examples/csmsc/voc5/README.md b/examples/csmsc/voc5/README.md index 21afe6eef..33e676165 100644 --- a/examples/csmsc/voc5/README.md +++ b/examples/csmsc/voc5/README.md @@ -112,9 +112,14 @@ optional arguments: 5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. ## Pretrained Models -The pretrained model can be downloaded here [hifigan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip). +The pretrained model can be downloaded here: +- [hifigan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip) -The static model can be downloaded here [hifigan_csmsc_static_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_static_0.1.1.zip). +The static model can be downloaded here: +- [hifigan_csmsc_static_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_static_0.1.1.zip) + +The ONNX model can be downloaded here: +- [hifigan_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_onnx_0.2.0.zip) Model | Step | eval/generator_loss | eval/mel_loss| eval/feature_matching_loss :-------------:| :------------:| :-----: | :-----: | :--------: diff --git a/examples/csmsc/voc6/README.md b/examples/csmsc/voc6/README.md index 7763b3551..26d4523d9 100644 --- a/examples/csmsc/voc6/README.md +++ b/examples/csmsc/voc6/README.md @@ -109,9 +109,11 @@ optional arguments: 5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. ## Pretrained Models -The pretrained model can be downloaded here [wavernn_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_ckpt_0.2.0.zip). +The pretrained model can be downloaded here: +- [wavernn_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_ckpt_0.2.0.zip) -The static model can be downloaded here [wavernn_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_static_0.2.0.zip). +The static model can be downloaded here: +- [wavernn_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_static_0.2.0.zip) Model | Step | eval/loss :-------------:|:------------:| :------------: diff --git a/examples/iwslt2012/punc0/README.md b/examples/iwslt2012/punc0/README.md index 74d599a21..6caa9710b 100644 --- a/examples/iwslt2012/punc0/README.md +++ b/examples/iwslt2012/punc0/README.md @@ -21,7 +21,7 @@ The pretrained model can be downloaded here [ernie_linear_p3_iwslt2012_zh_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_iwslt2012_zh_ckpt_0.1.1.zip). ### Test Result -- Ernie Linear +- Ernie | |COMMA | PERIOD | QUESTION | OVERALL| |:-----:|:-----:|:-----:|:-----:|:-----:| |Precision |0.510955 |0.526462 |0.820755 |0.619391| diff --git a/examples/iwslt2012/punc0/RESULTS.md b/examples/iwslt2012/punc0/RESULTS.md new file mode 100644 index 000000000..2e22713d8 --- /dev/null +++ b/examples/iwslt2012/punc0/RESULTS.md @@ -0,0 +1,9 @@ +# iwslt2012 + +## Ernie + +| |COMMA | PERIOD | QUESTION | OVERALL| +|:-----:|:-----:|:-----:|:-----:|:-----:| +|Precision |0.510955 |0.526462 |0.820755 |0.619391| +|Recall |0.517433 |0.564179 |0.861386 |0.647666| +|F1 |0.514173 |0.544669 |0.840580 |0.633141| diff --git a/examples/librispeech/asr1/README.md b/examples/librispeech/asr1/README.md index eb1a44001..ae252a58b 100644 --- a/examples/librispeech/asr1/README.md +++ b/examples/librispeech/asr1/README.md @@ -151,44 +151,22 @@ avg.sh best exp/conformer/checkpoints 20 CUDA_VISIBLE_DEVICES= ./local/test.sh conf/conformer.yaml exp/conformer/checkpoints/avg_20 ``` ## Pretrained Model -You can get the pretrained transformer or conformer using the scripts below: -```bash -# Conformer: -wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/conformer.model.tar.gz -# Transformer: -wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/transformer.model.tar.gz -``` +You can get the pretrained transformer or conformer from [this](../../../docs/source/released_model.md). + using the `tar` scripts to unpack the model and then you can use the script to test the model. For example: ```bash -wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/conformer.model.tar.gz -tar xzvf transformer.model.tar.gz +wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_conformer_librispeech_ckpt_0.1.1.model.tar.gz +tar xzvf asr1_conformer_librispeech_ckpt_0.1.1.model.tar.gz source path.sh # If you have process the data and get the manifest file, you can skip the following 2 steps bash local/data.sh --stage -1 --stop_stage -1 bash local/data.sh --stage 2 --stop_stage 2 CUDA_VISIBLE_DEVICES= ./local/test.sh conf/conformer.yaml exp/conformer/checkpoints/avg_20 ``` -The performance of the released models are shown below: -## Conformer -train: Epoch 70, 4 V100-32G, best avg: 20 - -| Model | Params | Config | Augmentation | Test set | Decode method | Loss | WER | -| --------- | ------- | ------------------- | ------------ | ---------- | ---------------------- | ----------------- | -------- | -| conformer | 47.63 M | conf/conformer.yaml | spec_aug | test-clean | attention | 6.433612394332886 | 0.039771 | -| conformer | 47.63 M | conf/conformer.yaml | spec_aug | test-clean | ctc_greedy_search | 6.433612394332886 | 0.040342 | -| conformer | 47.63 M | conf/conformer.yaml | spec_aug | test-clean | ctc_prefix_beam_search | 6.433612394332886 | 0.040342 | -| conformer | 47.63 M | conf/conformer.yaml | spec_aug | test-clean | attention_rescoring | 6.433612394332886 | 0.033761 | -## Transformer -train: Epoch 120, 4 V100-32G, 27 Day, best avg: 10 +The performance of the released models are shown in [here](./RESULTS.md). -| Model | Params | Config | Augmentation | Test set | Decode method | Loss | WER | -| ----------- | ------- | --------------------- | ------------ | ---------- | ---------------------- | ----------------- | -------- | -| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention | 6.382194232940674 | 0.049661 | -| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_greedy_search | 6.382194232940674 | 0.049566 | -| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_prefix_beam_search | 6.382194232940674 | 0.049585 | -| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention_rescoring | 6.382194232940674 | 0.038135 | ## Stage 4: CTC Alignment If you want to get the alignment between the audio and the text, you can use the ctc alignment. The code of this stage is shown below: ```bash @@ -227,8 +205,8 @@ In some situations, you want to use the trained model to do the inference for th ``` you can train the model by yourself using ```bash run.sh --stage 0 --stop_stage 3```, or you can download the pretrained model through the script below: ```bash -wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/conformer.model.tar.gz -tar xzvf conformer.model.tar.gz +wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_conformer_librispeech_ckpt_0.1.1.model.tar.gz +tar xzvf asr1_conformer_librispeech_ckpt_0.1.1.model.tar.gz ``` You can download the audio demo: ```bash diff --git a/examples/librispeech/asr2/README.md b/examples/librispeech/asr2/README.md index 7d6fe11df..5bc7185a9 100644 --- a/examples/librispeech/asr2/README.md +++ b/examples/librispeech/asr2/README.md @@ -1,4 +1,4 @@ -# Transformer/Conformer ASR with Librispeech Asr2 +# Transformer/Conformer ASR with Librispeech ASR2 This example contains code used to train a Transformer or [Conformer](http://arxiv.org/abs/2008.03802) model with [Librispeech dataset](http://www.openslr.org/resources/12) and use some functions in kaldi. @@ -213,17 +213,14 @@ avg.sh latest exp/transformer/checkpoints 10 ./local/recog.sh --ckpt_prefix exp/transformer/checkpoints/avg_10 ``` ## Pretrained Model -You can get the pretrained transformer using the scripts below: -```bash -# Transformer: -wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr2/transformer.model.tar.gz -``` +You can get the pretrained models from [this](../../../docs/source/released_model.md). + using the `tar` scripts to unpack the model and then you can use the script to test the model. For example: ```bash -wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr2/transformer.model.tar.gz -tar xzvf transformer.model.tar.gz +wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr2/asr2_transformer_librispeech_ckpt_0.1.1.model.tar.gz +tar xzvf asr2_transformer_librispeech_ckpt_0.1.1.model.tar.gz source path.sh # If you have process the data and get the manifest file, you can skip the following 2 steps bash local/data.sh --stage -1 --stop_stage -1 @@ -231,26 +228,7 @@ bash local/data.sh --stage 2 --stop_stage 2 CUDA_VISIBLE_DEVICES= ./local/test.sh conf/transformer.yaml exp/ctc/checkpoints/avg_10 ``` -The performance of the released models are shown below: -### Transformer -| Model | Params | GPUS | Averaged Model | Config | Augmentation | Loss | -| :---------: | :----: | :--------------------: | :--------------: | :-------------------: | :----------: | :-------------: | -| transformer | 32.52M | 8 Tesla V100-SXM2-32GB | 10-best val_loss | conf/transformer.yaml | spec_aug | 6.3197922706604 | - -#### Attention Rescore -| Test Set | Decode Method | #Snt | #Wrd | Corr | Sub | Del | Ins | Err | S.Err | -| ---------- | --------------------- | ---- | ----- | ---- | ---- | ---- | ---- | ---- | ----- | -| test-clean | attention | 2620 | 52576 | 96.4 | 2.5 | 1.1 | 0.4 | 4.0 | 34.7 | -| test-clean | ctc_greedy_search | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 48.0 | -| test-clean | ctc_prefix_beamsearch | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 47.6 | -| test-clean | attention_rescore | 2620 | 52576 | 96.8 | 2.9 | 0.3 | 0.4 | 3.7 | 38.0 | - -#### JoinCTC -| Test Set | Decode Method | #Snt | #Wrd | Corr | Sub | Del | Ins | Err | S.Err | -| ---------- | ----------------- | ---- | ----- | ---- | ---- | ---- | ---- | ---- | ----- | -| test-clean | join_ctc_only_att | 2620 | 52576 | 96.1 | 2.5 | 1.4 | 0.4 | 4.4 | 34.7 | -| test-clean | join_ctc_w/o_lm | 2620 | 52576 | 97.2 | 2.6 | 0.3 | 0.4 | 3.2 | 34.9 | -| test-clean | join_ctc_w_lm | 2620 | 52576 | 97.9 | 1.8 | 0.2 | 0.3 | 2.4 | 27.8 | +The performance of the released models are shown [here](./RESULTS.md). Compare with [ESPNET](https://github.com/espnet/espnet/blob/master/egs/librispeech/asr1/RESULTS.md#pytorch-large-transformer-with-specaug-4-gpus--transformer-lm-4-gpus) we using 8gpu, but the model size (aheads4-adim256) small than it. ## Stage 5: CTC Alignment diff --git a/examples/ljspeech/tts1/README.md b/examples/ljspeech/tts1/README.md index 4f7680e84..7f32522ac 100644 --- a/examples/ljspeech/tts1/README.md +++ b/examples/ljspeech/tts1/README.md @@ -171,7 +171,8 @@ optional arguments: 6. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. ## Pretrained Model -Pretrained Model can be downloaded here. [transformer_tts_ljspeech_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/transformer_tts/transformer_tts_ljspeech_ckpt_0.4.zip) +Pretrained Model can be downloaded here: +- [transformer_tts_ljspeech_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/transformer_tts/transformer_tts_ljspeech_ckpt_0.4.zip) TransformerTTS checkpoint contains files listed below. ```text diff --git a/examples/ljspeech/tts3/README.md b/examples/ljspeech/tts3/README.md index f5e919c0f..e028fa05d 100644 --- a/examples/ljspeech/tts3/README.md +++ b/examples/ljspeech/tts3/README.md @@ -214,7 +214,8 @@ optional arguments: 9. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. ## Pretrained Model -Pretrained FastSpeech2 model with no silence in the edge of audios. [fastspeech2_nosil_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip) +Pretrained FastSpeech2 model with no silence in the edge of audios: +- [fastspeech2_nosil_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip) Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/pitch_loss| eval/energy_loss :-------------:| :------------:| :-----: | :-----: | :--------: |:--------:|:---------: diff --git a/examples/ljspeech/voc0/README.md b/examples/ljspeech/voc0/README.md index 13a50efb5..41b08d57f 100644 --- a/examples/ljspeech/voc0/README.md +++ b/examples/ljspeech/voc0/README.md @@ -50,4 +50,5 @@ Synthesize waveform. 6. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. ## Pretrained Model -Pretrained Model with residual channel equals 128 can be downloaded here. [waveflow_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/waveflow/waveflow_ljspeech_ckpt_0.3.zip). +Pretrained Model with residual channel equals 128 can be downloaded here: +- [waveflow_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/waveflow/waveflow_ljspeech_ckpt_0.3.zip) diff --git a/examples/ljspeech/voc1/README.md b/examples/ljspeech/voc1/README.md index 6fcb2a520..4513b2a05 100644 --- a/examples/ljspeech/voc1/README.md +++ b/examples/ljspeech/voc1/README.md @@ -127,7 +127,8 @@ optional arguments: 5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. ## Pretrained Model -Pretrained models can be downloaded here. [pwg_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip) +Pretrained models can be downloaded here: +- [pwg_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip) Parallel WaveGAN checkpoint contains files listed below. diff --git a/examples/ljspeech/voc5/README.md b/examples/ljspeech/voc5/README.md index 9fbb9f746..9b31e2650 100644 --- a/examples/ljspeech/voc5/README.md +++ b/examples/ljspeech/voc5/README.md @@ -127,7 +127,8 @@ optional arguments: 5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. ## Pretrained Model -The pretrained model can be downloaded here [hifigan_ljspeech_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_ckpt_0.2.0.zip). +The pretrained model can be downloaded here: +- [hifigan_ljspeech_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_ckpt_0.2.0.zip) Model | Step | eval/generator_loss | eval/mel_loss| eval/feature_matching_loss @@ -143,6 +144,5 @@ hifigan_ljspeech_ckpt_0.2.0 └── snapshot_iter_2500000.pdz # generator parameters of hifigan ``` - ## Acknowledgement We adapted some code from https://github.com/kan-bayashi/ParallelWaveGAN. diff --git a/examples/vctk/tts3/README.md b/examples/vctk/tts3/README.md index 157949d1f..f373ca6a3 100644 --- a/examples/vctk/tts3/README.md +++ b/examples/vctk/tts3/README.md @@ -217,7 +217,8 @@ optional arguments: 9. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. ## Pretrained Model -Pretrained FastSpeech2 model with no silence in the edge of audios. [fastspeech2_nosil_vctk_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_vctk_ckpt_0.5.zip) +Pretrained FastSpeech2 model with no silence in the edge of audios: +- [fastspeech2_nosil_vctk_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_vctk_ckpt_0.5.zip) FastSpeech2 checkpoint contains files listed below. ```text diff --git a/examples/vctk/voc1/README.md b/examples/vctk/voc1/README.md index 4714f28dc..1c3016f88 100644 --- a/examples/vctk/voc1/README.md +++ b/examples/vctk/voc1/README.md @@ -132,7 +132,8 @@ optional arguments: 5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. ## Pretrained Model -Pretrained models can be downloaded here [pwg_vctk_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.1.1.zip). +Pretrained models can be downloaded here: +- [pwg_vctk_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.1.1.zip) Parallel WaveGAN checkpoint contains files listed below. diff --git a/examples/vctk/voc5/README.md b/examples/vctk/voc5/README.md index b4be341c0..4eb25c02d 100644 --- a/examples/vctk/voc5/README.md +++ b/examples/vctk/voc5/README.md @@ -133,7 +133,8 @@ optional arguments: 5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. ## Pretrained Model -The pretrained model can be downloaded here [hifigan_vctk_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip). +The pretrained model can be downloaded here: +- [hifigan_vctk_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip) Model | Step | eval/generator_loss | eval/mel_loss| eval/feature_matching_loss diff --git a/examples/voxceleb/sv0/RESULT.md b/examples/voxceleb/sv0/RESULT.md index c37bcecef..3a3f67d09 100644 --- a/examples/voxceleb/sv0/RESULT.md +++ b/examples/voxceleb/sv0/RESULT.md @@ -4,4 +4,4 @@ | Model | Number of Params | Release | Config | dim | Test set | Cosine | Cosine + S-Norm | | --- | --- | --- | --- | --- | --- | --- | ---- | -| ECAPA-TDNN | 85M | 0.1.1 | conf/ecapa_tdnn.yaml |192 | test | 1.15 | 1.06 | +| ECAPA-TDNN | 85M | 0.2.0 | conf/ecapa_tdnn.yaml |192 | test | 1.02 | 0.95 | diff --git a/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml b/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml index e58dca82d..4715c5a3c 100644 --- a/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml +++ b/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml @@ -1,14 +1,16 @@ ########################################### # Data # ########################################### -# we should explicitly specify the wav path of vox2 audio data converted from m4a -vox2_base_path: augment: True -batch_size: 16 +batch_size: 32 num_workers: 2 -num_speakers: 7205 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41 +num_speakers: 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41 shuffle: True +skip_prep: False +split_ratio: 0.9 +chunk_duration: 3.0 # seconds random_chunk: True +verification_file: data/vox1/veri_test2.txt ########################################################### # FEATURE EXTRACTION SETTING # @@ -26,7 +28,6 @@ hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160 # if we want use another model, please choose another configuration yaml file model: input_size: 80 - # "channels": [512, 512, 512, 512, 1536], channels: [1024, 1024, 1024, 1024, 3072] kernel_sizes: [5, 3, 3, 3, 1] dilations: [1, 2, 3, 4, 1] @@ -38,8 +39,8 @@ model: ########################################### seed: 1986 # according from speechbrain configuration epochs: 10 -save_interval: 1 -log_interval: 1 +save_interval: 10 +log_interval: 10 learning_rate: 1e-8 diff --git a/examples/voxceleb/sv0/conf/ecapa_tdnn_small.yaml b/examples/voxceleb/sv0/conf/ecapa_tdnn_small.yaml new file mode 100644 index 000000000..5ad5ea285 --- /dev/null +++ b/examples/voxceleb/sv0/conf/ecapa_tdnn_small.yaml @@ -0,0 +1,53 @@ +########################################### +# Data # +########################################### +augment: True +batch_size: 16 +num_workers: 2 +num_speakers: 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41 +shuffle: True +skip_prep: False +split_ratio: 0.9 +chunk_duration: 3.0 # seconds +random_chunk: True +verification_file: data/vox1/veri_test2.txt + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +# currently, we only support fbank +sr: 16000 # sample rate +n_mels: 80 +window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400 +hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160 + +########################################################### +# MODEL SETTING # +########################################################### +# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml +# if we want use another model, please choose another configuration yaml file +model: + input_size: 80 + channels: [512, 512, 512, 512, 1536] + kernel_sizes: [5, 3, 3, 3, 1] + dilations: [1, 2, 3, 4, 1] + attention_channels: 128 + lin_neurons: 192 + +########################################### +# Training # +########################################### +seed: 1986 # according from speechbrain configuration +epochs: 100 +save_interval: 10 +log_interval: 10 +learning_rate: 1e-8 + + +########################################### +# Testing # +########################################### +global_embedding_norm: True +embedding_mean_norm: True +embedding_std_norm: False + diff --git a/examples/voxceleb/sv0/local/data.sh b/examples/voxceleb/sv0/local/data.sh index a3ff1c486..d6010ec66 100755 --- a/examples/voxceleb/sv0/local/data.sh +++ b/examples/voxceleb/sv0/local/data.sh @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -stage=1 +stage=0 stop_stage=100 . ${MAIN_ROOT}/utils/parse_options.sh || exit -1; @@ -30,29 +30,114 @@ dir=$1 conf_path=$2 mkdir -p ${dir} -if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then - # data prepare for vox1 and vox2, vox2 must be converted from m4a to wav - # we should use the local/convert.sh convert m4a to wav - python3 local/data_prepare.py \ - --data-dir ${dir} \ - --config ${conf_path} -fi - +# Generally the `MAIN_ROOT` refers to the root of PaddleSpeech, +# which is defined in the path.sh +# And we will download the voxceleb data and rirs noise to ${MAIN_ROOT}/dataset TARGET_DIR=${MAIN_ROOT}/dataset mkdir -p ${TARGET_DIR} if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - # download data, generate manifests - python3 ${TARGET_DIR}/voxceleb/voxceleb1.py \ - --manifest_prefix="data/vox1/manifest" \ + # download data, generate manifests + # we will generate the manifest.{dev,test} file from ${TARGET_DIR}/voxceleb/vox1/{dev,test} directory + # and generate the meta info and download the trial file + # manifest.dev: 148642 + # manifest.test: 4847 + echo "Start to download vox1 dataset and generate the manifest files " + python3 ${TARGET_DIR}/voxceleb/voxceleb1.py \ + --manifest_prefix="${dir}/vox1/manifest" \ --target_dir="${TARGET_DIR}/voxceleb/vox1/" - if [ $? -ne 0 ]; then - echo "Prepare voxceleb failed. Terminated." - exit 1 - fi + if [ $? -ne 0 ]; then + echo "Prepare voxceleb1 failed. Terminated." + exit 1 + fi + +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # download voxceleb2 data + # we will download the data and unzip the package + # and we will store the m4a file in ${TARGET_DIR}/voxceleb/vox2/{dev,test} + echo "start to download vox2 dataset" + python3 ${TARGET_DIR}/voxceleb/voxceleb2.py \ + --download \ + --target_dir="${TARGET_DIR}/voxceleb/vox2/" + + if [ $? -ne 0 ]; then + echo "Download voxceleb2 dataset failed. Terminated." + exit 1 + fi + +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # convert the m4a to wav + # and we will not delete the original m4a file + echo "start to convert the m4a to wav" + bash local/convert.sh ${TARGET_DIR}/voxceleb/vox2/test/ || exit 1; + + if [ $? -ne 0 ]; then + echo "Convert voxceleb2 dataset from m4a to wav failed. Terminated." + exit 1 + fi + echo "m4a convert to wav operation finished" +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # generate the vox2 manifest file from wav file + # we will generate the ${dir}/vox2/manifest.vox2 + # because we use all the vox2 dataset to train, so collect all the vox2 data in one file + echo "start generate the vox2 manifest files" + python3 ${TARGET_DIR}/voxceleb/voxceleb2.py \ + --generate \ + --manifest_prefix="${dir}/vox2/manifest" \ + --target_dir="${TARGET_DIR}/voxceleb/vox2/" - # for dataset in train dev test; do - # mv data/manifest.${dataset} data/manifest.${dataset}.raw - # done -fi \ No newline at end of file + if [ $? -ne 0 ]; then + echo "Prepare voxceleb2 dataset failed. Terminated." + exit 1 + fi +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + # generate the vox csv file + # Currently, our training system use csv file for dataset + echo "convert the json format to csv format to be compatible with training process" + python3 local/make_vox_csv_dataset_from_json.py\ + --train "${dir}/vox1/manifest.dev" "${dir}/vox2/manifest.vox2"\ + --test "${dir}/vox1/manifest.test" \ + --target_dir "${dir}/vox/" \ + --config ${conf_path} + + if [ $? -ne 0 ]; then + echo "Prepare voxceleb failed. Terminated." + exit 1 + fi +fi + +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + # generate the open rir noise manifest file + echo "generate the open rir noise manifest file" + python3 ${TARGET_DIR}/rir_noise/rir_noise.py\ + --manifest_prefix="${dir}/rir_noise/manifest" \ + --target_dir="${TARGET_DIR}/rir_noise/" + + if [ $? -ne 0 ]; then + echo "Prepare rir_noise failed. Terminated." + exit 1 + fi +fi + +if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then + # generate the open rir noise manifest file + echo "generate the open rir noise csv file" + python3 local/make_rirs_noise_csv_dataset_from_json.py \ + --noise_dir="${TARGET_DIR}/rir_noise/" \ + --data_dir="${dir}/rir_noise/" \ + --config ${conf_path} + + if [ $? -ne 0 ]; then + echo "Prepare rir_noise failed. Terminated." + exit 1 + fi +fi diff --git a/examples/voxceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py b/examples/voxceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py new file mode 100644 index 000000000..b25a9d49a --- /dev/null +++ b/examples/voxceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py @@ -0,0 +1,167 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Convert the PaddleSpeech jsonline format data to csv format data in voxceleb experiment. +Currently, Speaker Identificaton Training process use csv format. +""" +import argparse +import csv +import os +from typing import List + +import tqdm +from yacs.config import CfgNode + +from paddleaudio import load as load_audio +from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.utils.vector_utils import get_chunks + +logger = Log(__name__).getlog() + + +def get_chunks_list(wav_file: str, + split_chunks: bool, + base_path: str, + chunk_duration: float=3.0) -> List[List[str]]: + """Get the single audio file info + + Args: + wav_file (list): the wav audio file and get this audio segment info list + split_chunks (bool): audio split flag + base_path (str): the audio base path + chunk_duration (float): the chunk duration. + if set the split_chunks, we split the audio into multi-chunks segment. + """ + waveform, sr = load_audio(wav_file) + audio_id = wav_file.split("/rir_noise/")[-1].split(".")[0] + audio_duration = waveform.shape[0] / sr + + ret = [] + if split_chunks and audio_duration > chunk_duration: # Split into pieces of self.chunk_duration seconds. + uniq_chunks_list = get_chunks(chunk_duration, audio_id, audio_duration) + + for idx, chunk in enumerate(uniq_chunks_list): + s, e = chunk.split("_")[-2:] # Timestamps of start and end + start_sample = int(float(s) * sr) + end_sample = int(float(e) * sr) + + # currently, all vector csv data format use one representation + # id, duration, wav, start, stop, label + # in rirs noise, all the label name is 'noise' + # the label is string type and we will convert it to integer type in training + ret.append([ + chunk, audio_duration, wav_file, start_sample, end_sample, + "noise" + ]) + else: # Keep whole audio. + ret.append( + [audio_id, audio_duration, wav_file, 0, waveform.shape[0], "noise"]) + return ret + + +def generate_csv(wav_files, + output_file: str, + base_path: str, + split_chunks: bool=True): + """Prepare the csv file according the wav files + + Args: + wav_files (list): all the audio list to prepare the csv file + output_file (str): the output csv file + config (CfgNode): yaml configuration content + split_chunks (bool): audio split flag + """ + logger.info(f'Generating csv: {output_file}') + header = ["utt_id", "duration", "wav", "start", "stop", "label"] + csv_lines = [] + for item in tqdm.tqdm(wav_files): + csv_lines.extend( + get_chunks_list( + item, base_path=base_path, split_chunks=split_chunks)) + + if not os.path.exists(os.path.dirname(output_file)): + os.makedirs(os.path.dirname(output_file)) + + with open(output_file, mode="w") as csv_f: + csv_writer = csv.writer( + csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL) + csv_writer.writerow(header) + for line in csv_lines: + csv_writer.writerow(line) + + +def prepare_data(args, config): + """Convert the jsonline format to csv format + + Args: + args (argparse.Namespace): scripts args + config (CfgNode): yaml configuration content + """ + # if external config set the skip_prep flat, we will do nothing + if config.skip_prep: + return + + base_path = args.noise_dir + wav_path = os.path.join(base_path, "RIRS_NOISES") + logger.info(f"base path: {base_path}") + logger.info(f"wav path: {wav_path}") + rir_list = os.path.join(wav_path, "real_rirs_isotropic_noises", "rir_list") + rir_files = [] + with open(rir_list, 'r') as f: + for line in f.readlines(): + rir_file = line.strip().split(' ')[-1] + rir_files.append(os.path.join(base_path, rir_file)) + + noise_list = os.path.join(wav_path, "pointsource_noises", "noise_list") + noise_files = [] + with open(noise_list, 'r') as f: + for line in f.readlines(): + noise_file = line.strip().split(' ')[-1] + noise_files.append(os.path.join(base_path, noise_file)) + + csv_path = os.path.join(args.data_dir, 'csv') + logger.info(f"csv path: {csv_path}") + generate_csv( + rir_files, os.path.join(csv_path, 'rir.csv'), base_path=base_path) + generate_csv( + noise_files, os.path.join(csv_path, 'noise.csv'), base_path=base_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--noise_dir", + default=None, + required=True, + help="The noise dataset dataset directory.") + parser.add_argument( + "--data_dir", + default=None, + required=True, + help="The target directory stores the csv files") + parser.add_argument( + "--config", + default=None, + required=True, + type=str, + help="configuration file") + args = parser.parse_args() + + # parse the yaml config file + config = CfgNode(new_allowed=True) + if args.config: + config.merge_from_file(args.config) + + # prepare the csv file from jsonlines files + prepare_data(args, config) diff --git a/examples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py b/examples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py new file mode 100644 index 000000000..4e64c3067 --- /dev/null +++ b/examples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py @@ -0,0 +1,251 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Convert the PaddleSpeech jsonline format data to csv format data in voxceleb experiment. +Currently, Speaker Identificaton Training process use csv format. +""" +import argparse +import csv +import json +import os +import random + +import tqdm +from yacs.config import CfgNode + +from paddleaudio import load as load_audio +from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.utils.vector_utils import get_chunks + +logger = Log(__name__).getlog() + + +def prepare_csv(wav_files, output_file, config, split_chunks=True): + """Prepare the csv file according the wav files + + Args: + wav_files (list): all the audio list to prepare the csv file + output_file (str): the output csv file + config (CfgNode): yaml configuration content + split_chunks (bool, optional): audio split flag. Defaults to True. + """ + if not os.path.exists(os.path.dirname(output_file)): + os.makedirs(os.path.dirname(output_file)) + csv_lines = [] + header = ["utt_id", "duration", "wav", "start", "stop", "label"] + # voxceleb meta info for each training utterance segment + # we extract a segment from a utterance to train + # and the segment' period is between start and stop time point in the original wav file + # each field in the meta info means as follows: + # utt_id: the utterance segment name, which is uniq in training dataset + # duration: the total utterance time + # wav: utterance file path, which should be absoulute path + # start: start point in the original wav file sample point range + # stop: stop point in the original wav file sample point range + # label: the utterance segment's label name, + # which is speaker name in speaker verification domain + for item in tqdm.tqdm(wav_files, total=len(wav_files)): + item = json.loads(item.strip()) + audio_id = item['utt'].replace(".wav", + "") # we remove the wav suffix name + audio_duration = item['feat_shape'][0] + wav_file = item['feat'] + label = audio_id.split('-')[ + 0] # speaker name in speaker verification domain + waveform, sr = load_audio(wav_file) + if split_chunks: + uniq_chunks_list = get_chunks(config.chunk_duration, audio_id, + audio_duration) + for chunk in uniq_chunks_list: + s, e = chunk.split("_")[-2:] # Timestamps of start and end + start_sample = int(float(s) * sr) + end_sample = int(float(e) * sr) + # id, duration, wav, start, stop, label + # in vector, the label in speaker id + csv_lines.append([ + chunk, audio_duration, wav_file, start_sample, end_sample, + label + ]) + else: + csv_lines.append([ + audio_id, audio_duration, wav_file, 0, waveform.shape[0], label + ]) + + with open(output_file, mode="w") as csv_f: + csv_writer = csv.writer( + csv_f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) + csv_writer.writerow(header) + for line in csv_lines: + csv_writer.writerow(line) + + +def get_enroll_test_list(dataset_list, verification_file): + """Get the enroll and test utterance list from all the voxceleb1 test utterance dataset. + Generally, we get the enroll and test utterances from the verfification file. + The verification file format as follows: + target/nontarget enroll-utt test-utt, + we set 0 as nontarget and 1 as target, eg: + 0 a.wav b.wav + 1 a.wav a.wav + + Args: + dataset_list (list): all the dataset to get the test utterances + verification_file (str): voxceleb1 trial file + """ + logger.info(f"verification file: {verification_file}") + enroll_audios = set() + test_audios = set() + with open(verification_file, 'r') as f: + for line in f: + _, enroll_file, test_file = line.strip().split(' ') + enroll_audios.add('-'.join(enroll_file.split('/'))) + test_audios.add('-'.join(test_file.split('/'))) + + enroll_files = [] + test_files = [] + for dataset in dataset_list: + with open(dataset, 'r') as f: + for line in f: + # audio_id may be in enroll and test at the same time + # eg: 1 a.wav a.wav + # the audio a.wav is enroll and test file at the same time + audio_id = json.loads(line.strip())['utt'] + if audio_id in enroll_audios: + enroll_files.append(line) + if audio_id in test_audios: + test_files.append(line) + + enroll_files = sorted(enroll_files) + test_files = sorted(test_files) + + return enroll_files, test_files + + +def get_train_dev_list(dataset_list, target_dir, split_ratio): + """Get the train and dev utterance list from all the training utterance dataset. + Generally, we use the split_ratio as the train dataset ratio, + and the remaining utterance (ratio is 1 - split_ratio) is the dev dataset + + Args: + dataset_list (list): all the dataset to get the all utterances + target_dir (str): the target train and dev directory, + we will create the csv directory to store the {train,dev}.csv file + split_ratio (float): train dataset ratio in all utterance list + """ + logger.info("start to get train and dev utt list") + if not os.path.exists(os.path.join(target_dir, "meta")): + os.makedirs(os.path.join(target_dir, "meta")) + + audio_files = [] + speakers = set() + for dataset in dataset_list: + with open(dataset, 'r') as f: + for line in f: + # the label is speaker name + label_name = json.loads(line.strip())['utt2spk'] + speakers.add(label_name) + audio_files.append(line.strip()) + speakers = sorted(speakers) + logger.info(f"we get {len(speakers)} speakers from all the train dataset") + + with open(os.path.join(target_dir, "meta", "label2id.txt"), 'w') as f: + for label_id, label_name in enumerate(speakers): + f.write(f'{label_name} {label_id}\n') + logger.info( + f'we store the speakers to {os.path.join(target_dir, "meta", "label2id.txt")}' + ) + + # the split_ratio is for train dataset + # the remaining is for dev dataset + split_idx = int(split_ratio * len(audio_files)) + audio_files = sorted(audio_files) + random.shuffle(audio_files) + train_files, dev_files = audio_files[:split_idx], audio_files[split_idx:] + logger.info( + f"we get train utterances: {len(train_files)}, dev utterance: {len(dev_files)}" + ) + return train_files, dev_files + + +def prepare_data(args, config): + """Convert the jsonline format to csv format + + Args: + args (argparse.Namespace): scripts args + config (CfgNode): yaml configuration content + """ + # stage0: set the random seed + random.seed(config.seed) + + # if external config set the skip_prep flat, we will do nothing + if config.skip_prep: + return + + # stage 1: prepare the enroll and test csv file + # And we generate the speaker to label file label2id.txt + logger.info("start to prepare the data csv file") + enroll_files, test_files = get_enroll_test_list( + [args.test], verification_file=config.verification_file) + prepare_csv( + enroll_files, + os.path.join(args.target_dir, "csv", "enroll.csv"), + config, + split_chunks=False) + prepare_csv( + test_files, + os.path.join(args.target_dir, "csv", "test.csv"), + config, + split_chunks=False) + + # stage 2: prepare the train and dev csv file + # we get the train dataset ratio as config.split_ratio + # and the remaining is dev dataset + logger.info("start to prepare the data csv file") + train_files, dev_files = get_train_dev_list( + args.train, target_dir=args.target_dir, split_ratio=config.split_ratio) + prepare_csv(train_files, + os.path.join(args.target_dir, "csv", "train.csv"), config) + prepare_csv(dev_files, + os.path.join(args.target_dir, "csv", "dev.csv"), config) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--train", + required=True, + nargs='+', + help="The jsonline files list for train.") + parser.add_argument( + "--test", required=True, help="The jsonline file for test") + parser.add_argument( + "--target_dir", + default=None, + required=True, + help="The target directory stores the csv files and meta file.") + parser.add_argument( + "--config", + default=None, + required=True, + type=str, + help="configuration file") + args = parser.parse_args() + + # parse the yaml config file + config = CfgNode(new_allowed=True) + if args.config: + config.merge_from_file(args.config) + + # prepare the csv file from jsonlines files + prepare_data(args, config) diff --git a/examples/voxceleb/sv0/run.sh b/examples/voxceleb/sv0/run.sh index bbc9e3dbb..e1dccf2ae 100755 --- a/examples/voxceleb/sv0/run.sh +++ b/examples/voxceleb/sv0/run.sh @@ -18,24 +18,22 @@ set -e ####################################################################### # stage 0: data prepare, including voxceleb1 download and generate {train,dev,enroll,test}.csv -# voxceleb2 data is m4a format, so we need user to convert the m4a to wav yourselves as described in Readme.md with the script local/convert.sh +# voxceleb2 data is m4a format, so we need convert the m4a to wav yourselves with the script local/convert.sh # stage 1: train the speaker identification model # stage 2: test speaker identification -# stage 3: extract the training embeding to train the LDA and PLDA +# stage 3: (todo)extract the training embeding to train the LDA and PLDA ###################################################################### -# we can set the variable PPAUDIO_HOME to specifiy the root directory of the downloaded vox1 and vox2 dataset -# default the dataset will be stored in the ~/.paddleaudio/ # the vox2 dataset is stored in m4a format, we need to convert the audio from m4a to wav yourself -# and put all of them to ${PPAUDIO_HOME}/datasets/vox2 -# we will find the wav from ${PPAUDIO_HOME}/datasets/vox1/wav and ${PPAUDIO_HOME}/datasets/vox2/wav -# export PPAUDIO_HOME= +# and put all of them to ${MAIN_ROOT}/datasets/vox2 +# we will find the wav from ${MAIN_ROOT}/datasets/vox1/{dev,test}/wav and ${MAIN_ROOT}/datasets/vox2/wav + stage=0 stop_stage=50 # data directory # if we set the variable ${dir}, we will store the wav info to this directory -# otherwise, we will store the wav info to vox1 and vox2 directory respectively +# otherwise, we will store the wav info to data/vox1 and data/vox2 directory respectively # vox2 wav path, we must convert the m4a format to wav format dir=data/ # data info directory @@ -64,6 +62,6 @@ if [ $stage -le 2 ] && [ ${stop_stage} -ge 2 ]; then fi # if [ $stage -le 3 ]; then -# # stage 2: extract the training embeding to train the LDA and PLDA +# # stage 3: extract the training embeding to train the LDA and PLDA # # todo: extract the training embedding # fi diff --git a/paddleaudio/paddleaudio/datasets/voxceleb.py b/paddleaudio/paddleaudio/datasets/voxceleb.py index 3f72b5f2e..07f44e0c1 100644 --- a/paddleaudio/paddleaudio/datasets/voxceleb.py +++ b/paddleaudio/paddleaudio/datasets/voxceleb.py @@ -261,7 +261,7 @@ class VoxCeleb(Dataset): output_file: str, split_chunks: bool=True): print(f'Generating csv: {output_file}') - header = ["ID", "duration", "wav", "start", "stop", "spk_id"] + header = ["id", "duration", "wav", "start", "stop", "spk_id"] # Note: this may occurs c++ execption, but the program will execute fine # so we can ignore the execption with Pool(cpu_count()) as p: diff --git a/paddleaudio/paddleaudio/utils/numeric.py b/paddleaudio/paddleaudio/utils/numeric.py new file mode 100644 index 000000000..126cada50 --- /dev/null +++ b/paddleaudio/paddleaudio/utils/numeric.py @@ -0,0 +1,30 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + + +def pcm16to32(audio: np.ndarray) -> np.ndarray: + """pcm int16 to float32 + + Args: + audio (np.ndarray): Waveform with dtype of int16. + + Returns: + np.ndarray: Waveform with dtype of float32. + """ + if audio.dtype == np.int16: + audio = audio.astype("float32") + bits = np.iinfo(np.int16).bits + audio = audio / (2**(bits - 1)) + return audio diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 1fb4be434..b12b9f6fc 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -80,9 +80,9 @@ pretrained_models = { }, "deepspeech2online_aishell-zh-16k": { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz', 'md5': - 'd5e076217cf60486519f72c217d21b9b', + '23e16c69730a1cb5d735c98c83c21e16', 'cfg_path': 'model.yaml', 'ckpt_path': @@ -426,6 +426,11 @@ class ASRExecutor(BaseExecutor): try: audio, audio_sample_rate = soundfile.read( audio_file, dtype="int16", always_2d=True) + audio_duration = audio.shape[0] / audio_sample_rate + max_duration = 50.0 + if audio_duration >= max_duration: + logger.error("Please input audio file less then 50 seconds.\n") + return except Exception as e: logger.exception(e) logger.error( diff --git a/paddlespeech/cli/vector/infer.py b/paddlespeech/cli/vector/infer.py index 175a9723e..68e832ac7 100644 --- a/paddlespeech/cli/vector/infer.py +++ b/paddlespeech/cli/vector/infer.py @@ -15,6 +15,7 @@ import argparse import os import sys from collections import OrderedDict +from typing import Dict from typing import List from typing import Optional from typing import Union @@ -42,9 +43,9 @@ pretrained_models = { # "paddlespeech vector --task spk --model ecapatdnn_voxceleb12-16k --sr 16000 --input ./input.wav" "ecapatdnn_voxceleb12-16k": { 'url': - 'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz', + 'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz', 'md5': - 'a1c0dba7d4de997187786ff517d5b4ec', + 'cc33023c54ab346cd318408f43fcaf95', 'cfg_path': 'conf/model.yaml', # the yaml config path 'ckpt_path': @@ -79,7 +80,7 @@ class VectorExecutor(BaseExecutor): "--task", type=str, default="spk", - choices=["spk"], + choices=["spk", "score"], help="task type in vector domain") self.parser.add_argument( "--input", @@ -147,13 +148,40 @@ class VectorExecutor(BaseExecutor): logger.info(f"task source: {task_source}") # stage 3: process the audio one by one + # we do action according the task type task_result = OrderedDict() has_exceptions = False for id_, input_ in task_source.items(): try: - res = self(input_, model, sample_rate, config, ckpt_path, - device) - task_result[id_] = res + # extract the speaker audio embedding + if parser_args.task == "spk": + logger.info("do vector spk task") + res = self(input_, model, sample_rate, config, ckpt_path, + device) + task_result[id_] = res + elif parser_args.task == "score": + logger.info("do vector score task") + logger.info(f"input content {input_}") + if len(input_.split()) != 2: + logger.error( + f"vector score task input {input_} wav num is not two," + "that is {len(input_.split())}") + sys.exit(-1) + + # get the enroll and test embedding + enroll_audio, test_audio = input_.split() + logger.info( + f"score task, enroll audio: {enroll_audio}, test audio: {test_audio}" + ) + enroll_embedding = self(enroll_audio, model, sample_rate, + config, ckpt_path, device) + test_embedding = self(test_audio, model, sample_rate, + config, ckpt_path, device) + + # get the score + res = self.get_embeddings_score(enroll_embedding, + test_embedding) + task_result[id_] = res except Exception as e: has_exceptions = True task_result[id_] = f'{e.__class__.__name__}: {e}' @@ -172,6 +200,49 @@ class VectorExecutor(BaseExecutor): else: return True + def _get_job_contents( + self, job_input: os.PathLike) -> Dict[str, Union[str, os.PathLike]]: + """ + Read a job input file and return its contents in a dictionary. + Refactor from the Executor._get_job_contents + + Args: + job_input (os.PathLike): The job input file. + + Returns: + Dict[str, str]: Contents of job input. + """ + job_contents = OrderedDict() + with open(job_input) as f: + for line in f: + line = line.strip() + if not line: + continue + k = line.split(' ')[0] + v = ' '.join(line.split(' ')[1:]) + job_contents[k] = v + return job_contents + + def get_embeddings_score(self, enroll_embedding, test_embedding): + """get the enroll embedding and test embedding score + + Args: + enroll_embedding (numpy.array): shape: (emb_size), enroll audio embedding + test_embedding (numpy.array): shape: (emb_size), test audio embedding + + Returns: + score: the score between enroll embedding and test embedding + """ + if not hasattr(self, "score_func"): + self.score_func = paddle.nn.CosineSimilarity(axis=0) + logger.info("create the cosine score function ") + + score = self.score_func( + paddle.to_tensor(enroll_embedding), + paddle.to_tensor(test_embedding)) + + return score.item() + @stats_wrapper def __call__(self, audio_file: os.PathLike, diff --git a/paddlespeech/server/bin/paddlespeech_server.py b/paddlespeech/server/bin/paddlespeech_server.py index f6a7f4295..474a8b79f 100644 --- a/paddlespeech/server/bin/paddlespeech_server.py +++ b/paddlespeech/server/bin/paddlespeech_server.py @@ -23,8 +23,9 @@ from ..util import cli_server_register from ..util import stats_wrapper from paddlespeech.cli.log import logger from paddlespeech.server.engine.engine_pool import init_engine_pool -from paddlespeech.server.restful.api import setup_router +from paddlespeech.server.restful.api import setup_router as setup_http_router from paddlespeech.server.utils.config import get_config +from paddlespeech.server.ws.api import setup_router as setup_ws_router __all__ = ['ServerExecutor', 'ServerStatsExecutor'] @@ -63,7 +64,12 @@ class ServerExecutor(BaseExecutor): """ # init api api_list = list(engine.split("_")[0] for engine in config.engine_list) - api_router = setup_router(api_list) + if config.protocol == "websocket": + api_router = setup_ws_router(api_list) + elif config.protocol == "http": + api_router = setup_http_router(api_list) + else: + raise Exception("unsupported protocol") app.include_router(api_router) if not init_engine_pool(config): diff --git a/paddlespeech/server/conf/tts_online_application.yaml b/paddlespeech/server/conf/tts_online_application.yaml new file mode 100644 index 000000000..a80b3ecec --- /dev/null +++ b/paddlespeech/server/conf/tts_online_application.yaml @@ -0,0 +1,46 @@ +# This is the parameter configuration file for PaddleSpeech Serving. + +################################################################################# +# SERVER SETTING # +################################################################################# +host: 127.0.0.1 +port: 8092 + +# The task format in the engin_list is: _ +# task choices = ['asr_online', 'tts_online'] +# protocol = ['websocket', 'http'] (only one can be selected). +protocol: 'http' +engine_list: ['tts_online'] + + +################################################################################# +# ENGINE CONFIG # +################################################################################# + +################################### TTS ######################################### +################### speech task: tts; engine_type: online ####################### +tts_online: + # am (acoustic model) choices=['fastspeech2_csmsc'] + am: 'fastspeech2_csmsc' + am_config: + am_ckpt: + am_stat: + phones_dict: + tones_dict: + speaker_dict: + spk_id: 0 + + # voc (vocoder) choices=['mb_melgan_csmsc'] + voc: 'mb_melgan_csmsc' + voc_config: + voc_ckpt: + voc_stat: + + # others + lang: 'zh' + device: # set 'gpu:id' or 'cpu' + am_block: 42 + am_pad: 12 + voc_block: 14 + voc_pad: 14 + diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 389175a0a..1f356a3c6 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -27,6 +27,7 @@ from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.server.engine.base_engine import BaseEngine +from paddlespeech.server.utils.audio_process import pcm2float from paddlespeech.server.utils.paddle_predictor import init_predictor __all__ = ['ASREngine'] @@ -36,7 +37,7 @@ pretrained_models = { 'url': 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz', 'md5': - 'd5e076217cf60486519f72c217d21b9b', + '23e16c69730a1cb5d735c98c83c21e16', 'cfg_path': 'model.yaml', 'ckpt_path': @@ -222,21 +223,6 @@ class ASRServerExecutor(ASRExecutor): else: raise Exception("invalid model name") - def _pcm16to32(self, audio): - """pcm int16 to float32 - - Args: - audio(numpy.array): numpy.int16 - - Returns: - audio(numpy.array): numpy.float32 - """ - if audio.dtype == np.int16: - audio = audio.astype("float32") - bits = np.iinfo(np.int16).bits - audio = audio / (2**(bits - 1)) - return audio - def extract_feat(self, samples, sample_rate): """extract feat @@ -249,7 +235,7 @@ class ASRServerExecutor(ASRExecutor): x_chunk_lens (numpy.array): shape[B] """ # pcm16 -> pcm 32 - samples = self._pcm16to32(samples) + samples = pcm2float(samples) # read audio speech_segment = SpeechSegment.from_pcm( diff --git a/paddlespeech/server/engine/engine_factory.py b/paddlespeech/server/engine/engine_factory.py index 2a39fb79b..e147a29a6 100644 --- a/paddlespeech/server/engine/engine_factory.py +++ b/paddlespeech/server/engine/engine_factory.py @@ -34,6 +34,9 @@ class EngineFactory(object): elif engine_name == 'tts' and engine_type == 'python': from paddlespeech.server.engine.tts.python.tts_engine import TTSEngine return TTSEngine() + elif engine_name == 'tts' and engine_type == 'online': + from paddlespeech.server.engine.tts.online.tts_engine import TTSEngine + return TTSEngine() elif engine_name == 'cls' and engine_type == 'inference': from paddlespeech.server.engine.cls.paddleinference.cls_engine import CLSEngine return CLSEngine() diff --git a/paddlespeech/server/engine/tts/online/__init__.py b/paddlespeech/server/engine/tts/online/__init__.py new file mode 100644 index 000000000..97043fd7b --- /dev/null +++ b/paddlespeech/server/engine/tts/online/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/tts/online/tts_engine.py b/paddlespeech/server/engine/tts/online/tts_engine.py new file mode 100644 index 000000000..25a8bc76f --- /dev/null +++ b/paddlespeech/server/engine/tts/online/tts_engine.py @@ -0,0 +1,220 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +import time + +import numpy as np +import paddle + +from paddlespeech.cli.log import logger +from paddlespeech.cli.tts.infer import TTSExecutor +from paddlespeech.server.engine.base_engine import BaseEngine +from paddlespeech.server.utils.audio_process import float2pcm +from paddlespeech.server.utils.util import get_chunks + +__all__ = ['TTSEngine'] + + +class TTSServerExecutor(TTSExecutor): + def __init__(self): + super().__init__() + pass + + @paddle.no_grad() + def infer( + self, + text: str, + lang: str='zh', + am: str='fastspeech2_csmsc', + spk_id: int=0, + am_block: int=42, + am_pad: int=12, + voc_block: int=14, + voc_pad: int=14, ): + """ + Model inference and result stored in self.output. + """ + am_name = am[:am.rindex('_')] + am_dataset = am[am.rindex('_') + 1:] + get_tone_ids = False + merge_sentences = False + frontend_st = time.time() + if lang == 'zh': + input_ids = self.frontend.get_input_ids( + text, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + if get_tone_ids: + tone_ids = input_ids["tone_ids"] + elif lang == 'en': + input_ids = self.frontend.get_input_ids( + text, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + else: + print("lang should in {'zh', 'en'}!") + self.frontend_time = time.time() - frontend_st + + for i in range(len(phone_ids)): + am_st = time.time() + part_phone_ids = phone_ids[i] + # am + if am_name == 'speedyspeech': + part_tone_ids = tone_ids[i] + mel = self.am_inference(part_phone_ids, part_tone_ids) + # fastspeech2 + else: + # multi speaker + if am_dataset in {"aishell3", "vctk"}: + mel = self.am_inference( + part_phone_ids, spk_id=paddle.to_tensor(spk_id)) + else: + mel = self.am_inference(part_phone_ids) + am_et = time.time() + + # voc streaming + voc_upsample = self.voc_config.n_shift + mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc") + chunk_num = len(mel_chunks) + voc_st = time.time() + for i, mel_chunk in enumerate(mel_chunks): + sub_wav = self.voc_inference(mel_chunk) + front_pad = min(i * voc_block, voc_pad) + + if i == 0: + sub_wav = sub_wav[:voc_block * voc_upsample] + elif i == chunk_num - 1: + sub_wav = sub_wav[front_pad * voc_upsample:] + else: + sub_wav = sub_wav[front_pad * voc_upsample:( + front_pad + voc_block) * voc_upsample] + + yield sub_wav + + +class TTSEngine(BaseEngine): + """TTS server engine + + Args: + metaclass: Defaults to Singleton. + """ + + def __init__(self, name=None): + """Initialize TTS server engine + """ + super(TTSEngine, self).__init__() + + def init(self, config: dict) -> bool: + self.executor = TTSServerExecutor() + self.config = config + assert "fastspeech2_csmsc" in config.am and ( + config.voc == "hifigan_csmsc-zh" or config.voc == "mb_melgan_csmsc" + ), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.' + try: + if self.config.device: + self.device = self.config.device + else: + self.device = paddle.get_device() + paddle.set_device(self.device) + except Exception as e: + logger.error( + "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" + ) + logger.error("Initialize TTS server engine Failed on device: %s." % + (self.device)) + return False + + try: + self.executor._init_from_path( + am=self.config.am, + am_config=self.config.am_config, + am_ckpt=self.config.am_ckpt, + am_stat=self.config.am_stat, + phones_dict=self.config.phones_dict, + tones_dict=self.config.tones_dict, + speaker_dict=self.config.speaker_dict, + voc=self.config.voc, + voc_config=self.config.voc_config, + voc_ckpt=self.config.voc_ckpt, + voc_stat=self.config.voc_stat, + lang=self.config.lang) + except Exception as e: + logger.error("Failed to get model related files.") + logger.error("Initialize TTS server engine Failed on device: %s." % + (self.device)) + return False + + self.am_block = self.config.am_block + self.am_pad = self.config.am_pad + self.voc_block = self.config.voc_block + self.voc_pad = self.config.voc_pad + + logger.info("Initialize TTS server engine successfully on device: %s." % + (self.device)) + return True + + def preprocess(self, text_bese64: str=None, text_bytes: bytes=None): + # Convert byte to text + if text_bese64: + text_bytes = base64.b64decode(text_bese64) # base64 to bytes + text = text_bytes.decode('utf-8') # bytes to text + + return text + + def run(self, + sentence: str, + spk_id: int=0, + speed: float=1.0, + volume: float=1.0, + sample_rate: int=0, + save_path: str=None): + """ run include inference and postprocess. + + Args: + sentence (str): text to be synthesized + spk_id (int, optional): speaker id for multi-speaker speech synthesis. Defaults to 0. + speed (float, optional): speed. Defaults to 1.0. + volume (float, optional): volume. Defaults to 1.0. + sample_rate (int, optional): target sample rate for synthesized audio, + 0 means the same as the model sampling rate. Defaults to 0. + save_path (str, optional): The save path of the synthesized audio. + None means do not save audio. Defaults to None. + + Returns: + wav_base64: The base64 format of the synthesized audio. + """ + + lang = self.config.lang + wav_list = [] + + for wav in self.executor.infer( + text=sentence, + lang=lang, + am=self.config.am, + spk_id=spk_id, + am_block=self.am_block, + am_pad=self.am_pad, + voc_block=self.voc_block, + voc_pad=self.voc_pad): + # wav type: float32, convert to pcm (base64) + wav = float2pcm(wav) # float32 to int16 + wav_bytes = wav.tobytes() # to bytes + wav_base64 = base64.b64encode(wav_bytes).decode('utf8') # to base64 + wav_list.append(wav) + + yield wav_base64 + + wav_all = np.concatenate(wav_list, axis=0) + logger.info("The durations of audio is: {} s".format( + len(wav_all) / self.executor.am_config.fs)) diff --git a/paddlespeech/server/restful/tts_api.py b/paddlespeech/server/restful/tts_api.py index 4e9bbe23e..d1268428a 100644 --- a/paddlespeech/server/restful/tts_api.py +++ b/paddlespeech/server/restful/tts_api.py @@ -15,6 +15,7 @@ import traceback from typing import Union from fastapi import APIRouter +from fastapi.responses import StreamingResponse from paddlespeech.cli.log import logger from paddlespeech.server.engine.engine_pool import get_engine_pool @@ -125,3 +126,14 @@ def tts(request_body: TTSRequest): traceback.print_exc() return response + + +@router.post("/paddlespeech/streaming/tts") +async def stream_tts(request_body: TTSRequest): + text = request_body.text + + engine_pool = get_engine_pool() + tts_engine = engine_pool['tts'] + logger.info("Get tts engine successfully.") + + return StreamingResponse(tts_engine.run(sentence=text)) diff --git a/paddlespeech/server/tests/tts/test_client.py b/paddlespeech/server/tests/tts/offline/http_client.py similarity index 90% rename from paddlespeech/server/tests/tts/test_client.py rename to paddlespeech/server/tests/tts/offline/http_client.py index e42c9bcfa..1bdee4c18 100644 --- a/paddlespeech/server/tests/tts/test_client.py +++ b/paddlespeech/server/tests/tts/offline/http_client.py @@ -33,7 +33,8 @@ def tts_client(args): text: A sentence to be synthesized outfile: Synthetic audio file """ - url = 'http://127.0.0.1:8090/paddlespeech/tts' + url = "http://" + str(args.server) + ":" + str( + args.port) + "/paddlespeech/tts" request = { "text": args.text, "spk_id": args.spk_id, @@ -72,7 +73,7 @@ if __name__ == "__main__": parser.add_argument( '--text', type=str, - default="你好,欢迎使用语音合成服务", + default="您好,欢迎使用语音合成服务。", help='A sentence to be synthesized') parser.add_argument('--spk_id', type=int, default=0, help='Speaker id') parser.add_argument('--speed', type=float, default=1.0, help='Audio speed') @@ -88,6 +89,9 @@ if __name__ == "__main__": type=str, default="./out.wav", help='Synthesized audio file') + parser.add_argument( + "--server", type=str, help="server ip", default="127.0.0.1") + parser.add_argument("--port", type=int, help="server port", default=8090) args = parser.parse_args() st = time.time() diff --git a/paddlespeech/server/tests/tts/online/http_client.py b/paddlespeech/server/tests/tts/online/http_client.py new file mode 100644 index 000000000..cbc1f5c02 --- /dev/null +++ b/paddlespeech/server/tests/tts/online/http_client.py @@ -0,0 +1,100 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import base64 +import json +import os +import time + +import requests + +from paddlespeech.server.utils.audio_process import pcm2wav + + +def save_audio(buffer, audio_path) -> bool: + if args.save_path.endswith("pcm"): + with open(args.save_path, "wb") as f: + f.write(buffer) + elif args.save_path.endswith("wav"): + with open("./tmp.pcm", "wb") as f: + f.write(buffer) + pcm2wav("./tmp.pcm", audio_path, channels=1, bits=16, sample_rate=24000) + os.system("rm ./tmp.pcm") + else: + print("Only supports saved audio format is pcm or wav") + return False + + return True + + +def test(args): + params = { + "text": args.text, + "spk_id": args.spk_id, + "speed": args.speed, + "volume": args.volume, + "sample_rate": args.sample_rate, + "save_path": '' + } + + buffer = b'' + flag = 1 + url = "http://" + str(args.server) + ":" + str( + args.port) + "/paddlespeech/streaming/tts" + st = time.time() + html = requests.post(url, json.dumps(params), stream=True) + for chunk in html.iter_content(chunk_size=1024): + chunk = base64.b64decode(chunk) # bytes + if flag: + first_response = time.time() - st + print(f"首包响应:{first_response} s") + flag = 0 + buffer += chunk + + final_response = time.time() - st + duration = len(buffer) / 2.0 / 24000 + + print(f"尾包响应:{final_response} s") + print(f"音频时长:{duration} s") + print(f"RTF: {final_response / duration}") + + if args.save_path is not None: + if save_audio(buffer, args.save_path): + print("音频保存至:", args.save_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + '--text', + type=str, + default="您好,欢迎使用语音合成服务。", + help='A sentence to be synthesized') + parser.add_argument('--spk_id', type=int, default=0, help='Speaker id') + parser.add_argument('--speed', type=float, default=1.0, help='Audio speed') + parser.add_argument( + '--volume', type=float, default=1.0, help='Audio volume') + parser.add_argument( + '--sample_rate', + type=int, + default=0, + help='Sampling rate, the default is the same as the model') + parser.add_argument( + "--server", type=str, help="server ip", default="127.0.0.1") + parser.add_argument("--port", type=int, help="server port", default=8092) + parser.add_argument( + "--save_path", type=str, help="save audio path", default=None) + + args = parser.parse_args() + test(args) diff --git a/paddlespeech/server/tests/tts/online/http_client_playaudio.py b/paddlespeech/server/tests/tts/online/http_client_playaudio.py new file mode 100644 index 000000000..1e7e8064e --- /dev/null +++ b/paddlespeech/server/tests/tts/online/http_client_playaudio.py @@ -0,0 +1,112 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import base64 +import json +import threading +import time + +import pyaudio +import requests + +mutex = threading.Lock() +buffer = b'' +p = pyaudio.PyAudio() +stream = p.open( + format=p.get_format_from_width(2), channels=1, rate=24000, output=True) +max_fail = 50 + + +def play_audio(): + global stream + global buffer + global max_fail + while True: + if not buffer: + max_fail -= 1 + time.sleep(0.05) + if max_fail < 0: + break + mutex.acquire() + stream.write(buffer) + buffer = b'' + mutex.release() + + +def test(args): + global mutex + global buffer + params = { + "text": args.text, + "spk_id": args.spk_id, + "speed": args.speed, + "volume": args.volume, + "sample_rate": args.sample_rate, + "save_path": '' + } + + all_bytes = 0.0 + t = threading.Thread(target=play_audio) + flag = 1 + url = "http://" + str(args.server) + ":" + str( + args.port) + "/paddlespeech/streaming/tts" + st = time.time() + html = requests.post(url, json.dumps(params), stream=True) + for chunk in html.iter_content(chunk_size=1024): + mutex.acquire() + chunk = base64.b64decode(chunk) # bytes + buffer += chunk + mutex.release() + if flag: + first_response = time.time() - st + print(f"首包响应:{first_response} s") + flag = 0 + t.start() + all_bytes += len(chunk) + + final_response = time.time() - st + duration = all_bytes / 2 / 24000 + + print(f"尾包响应:{final_response} s") + print(f"音频时长:{duration} s") + print(f"RTF: {final_response / duration}") + + t.join() + stream.stop_stream() + stream.close() + p.terminate() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + '--text', + type=str, + default="您好,欢迎使用语音合成服务。", + help='A sentence to be synthesized') + parser.add_argument('--spk_id', type=int, default=0, help='Speaker id') + parser.add_argument('--speed', type=float, default=1.0, help='Audio speed') + parser.add_argument( + '--volume', type=float, default=1.0, help='Audio volume') + parser.add_argument( + '--sample_rate', + type=int, + default=0, + help='Sampling rate, the default is the same as the model') + parser.add_argument( + "--server", type=str, help="server ip", default="127.0.0.1") + parser.add_argument("--port", type=int, help="server port", default=8092) + + args = parser.parse_args() + test(args) diff --git a/paddlespeech/server/tests/tts/online/ws_client.py b/paddlespeech/server/tests/tts/online/ws_client.py new file mode 100644 index 000000000..eef010cf2 --- /dev/null +++ b/paddlespeech/server/tests/tts/online/ws_client.py @@ -0,0 +1,126 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import _thread as thread +import argparse +import base64 +import json +import ssl +import time + +import websocket + +flag = 1 +st = 0.0 +all_bytes = b'' + + +class WsParam(object): + # 初始化 + def __init__(self, text, server="127.0.0.1", port=8090): + self.server = server + self.port = port + self.url = "ws://" + self.server + ":" + str(self.port) + "/ws/tts" + self.text = text + + # 生成url + def create_url(self): + return self.url + + +def on_message(ws, message): + global flag + global st + global all_bytes + + try: + message = json.loads(message) + audio = message["audio"] + audio = base64.b64decode(audio) # bytes + status = message["status"] + all_bytes += audio + + if status == 0: + print("create successfully.") + elif status == 1: + if flag: + print(f"首包响应:{time.time() - st} s") + flag = 0 + elif status == 2: + final_response = time.time() - st + duration = len(all_bytes) / 2.0 / 24000 + print(f"尾包响应:{final_response} s") + print(f"音频时长:{duration} s") + print(f"RTF: {final_response / duration}") + with open("./out.pcm", "wb") as f: + f.write(all_bytes) + print("ws is closed") + ws.close() + else: + print("infer error") + + except Exception as e: + print("receive msg,but parse exception:", e) + + +# 收到websocket错误的处理 +def on_error(ws, error): + print("### error:", error) + + +# 收到websocket关闭的处理 +def on_close(ws): + print("### closed ###") + + +# 收到websocket连接建立的处理 +def on_open(ws): + def run(*args): + global st + text_base64 = str( + base64.b64encode((wsParam.text).encode('utf-8')), "UTF8") + d = {"text": text_base64} + d = json.dumps(d) + print("Start sending text data") + st = time.time() + ws.send(d) + + thread.start_new_thread(run, ()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--text", + type=str, + help="A sentence to be synthesized", + default="您好,欢迎使用语音合成服务。") + parser.add_argument( + "--server", type=str, help="server ip", default="127.0.0.1") + parser.add_argument("--port", type=int, help="server port", default=8092) + args = parser.parse_args() + + print("***************************************") + print("Server ip: ", args.server) + print("Server port: ", args.port) + print("Sentence to be synthesized: ", args.text) + print("***************************************") + + wsParam = WsParam(text=args.text, server=args.server, port=args.port) + + websocket.enableTrace(False) + wsUrl = wsParam.create_url() + ws = websocket.WebSocketApp( + wsUrl, on_message=on_message, on_error=on_error, on_close=on_close) + ws.on_open = on_open + ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) diff --git a/paddlespeech/server/tests/tts/online/ws_client_playaudio.py b/paddlespeech/server/tests/tts/online/ws_client_playaudio.py new file mode 100644 index 000000000..cdeb362df --- /dev/null +++ b/paddlespeech/server/tests/tts/online/ws_client_playaudio.py @@ -0,0 +1,160 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import _thread as thread +import argparse +import base64 +import json +import ssl +import threading +import time + +import pyaudio +import websocket + +mutex = threading.Lock() +buffer = b'' +p = pyaudio.PyAudio() +stream = p.open( + format=p.get_format_from_width(2), channels=1, rate=24000, output=True) +flag = 1 +st = 0.0 +all_bytes = 0.0 + + +class WsParam(object): + # 初始化 + def __init__(self, text, server="127.0.0.1", port=8090): + self.server = server + self.port = port + self.url = "ws://" + self.server + ":" + str(self.port) + "/ws/tts" + self.text = text + + # 生成url + def create_url(self): + return self.url + + +def play_audio(): + global stream + global buffer + while True: + time.sleep(0.05) + if not buffer: # buffer 为空 + break + mutex.acquire() + stream.write(buffer) + buffer = b'' + mutex.release() + + +t = threading.Thread(target=play_audio) + + +def on_message(ws, message): + global flag + global t + global buffer + global st + global all_bytes + + try: + message = json.loads(message) + audio = message["audio"] + audio = base64.b64decode(audio) # bytes + status = message["status"] + all_bytes += len(audio) + + if status == 0: + print("create successfully.") + elif status == 1: + mutex.acquire() + buffer += audio + mutex.release() + if flag: + print(f"首包响应:{time.time() - st} s") + flag = 0 + print("Start playing audio") + t.start() + elif status == 2: + final_response = time.time() - st + duration = all_bytes / 2 / 24000 + print(f"尾包响应:{final_response} s") + print(f"音频时长:{duration} s") + print(f"RTF: {final_response / duration}") + print("ws is closed") + ws.close() + else: + print("infer error") + + except Exception as e: + print("receive msg,but parse exception:", e) + + +# 收到websocket错误的处理 +def on_error(ws, error): + print("### error:", error) + + +# 收到websocket关闭的处理 +def on_close(ws): + print("### closed ###") + + +# 收到websocket连接建立的处理 +def on_open(ws): + def run(*args): + global st + text_base64 = str( + base64.b64encode((wsParam.text).encode('utf-8')), "UTF8") + d = {"text": text_base64} + d = json.dumps(d) + print("Start sending text data") + st = time.time() + ws.send(d) + + thread.start_new_thread(run, ()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--text", + type=str, + help="A sentence to be synthesized", + default="您好,欢迎使用语音合成服务。") + parser.add_argument( + "--server", type=str, help="server ip", default="127.0.0.1") + parser.add_argument("--port", type=int, help="server port", default=8092) + args = parser.parse_args() + + print("***************************************") + print("Server ip: ", args.server) + print("Server port: ", args.port) + print("Sentence to be synthesized: ", args.text) + print("***************************************") + + wsParam = WsParam(text=args.text, server=args.server, port=args.port) + + websocket.enableTrace(False) + wsUrl = wsParam.create_url() + ws = websocket.WebSocketApp( + wsUrl, on_message=on_message, on_error=on_error, on_close=on_close) + ws.on_open = on_open + ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) + + t.join() + print("End of playing audio") + stream.stop_stream() + stream.close() + p.terminate() diff --git a/paddlespeech/server/utils/audio_process.py b/paddlespeech/server/utils/audio_process.py index 3cbb495a6..e85b9a27e 100644 --- a/paddlespeech/server/utils/audio_process.py +++ b/paddlespeech/server/utils/audio_process.py @@ -103,3 +103,40 @@ def change_speed(sample_raw, speed_rate, sample_rate): sample_rate_in=sample_rate).squeeze(-1).astype(np.float32).copy() return sample_speed + + +def float2pcm(sig, dtype='int16'): + """Convert floating point signal with a range from -1 to 1 to PCM. + + Args: + sig (array): Input array, must have floating point type. + dtype (str, optional): Desired (integer) data type. Defaults to 'int16'. + + Returns: + numpy.ndarray: Integer data, scaled and clipped to the range of the given + """ + sig = np.asarray(sig) + if sig.dtype.kind != 'f': + raise TypeError("'sig' must be a float array") + dtype = np.dtype(dtype) + if dtype.kind not in 'iu': + raise TypeError("'dtype' must be an integer type") + + i = np.iinfo(dtype) + abs_max = 2**(i.bits - 1) + offset = i.min + abs_max + return (sig * abs_max + offset).clip(i.min, i.max).astype(dtype) + + +def pcm2float(data): + """pcm int16 to float32 + Args: + audio(numpy.array): numpy.int16 + Returns: + audio(numpy.array): numpy.float32 + """ + if data.dtype == np.int16: + data = data.astype("float32") + bits = np.iinfo(np.int16).bits + data = data / (2**(bits - 1)) + return data diff --git a/paddlespeech/server/utils/util.py b/paddlespeech/server/utils/util.py index e9104fa2d..0fe70849d 100644 --- a/paddlespeech/server/utils/util.py +++ b/paddlespeech/server/utils/util.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the import base64 +import math def wav2base64(wav_file: str): @@ -31,3 +32,42 @@ def self_check(): """ self check resource """ return True + + +def denorm(data, mean, std): + """stream am model need to denorm + """ + return data * std + mean + + +def get_chunks(data, block_size, pad_size, step): + """Divide data into multiple chunks + + Args: + data (tensor): data + block_size (int): [description] + pad_size (int): [description] + step (str): set "am" or "voc", generate chunk for step am or vocoder(voc) + + Returns: + list: chunks list + """ + if step == "am": + data_len = data.shape[1] + elif step == "voc": + data_len = data.shape[0] + else: + print("Please set correct type to get chunks, am or voc") + + chunks = [] + n = math.ceil(data_len / block_size) + for i in range(n): + start = max(0, i * block_size - pad_size) + end = min((i + 1) * block_size + pad_size, data_len) + if step == "am": + chunks.append(data[:, start:end, :]) + elif step == "voc": + chunks.append(data[start:end, :]) + else: + print("Please set correct type to get chunks, am or voc") + return chunks diff --git a/paddlespeech/server/ws/api.py b/paddlespeech/server/ws/api.py index 10664d114..313fd16f5 100644 --- a/paddlespeech/server/ws/api.py +++ b/paddlespeech/server/ws/api.py @@ -16,6 +16,7 @@ from typing import List from fastapi import APIRouter from paddlespeech.server.ws.asr_socket import router as asr_router +from paddlespeech.server.ws.tts_socket import router as tts_router _router = APIRouter() @@ -31,7 +32,7 @@ def setup_router(api_list: List): if api_name == 'asr': _router.include_router(asr_router) elif api_name == 'tts': - pass + _router.include_router(tts_router) else: pass diff --git a/paddlespeech/server/ws/tts_socket.py b/paddlespeech/server/ws/tts_socket.py new file mode 100644 index 000000000..11458b3cf --- /dev/null +++ b/paddlespeech/server/ws/tts_socket.py @@ -0,0 +1,62 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json + +from fastapi import APIRouter +from fastapi import WebSocket +from fastapi import WebSocketDisconnect +from starlette.websockets import WebSocketState as WebSocketState + +from paddlespeech.cli.log import logger +from paddlespeech.server.engine.engine_pool import get_engine_pool + +router = APIRouter() + + +@router.websocket('/ws/tts') +async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + + try: + # careful here, changed the source code from starlette.websockets + assert websocket.application_state == WebSocketState.CONNECTED + message = await websocket.receive() + websocket._raise_on_disconnect(message) + + # get engine + engine_pool = get_engine_pool() + tts_engine = engine_pool['tts'] + + # 获取 message 并转文本 + message = json.loads(message["text"]) + text_bese64 = message["text"] + sentence = tts_engine.preprocess(text_bese64=text_bese64) + + # run + wav_generator = tts_engine.run(sentence) + + while True: + try: + tts_results = next(wav_generator) + resp = {"status": 1, "audio": tts_results} + await websocket.send_json(resp) + logger.info("streaming audio...") + except StopIteration as e: + resp = {"status": 2, "audio": ''} + await websocket.send_json(resp) + logger.info("Complete the transmission of audio streams") + break + + except WebSocketDisconnect: + pass diff --git a/paddlespeech/t2s/exps/fastspeech2/preprocess.py b/paddlespeech/t2s/exps/fastspeech2/preprocess.py index 5bda75451..db1842b2e 100644 --- a/paddlespeech/t2s/exps/fastspeech2/preprocess.py +++ b/paddlespeech/t2s/exps/fastspeech2/preprocess.py @@ -86,6 +86,9 @@ def process_sentence(config: Dict[str, Any], logmel = mel_extractor.get_log_mel_fbank(wav) # change duration according to mel_length compare_duration_and_mel_length(sentences, utt_id, logmel) + # utt_id may be popped in compare_duration_and_mel_length + if utt_id not in sentences: + return None phones = sentences[utt_id][0] durations = sentences[utt_id][1] num_frames = logmel.shape[0] diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index 1188ddfb1..62602a01f 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -104,7 +104,7 @@ def get_voc_output(args, voc_predictor, input): def parse_args(): parser = argparse.ArgumentParser( - description="Paddle Infernce with speedyspeech & parallel wavegan.") + description="Paddle Infernce with acoustic model & vocoder.") # acoustic model parser.add_argument( '--am', diff --git a/paddlespeech/t2s/exps/ort_predict.py b/paddlespeech/t2s/exps/ort_predict.py new file mode 100644 index 000000000..e8d4d61c3 --- /dev/null +++ b/paddlespeech/t2s/exps/ort_predict.py @@ -0,0 +1,156 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +from pathlib import Path + +import jsonlines +import numpy as np +import onnxruntime as ort +import soundfile as sf +from timer import timer + +from paddlespeech.t2s.exps.syn_utils import get_test_dataset +from paddlespeech.t2s.utils import str2bool + + +def get_sess(args, filed='am'): + full_name = '' + if filed == 'am': + full_name = args.am + elif filed == 'voc': + full_name = args.voc + model_dir = str(Path(args.inference_dir) / (full_name + ".onnx")) + sess_options = ort.SessionOptions() + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL + + if args.device == "gpu": + # fastspeech2/mb_melgan can't use trt now! + if args.use_trt: + providers = ['TensorrtExecutionProvider'] + else: + providers = ['CUDAExecutionProvider'] + elif args.device == "cpu": + providers = ['CPUExecutionProvider'] + sess_options.intra_op_num_threads = args.cpu_threads + sess = ort.InferenceSession( + model_dir, providers=providers, sess_options=sess_options) + return sess + + +def ort_predict(args): + # construct dataset for evaluation + with jsonlines.open(args.test_metadata, 'r') as reader: + test_metadata = list(reader) + am_name = args.am[:args.am.rindex('_')] + am_dataset = args.am[args.am.rindex('_') + 1:] + test_dataset = get_test_dataset(args, test_metadata, am_name, am_dataset) + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + fs = 24000 if am_dataset != 'ljspeech' else 22050 + + # am + am_sess = get_sess(args, filed='am') + + # vocoder + voc_sess = get_sess(args, filed='voc') + + # am warmup + for T in [27, 38, 54]: + data = np.random.randint(1, 266, size=(T, )) + am_sess.run(None, {"text": data}) + + # voc warmup + for T in [227, 308, 544]: + data = np.random.rand(T, 80).astype("float32") + voc_sess.run(None, {"logmel": data}) + print("warm up done!") + + N = 0 + T = 0 + for example in test_dataset: + utt_id = example['utt_id'] + phone_ids = example["text"] + with timer() as t: + mel = am_sess.run(output_names=None, input_feed={'text': phone_ids}) + mel = mel[0] + wav = voc_sess.run(output_names=None, input_feed={'logmel': mel}) + + N += len(wav[0]) + T += t.elapse + speed = len(wav[0]) / t.elapse + rtf = fs / speed + sf.write( + str(output_dir / (utt_id + ".wav")), + np.array(wav)[0], + samplerate=fs) + print( + f"{utt_id}, mel: {mel.shape}, wave: {len(wav[0])}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + ) + print(f"generation speed: {N / T}Hz, RTF: {fs / (N / T) }") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Infernce with onnxruntime.") + # acoustic model + parser.add_argument( + '--am', + type=str, + default='fastspeech2_csmsc', + choices=[ + 'fastspeech2_csmsc', + ], + help='Choose acoustic model type of tts task.') + + # voc + parser.add_argument( + '--voc', + type=str, + default='hifigan_csmsc', + choices=['hifigan_csmsc', 'mb_melgan_csmsc'], + help='Choose vocoder type of tts task.') + # other + parser.add_argument( + "--inference_dir", type=str, help="dir to save inference models") + parser.add_argument("--test_metadata", type=str, help="test metadata.") + parser.add_argument("--output_dir", type=str, help="output dir") + + # inference + parser.add_argument( + "--use_trt", + type=str2bool, + default=False, + help="Whether to use inference engin TensorRT.", ) + + parser.add_argument( + "--device", + default="gpu", + choices=["gpu", "cpu"], + help="Device selected for inference.", ) + parser.add_argument('--cpu_threads', type=int, default=1) + + args, _ = parser.parse_known_args() + return args + + +def main(): + args = parse_args() + + ort_predict(args) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/ort_predict_e2e.py b/paddlespeech/t2s/exps/ort_predict_e2e.py new file mode 100644 index 000000000..8aa04cbc5 --- /dev/null +++ b/paddlespeech/t2s/exps/ort_predict_e2e.py @@ -0,0 +1,183 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +from pathlib import Path + +import numpy as np +import onnxruntime as ort +import soundfile as sf +from timer import timer + +from paddlespeech.t2s.exps.syn_utils import get_frontend +from paddlespeech.t2s.exps.syn_utils import get_sentences +from paddlespeech.t2s.utils import str2bool + + +def get_sess(args, filed='am'): + full_name = '' + if filed == 'am': + full_name = args.am + elif filed == 'voc': + full_name = args.voc + model_dir = str(Path(args.inference_dir) / (full_name + ".onnx")) + sess_options = ort.SessionOptions() + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL + + if args.device == "gpu": + # fastspeech2/mb_melgan can't use trt now! + if args.use_trt: + providers = ['TensorrtExecutionProvider'] + else: + providers = ['CUDAExecutionProvider'] + elif args.device == "cpu": + providers = ['CPUExecutionProvider'] + sess_options.intra_op_num_threads = args.cpu_threads + sess = ort.InferenceSession( + model_dir, providers=providers, sess_options=sess_options) + return sess + + +def ort_predict(args): + + # frontend + frontend = get_frontend(args) + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + sentences = get_sentences(args) + + am_name = args.am[:args.am.rindex('_')] + am_dataset = args.am[args.am.rindex('_') + 1:] + fs = 24000 if am_dataset != 'ljspeech' else 22050 + + # am + am_sess = get_sess(args, filed='am') + + # vocoder + voc_sess = get_sess(args, filed='voc') + + # am warmup + for T in [27, 38, 54]: + data = np.random.randint(1, 266, size=(T, )) + am_sess.run(None, {"text": data}) + + # voc warmup + for T in [227, 308, 544]: + data = np.random.rand(T, 80).astype("float32") + voc_sess.run(None, {"logmel": data}) + print("warm up done!") + + # frontend warmup + # Loading model cost 0.5+ seconds + if args.lang == 'zh': + frontend.get_input_ids("你好,欢迎使用飞桨框架进行深度学习研究!", merge_sentences=True) + else: + print("lang should in be 'zh' here!") + + N = 0 + T = 0 + merge_sentences = True + for utt_id, sentence in sentences: + with timer() as t: + if args.lang == 'zh': + input_ids = frontend.get_input_ids( + sentence, merge_sentences=merge_sentences) + + phone_ids = input_ids["phone_ids"] + else: + print("lang should in be 'zh' here!") + # merge_sentences=True here, so we only use the first item of phone_ids + phone_ids = phone_ids[0].numpy() + mel = am_sess.run(output_names=None, input_feed={'text': phone_ids}) + mel = mel[0] + wav = voc_sess.run(output_names=None, input_feed={'logmel': mel}) + + N += len(wav[0]) + T += t.elapse + speed = len(wav[0]) / t.elapse + rtf = fs / speed + sf.write( + str(output_dir / (utt_id + ".wav")), + np.array(wav)[0], + samplerate=fs) + print( + f"{utt_id}, mel: {mel.shape}, wave: {len(wav[0])}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + ) + print(f"generation speed: {N / T}Hz, RTF: {fs / (N / T) }") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Infernce with onnxruntime.") + # acoustic model + parser.add_argument( + '--am', + type=str, + default='fastspeech2_csmsc', + choices=[ + 'fastspeech2_csmsc', + ], + help='Choose acoustic model type of tts task.') + parser.add_argument( + "--phones_dict", type=str, default=None, help="phone vocabulary file.") + parser.add_argument( + "--tones_dict", type=str, default=None, help="tone vocabulary file.") + + # voc + parser.add_argument( + '--voc', + type=str, + default='hifigan_csmsc', + choices=['hifigan_csmsc', 'mb_melgan_csmsc'], + help='Choose vocoder type of tts task.') + # other + parser.add_argument( + "--inference_dir", type=str, help="dir to save inference models") + parser.add_argument( + "--text", + type=str, + help="text to synthesize, a 'utt_id sentence' pair per line") + parser.add_argument("--output_dir", type=str, help="output dir") + parser.add_argument( + '--lang', + type=str, + default='zh', + help='Choose model language. zh or en') + + # inference + parser.add_argument( + "--use_trt", + type=str2bool, + default=False, + help="Whether to use inference engin TensorRT.", ) + + parser.add_argument( + "--device", + default="gpu", + choices=["gpu", "cpu"], + help="Device selected for inference.", ) + parser.add_argument('--cpu_threads', type=int, default=1) + + args, _ = parser.parse_known_args() + return args + + +def main(): + args = parse_args() + + ort_predict(args) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/speedyspeech/preprocess.py b/paddlespeech/t2s/exps/speedyspeech/preprocess.py index 3f81c4e14..e833d1394 100644 --- a/paddlespeech/t2s/exps/speedyspeech/preprocess.py +++ b/paddlespeech/t2s/exps/speedyspeech/preprocess.py @@ -79,6 +79,9 @@ def process_sentence(config: Dict[str, Any], logmel = mel_extractor.get_log_mel_fbank(wav) # change duration according to mel_length compare_duration_and_mel_length(sentences, utt_id, logmel) + # utt_id may be popped in compare_duration_and_mel_length + if utt_id not in sentences: + return None labels = sentences[utt_id][0] # extract phone and duration phones = [] diff --git a/paddlespeech/t2s/exps/synthesize_streaming.py b/paddlespeech/t2s/exps/synthesize_streaming.py index f38b2d352..7b9906c10 100644 --- a/paddlespeech/t2s/exps/synthesize_streaming.py +++ b/paddlespeech/t2s/exps/synthesize_streaming.py @@ -90,6 +90,7 @@ def evaluate(args): output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) merge_sentences = True + get_tone_ids = False N = 0 T = 0 @@ -98,8 +99,6 @@ def evaluate(args): for utt_id, sentence in sentences: with timer() as t: - get_tone_ids = False - if args.lang == 'zh': input_ids = frontend.get_input_ids( sentence, diff --git a/paddlespeech/t2s/exps/tacotron2/preprocess.py b/paddlespeech/t2s/exps/tacotron2/preprocess.py index 7f41089eb..14a0d7eae 100644 --- a/paddlespeech/t2s/exps/tacotron2/preprocess.py +++ b/paddlespeech/t2s/exps/tacotron2/preprocess.py @@ -82,6 +82,9 @@ def process_sentence(config: Dict[str, Any], logmel = mel_extractor.get_log_mel_fbank(wav) # change duration according to mel_length compare_duration_and_mel_length(sentences, utt_id, logmel) + # utt_id may be popped in compare_duration_and_mel_length + if utt_id not in sentences: + return None phones = sentences[utt_id][0] durations = sentences[utt_id][1] num_frames = logmel.shape[0] diff --git a/paddlespeech/t2s/modules/positional_encoding.py b/paddlespeech/t2s/modules/positional_encoding.py index 7c368c3aa..715c576f5 100644 --- a/paddlespeech/t2s/modules/positional_encoding.py +++ b/paddlespeech/t2s/modules/positional_encoding.py @@ -31,8 +31,9 @@ def sinusoid_position_encoding(num_positions: int, channel = paddle.arange(0, feature_size, 2, dtype=dtype) index = paddle.arange(start_pos, start_pos + num_positions, 1, dtype=dtype) - p = (paddle.unsqueeze(index, -1) * - omega) / (10000.0**(channel / float(feature_size))) + denominator = channel / float(feature_size) + denominator = paddle.to_tensor([10000.0], dtype='float32')**denominator + p = (paddle.unsqueeze(index, -1) * omega) / denominator encodings = paddle.zeros([num_positions, feature_size], dtype=dtype) encodings[:, 0::2] = paddle.sin(p) encodings[:, 1::2] = paddle.cos(p) diff --git a/paddlespeech/vector/cluster/diarization.py b/paddlespeech/vector/cluster/diarization.py index ee00cb535..56af61af6 100644 --- a/paddlespeech/vector/cluster/diarization.py +++ b/paddlespeech/vector/cluster/diarization.py @@ -747,6 +747,77 @@ def merge_ssegs_same_speaker(lol): return new_lol +def write_ders_file(ref_rttm, DER, out_der_file): + """Write the final DERs for individual recording. + + Arguments + --------- + ref_rttm : str + Reference RTTM file. + DER : array + Array containing DER values of each recording. + out_der_file : str + File to write the DERs. + """ + + rttm = read_rttm(ref_rttm) + spkr_info = list(filter(lambda x: x.startswith("SPKR-INFO"), rttm)) + + rec_id_list = [] + count = 0 + + with open(out_der_file, "w") as f: + for row in spkr_info: + a = row.split(" ") + rec_id = a[1] + if rec_id not in rec_id_list: + r = [rec_id, str(round(DER[count], 2))] + rec_id_list.append(rec_id) + line_str = " ".join(r) + f.write("%s\n" % line_str) + count += 1 + r = ["OVERALL ", str(round(DER[count], 2))] + line_str = " ".join(r) + f.write("%s\n" % line_str) + + +def get_oracle_num_spkrs(rec_id, spkr_info): + """ + Returns actual number of speakers in a recording from the ground-truth. + This can be used when the condition is oracle number of speakers. + + Arguments + --------- + rec_id : str + Recording ID for which the number of speakers have to be obtained. + spkr_info : list + Header of the RTTM file. Starting with `SPKR-INFO`. + + Example + ------- + >>> from speechbrain.processing import diarization as diar + >>> spkr_info = ['SPKR-INFO ES2011a 0 unknown ES2011a.A ', + ... 'SPKR-INFO ES2011a 0 unknown ES2011a.B ', + ... 'SPKR-INFO ES2011a 0 unknown ES2011a.C ', + ... 'SPKR-INFO ES2011a 0 unknown ES2011a.D ', + ... 'SPKR-INFO ES2011b 0 unknown ES2011b.A ', + ... 'SPKR-INFO ES2011b 0 unknown ES2011b.B ', + ... 'SPKR-INFO ES2011b 0 unknown ES2011b.C '] + >>> diar.get_oracle_num_spkrs('ES2011a', spkr_info) + 4 + >>> diar.get_oracle_num_spkrs('ES2011b', spkr_info) + 3 + """ + + num_spkrs = 0 + for line in spkr_info: + if rec_id in line: + # Since rec_id is prefix for each speaker + num_spkrs += 1 + + return num_spkrs + + def distribute_overlap(lol): """ Distributes the overlapped speech equally among the adjacent segments @@ -827,6 +898,29 @@ def distribute_overlap(lol): return new_lol +def read_rttm(rttm_file_path): + """ + Reads and returns RTTM in list format. + + Arguments + --------- + rttm_file_path : str + Path to the RTTM file to be read. + + Returns + ------- + rttm : list + List containing rows of RTTM file. + """ + + rttm = [] + with open(rttm_file_path, "r") as f: + for line in f: + entry = line[:-1] + rttm.append(entry) + return rttm + + def write_rttm(segs_list, out_rttm_file): """ Writes the segment list in RTTM format (A standard NIST format). diff --git a/paddlespeech/vector/exps/ecapa_tdnn/test.py b/paddlespeech/vector/exps/ecapa_tdnn/test.py index d0de6dc51..70b1521ed 100644 --- a/paddlespeech/vector/exps/ecapa_tdnn/test.py +++ b/paddlespeech/vector/exps/ecapa_tdnn/test.py @@ -21,10 +21,11 @@ from paddle.io import DataLoader from tqdm import tqdm from yacs.config import CfgNode -from paddleaudio.datasets import VoxCeleb from paddleaudio.metric import compute_eer from paddlespeech.s2t.utils.log import Log from paddlespeech.vector.io.batch import batch_feature_normalize +from paddlespeech.vector.io.dataset import CSVDataset +from paddlespeech.vector.io.embedding_norm import InputNormalization from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.modules.sid_model import SpeakerIdetification from paddlespeech.vector.training.seeding import seed_everything @@ -32,6 +33,91 @@ from paddlespeech.vector.training.seeding import seed_everything logger = Log(__name__).getlog() +def compute_dataset_embedding(data_loader, model, mean_var_norm_emb, config, + id2embedding): + """compute the dataset embeddings + + Args: + data_loader (_type_): _description_ + model (_type_): _description_ + mean_var_norm_emb (_type_): _description_ + config (_type_): _description_ + """ + logger.info( + f'Computing embeddings on {data_loader.dataset.csv_path} dataset') + with paddle.no_grad(): + for batch_idx, batch in enumerate(tqdm(data_loader)): + + # stage 8-1: extrac the audio embedding + ids, feats, lengths = batch['ids'], batch['feats'], batch['lengths'] + embeddings = model.backbone(feats, lengths).squeeze( + -1) # (N, emb_size, 1) -> (N, emb_size) + + # Global embedding normalization. + # if we use the global embedding norm + # eer can reduece about relative 10% + if config.global_embedding_norm and mean_var_norm_emb: + lengths = paddle.ones([embeddings.shape[0]]) + embeddings = mean_var_norm_emb(embeddings, lengths) + + # Update embedding dict. + id2embedding.update(dict(zip(ids, embeddings))) + + +def compute_verification_scores(id2embedding, train_cohort, config): + labels = [] + enroll_ids = [] + test_ids = [] + logger.info(f"read the trial from {config.verification_file}") + cos_sim_func = paddle.nn.CosineSimilarity(axis=-1) + scores = [] + with open(config.verification_file, 'r') as f: + for line in f.readlines(): + label, enroll_id, test_id = line.strip().split(' ') + enroll_id = enroll_id.split('.')[0].replace('/', '-') + test_id = test_id.split('.')[0].replace('/', '-') + labels.append(int(label)) + + enroll_emb = id2embedding[enroll_id] + test_emb = id2embedding[test_id] + score = cos_sim_func(enroll_emb, test_emb).item() + + if "score_norm" in config: + # Getting norm stats for enroll impostors + enroll_rep = paddle.tile( + enroll_emb, repeat_times=[train_cohort.shape[0], 1]) + score_e_c = cos_sim_func(enroll_rep, train_cohort) + if "cohort_size" in config: + score_e_c, _ = paddle.topk( + score_e_c, k=config.cohort_size, axis=0) + mean_e_c = paddle.mean(score_e_c, axis=0) + std_e_c = paddle.std(score_e_c, axis=0) + + # Getting norm stats for test impostors + test_rep = paddle.tile( + test_emb, repeat_times=[train_cohort.shape[0], 1]) + score_t_c = cos_sim_func(test_rep, train_cohort) + if "cohort_size" in config: + score_t_c, _ = paddle.topk( + score_t_c, k=config.cohort_size, axis=0) + mean_t_c = paddle.mean(score_t_c, axis=0) + std_t_c = paddle.std(score_t_c, axis=0) + + if config.score_norm == "s-norm": + score_e = (score - mean_e_c) / std_e_c + score_t = (score - mean_t_c) / std_t_c + + score = 0.5 * (score_e + score_t) + elif config.score_norm == "z-norm": + score = (score - mean_e_c) / std_e_c + elif config.score_norm == "t-norm": + score = (score - mean_t_c) / std_t_c + + scores.append(score) + + return scores, labels + + def main(args, config): # stage0: set the training device, cpu or gpu paddle.set_device(args.device) @@ -58,9 +144,8 @@ def main(args, config): # stage4: construct the enroll and test dataloader - enroll_dataset = VoxCeleb( - subset='enroll', - target_dir=args.data_dir, + enroll_dataset = CSVDataset( + os.path.join(args.data_dir, "vox/csv/enroll.csv"), feat_type='melspectrogram', random_chunk=False, n_mels=config.n_mels, @@ -68,16 +153,15 @@ def main(args, config): hop_length=config.hop_size) enroll_sampler = BatchSampler( enroll_dataset, batch_size=config.batch_size, - shuffle=True) # Shuffle to make embedding normalization more robust. - enrol_loader = DataLoader(enroll_dataset, + shuffle=False) # Shuffle to make embedding normalization more robust. + enroll_loader = DataLoader(enroll_dataset, batch_sampler=enroll_sampler, collate_fn=lambda x: batch_feature_normalize( - x, mean_norm=True, std_norm=False), + x, mean_norm=True, std_norm=False), num_workers=config.num_workers, return_list=True,) - test_dataset = VoxCeleb( - subset='test', - target_dir=args.data_dir, + test_dataset = CSVDataset( + os.path.join(args.data_dir, "vox/csv/test.csv"), feat_type='melspectrogram', random_chunk=False, n_mels=config.n_mels, @@ -85,7 +169,7 @@ def main(args, config): hop_length=config.hop_size) test_sampler = BatchSampler( - test_dataset, batch_size=config.batch_size, shuffle=True) + test_dataset, batch_size=config.batch_size, shuffle=False) test_loader = DataLoader(test_dataset, batch_sampler=test_sampler, collate_fn=lambda x: batch_feature_normalize( @@ -97,75 +181,65 @@ def main(args, config): # stage6: global embedding norm to imporve the performance logger.info(f"global embedding norm: {config.global_embedding_norm}") - if config.global_embedding_norm: - global_embedding_mean = None - global_embedding_std = None - mean_norm_flag = config.embedding_mean_norm - std_norm_flag = config.embedding_std_norm - batch_count = 0 # stage7: Compute embeddings of audios in enrol and test dataset from model. + + if config.global_embedding_norm: + mean_var_norm_emb = InputNormalization( + norm_type="global", + mean_norm=config.embedding_mean_norm, + std_norm=config.embedding_std_norm) + + if "score_norm" in config: + logger.info(f"we will do score norm: {config.score_norm}") + train_dataset = CSVDataset( + os.path.join(args.data_dir, "vox/csv/train.csv"), + feat_type='melspectrogram', + n_train_snts=config.n_train_snts, + random_chunk=False, + n_mels=config.n_mels, + window_size=config.window_size, + hop_length=config.hop_size) + train_sampler = BatchSampler( + train_dataset, batch_size=config.batch_size, shuffle=False) + train_loader = DataLoader(train_dataset, + batch_sampler=train_sampler, + collate_fn=lambda x: batch_feature_normalize( + x, mean_norm=True, std_norm=False), + num_workers=config.num_workers, + return_list=True,) + id2embedding = {} # Run multi times to make embedding normalization more stable. - for i in range(2): - for dl in [enrol_loader, test_loader]: - logger.info( - f'Loop {[i+1]}: Computing embeddings on {dl.dataset.subset} dataset' - ) - with paddle.no_grad(): - for batch_idx, batch in enumerate(tqdm(dl)): - - # stage 8-1: extrac the audio embedding - ids, feats, lengths = batch['ids'], batch['feats'], batch[ - 'lengths'] - embeddings = model.backbone(feats, lengths).squeeze( - -1).numpy() # (N, emb_size, 1) -> (N, emb_size) - - # Global embedding normalization. - # if we use the global embedding norm - # eer can reduece about relative 10% - if config.global_embedding_norm: - batch_count += 1 - current_mean = embeddings.mean( - axis=0) if mean_norm_flag else 0 - current_std = embeddings.std( - axis=0) if std_norm_flag else 1 - # Update global mean and std. - if global_embedding_mean is None and global_embedding_std is None: - global_embedding_mean, global_embedding_std = current_mean, current_std - else: - weight = 1 / batch_count # Weight decay by batches. - global_embedding_mean = ( - 1 - weight - ) * global_embedding_mean + weight * current_mean - global_embedding_std = ( - 1 - weight - ) * global_embedding_std + weight * current_std - # Apply global embedding normalization. - embeddings = (embeddings - global_embedding_mean - ) / global_embedding_std - - # Update embedding dict. - id2embedding.update(dict(zip(ids, embeddings))) + logger.info("First loop for enroll and test dataset") + compute_dataset_embedding(enroll_loader, model, mean_var_norm_emb, config, + id2embedding) + compute_dataset_embedding(test_loader, model, mean_var_norm_emb, config, + id2embedding) + + logger.info("Second loop for enroll and test dataset") + compute_dataset_embedding(enroll_loader, model, mean_var_norm_emb, config, + id2embedding) + compute_dataset_embedding(test_loader, model, mean_var_norm_emb, config, + id2embedding) + mean_var_norm_emb.save( + os.path.join(args.load_checkpoint, "mean_var_norm_emb")) # stage 8: Compute cosine scores. - labels = [] - enroll_ids = [] - test_ids = [] - logger.info(f"read the trial from {VoxCeleb.veri_test_file}") - with open(VoxCeleb.veri_test_file, 'r') as f: - for line in f.readlines(): - label, enroll_id, test_id = line.strip().split(' ') - labels.append(int(label)) - enroll_ids.append(enroll_id.split('.')[0].replace('/', '-')) - test_ids.append(test_id.split('.')[0].replace('/', '-')) - - cos_sim_func = paddle.nn.CosineSimilarity(axis=1) - enrol_embeddings, test_embeddings = map(lambda ids: paddle.to_tensor( - np.asarray([id2embedding[uttid] for uttid in ids], dtype='float32')), - [enroll_ids, test_ids - ]) # (N, emb_size) - scores = cos_sim_func(enrol_embeddings, test_embeddings) + train_cohort = None + if "score_norm" in config: + train_embeddings = {} + # cohort embedding not do mean and std norm + compute_dataset_embedding(train_loader, model, None, config, + train_embeddings) + train_cohort = paddle.stack(list(train_embeddings.values())) + + # compute the scores + scores, labels = compute_verification_scores(id2embedding, train_cohort, + config) + + # compute the EER and threshold + scores = paddle.to_tensor(scores) EER, threshold = compute_eer(np.asarray(labels), scores.numpy()) logger.info( f'EER of verification test: {EER*100:.4f}%, score threshold: {threshold:.5f}' diff --git a/paddlespeech/vector/exps/ecapa_tdnn/train.py b/paddlespeech/vector/exps/ecapa_tdnn/train.py index 257b97abe..b777dae89 100644 --- a/paddlespeech/vector/exps/ecapa_tdnn/train.py +++ b/paddlespeech/vector/exps/ecapa_tdnn/train.py @@ -23,13 +23,13 @@ from paddle.io import DistributedBatchSampler from yacs.config import CfgNode from paddleaudio.compliance.librosa import melspectrogram -from paddleaudio.datasets.voxceleb import VoxCeleb from paddlespeech.s2t.utils.log import Log from paddlespeech.vector.io.augment import build_augment_pipeline from paddlespeech.vector.io.augment import waveform_augment from paddlespeech.vector.io.batch import batch_pad_right from paddlespeech.vector.io.batch import feature_normalize from paddlespeech.vector.io.batch import waveform_collate_fn +from paddlespeech.vector.io.dataset import CSVDataset from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.modules.loss import AdditiveAngularMargin from paddlespeech.vector.modules.loss import LogSoftmaxWrapper @@ -54,8 +54,12 @@ def main(args, config): # stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline # note: some cmd must do in rank==0, so wo will refactor the data prepare code - train_dataset = VoxCeleb('train', target_dir=args.data_dir) - dev_dataset = VoxCeleb('dev', target_dir=args.data_dir) + train_dataset = CSVDataset( + csv_path=os.path.join(args.data_dir, "vox/csv/train.csv"), + label2id_path=os.path.join(args.data_dir, "vox/meta/label2id.txt")) + dev_dataset = CSVDataset( + csv_path=os.path.join(args.data_dir, "vox/csv/dev.csv"), + label2id_path=os.path.join(args.data_dir, "vox/meta/label2id.txt")) if config.augment: augment_pipeline = build_augment_pipeline(target_dir=args.data_dir) @@ -67,7 +71,7 @@ def main(args, config): # stage4: build the speaker verification train instance with backbone model model = SpeakerIdetification( - backbone=ecapa_tdnn, num_class=VoxCeleb.num_speakers) + backbone=ecapa_tdnn, num_class=config.num_speakers) # stage5: build the optimizer, we now only construct the AdamW optimizer # 140000 is single gpu steps @@ -193,15 +197,15 @@ def main(args, config): paddle.optimizer.lr.LRScheduler): optimizer._learning_rate.step() optimizer.clear_grad() - train_run_cost += time.time() - train_start # stage 9-8: Calculate average loss per batch - avg_loss += loss.numpy()[0] + avg_loss = loss.item() # stage 9-9: Calculate metrics, which is one-best accuracy preds = paddle.argmax(logits, axis=1) num_corrects += (preds == labels).numpy().sum() num_samples += feats.shape[0] + train_run_cost += time.time() - train_start timer.count() # step plus one in timer # stage 9-10: print the log information only on 0-rank per log-freq batchs @@ -220,8 +224,9 @@ def main(args, config): train_feat_cost / config.log_interval) print_msg += ' avg_train_cost: {:.5f} sec,'.format( train_run_cost / config.log_interval) - print_msg += ' lr={:.4E} step/sec={:.2f} | ETA {}'.format( - lr, timer.timing, timer.eta) + + print_msg += ' lr={:.4E} step/sec={:.2f} ips:{:.5f}| ETA {}'.format( + lr, timer.timing, timer.ips, timer.eta) logger.info(print_msg) avg_loss = 0 diff --git a/paddlespeech/vector/io/augment.py b/paddlespeech/vector/io/augment.py index 3baace139..0aa89c6a3 100644 --- a/paddlespeech/vector/io/augment.py +++ b/paddlespeech/vector/io/augment.py @@ -14,6 +14,7 @@ # this is modified from SpeechBrain # https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/lobes/augment.py import math +import os from typing import List import numpy as np @@ -21,8 +22,8 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F -from paddleaudio.datasets.rirs_noises import OpenRIRNoise from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.io.dataset import CSVDataset from paddlespeech.vector.io.signal_processing import compute_amplitude from paddlespeech.vector.io.signal_processing import convolve1d from paddlespeech.vector.io.signal_processing import dB_to_amplitude @@ -509,7 +510,7 @@ class AddNoise(nn.Layer): assert w >= 0, f'Target length {target_length} is less than origin length {x.shape[0]}' return np.pad(x, [0, w], mode=mode, **kwargs) - ids = [item['id'] for item in batch] + ids = [item['utt_id'] for item in batch] lengths = np.asarray([item['feat'].shape[0] for item in batch]) waveforms = list( map(lambda x: pad(x, max(max_length, lengths.max().item())), @@ -589,7 +590,7 @@ class AddReverb(nn.Layer): assert w >= 0, f'Target length {target_length} is less than origin length {x.shape[0]}' return np.pad(x, [0, w], mode=mode, **kwargs) - ids = [item['id'] for item in batch] + ids = [item['utt_id'] for item in batch] lengths = np.asarray([item['feat'].shape[0] for item in batch]) waveforms = list( map(lambda x: pad(x, lengths.max().item()), @@ -839,8 +840,10 @@ def build_augment_pipeline(target_dir=None) -> List[paddle.nn.Layer]: List[paddle.nn.Layer]: all augment process """ logger.info("start to build the augment pipeline") - noise_dataset = OpenRIRNoise('noise', target_dir=target_dir) - rir_dataset = OpenRIRNoise('rir', target_dir=target_dir) + noise_dataset = CSVDataset(csv_path=os.path.join(target_dir, + "rir_noise/csv/noise.csv")) + rir_dataset = CSVDataset(csv_path=os.path.join(target_dir, + "rir_noise/csv/rir.csv")) wavedrop = TimeDomainSpecAugment( sample_rate=16000, diff --git a/paddlespeech/vector/io/batch.py b/paddlespeech/vector/io/batch.py index 92ca990cf..5049d1946 100644 --- a/paddlespeech/vector/io/batch.py +++ b/paddlespeech/vector/io/batch.py @@ -17,6 +17,17 @@ import paddle def waveform_collate_fn(batch): + """Wrap the waveform into a batch form + + Args: + batch (list): the waveform list from the dataloader + the item of data include several field + feat: the utterance waveform data + label: the utterance label encoding data + + Returns: + dict: the batch data to dataloader + """ waveforms = np.stack([item['feat'] for item in batch]) labels = np.stack([item['label'] for item in batch]) @@ -27,6 +38,18 @@ def feature_normalize(feats: paddle.Tensor, mean_norm: bool=True, std_norm: bool=True, convert_to_numpy: bool=False): + """Do one utterance feature normalization + + Args: + feats (paddle.Tensor): the original utterance feat, such as fbank, mfcc + mean_norm (bool, optional): mean norm flag. Defaults to True. + std_norm (bool, optional): std norm flag. Defaults to True. + convert_to_numpy (bool, optional): convert the paddle.tensor to numpy + and do feature norm with numpy. Defaults to False. + + Returns: + paddle.Tensor : the normalized feats + """ # Features normalization if needed # numpy.mean is a little with paddle.mean about 1e-6 if convert_to_numpy: @@ -60,7 +83,17 @@ def pad_right_2d(x, target_length, axis=-1, mode='constant', **kwargs): def batch_feature_normalize(batch, mean_norm: bool=True, std_norm: bool=True): - ids = [item['id'] for item in batch] + """Do batch utterance features normalization + + Args: + batch (list): the batch feature from dataloader + mean_norm (bool, optional): mean normalization flag. Defaults to True. + std_norm (bool, optional): std normalization flag. Defaults to True. + + Returns: + dict: the normalized batch features + """ + ids = [item['utt_id'] for item in batch] lengths = np.asarray([item['feat'].shape[1] for item in batch]) feats = list( map(lambda x: pad_right_2d(x, lengths.max()), diff --git a/paddlespeech/vector/io/dataset.py b/paddlespeech/vector/io/dataset.py new file mode 100644 index 000000000..316c8ac34 --- /dev/null +++ b/paddlespeech/vector/io/dataset.py @@ -0,0 +1,192 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from dataclasses import fields +from paddle.io import Dataset + +from paddleaudio import load as load_audio +from paddleaudio.compliance.librosa import melspectrogram +from paddlespeech.s2t.utils.log import Log +logger = Log(__name__).getlog() + +# the audio meta info in the vector CSVDataset +# utt_id: the utterance segment name +# duration: utterance segment time +# wav: utterance file path +# start: start point in the original wav file +# stop: stop point in the original wav file +# label: the utterance segment's label id + + +@dataclass +class meta_info: + """the audio meta info in the vector CSVDataset + + Args: + utt_id (str): the utterance segment name + duration (float): utterance segment time + wav (str): utterance file path + start (int): start point in the original wav file + stop (int): stop point in the original wav file + lab_id (str): the utterance segment's label id + """ + utt_id: str + duration: float + wav: str + start: int + stop: int + label: str + + +# csv dataset support feature type +# raw: return the pcm data sample point +# melspectrogram: fbank feature +feat_funcs = { + 'raw': None, + 'melspectrogram': melspectrogram, +} + + +class CSVDataset(Dataset): + def __init__(self, + csv_path, + label2id_path=None, + config=None, + random_chunk=True, + feat_type: str="raw", + n_train_snts: int=-1, + **kwargs): + """Implement the CSV Dataset + + Args: + csv_path (str): csv dataset file path + label2id_path (str): the utterance label to integer id map file path + config (CfgNode): yaml config + feat_type (str): dataset feature type. if it is raw, it return pcm data. + n_train_snts (int): select the n_train_snts sample from the dataset. + if n_train_snts = -1, dataset will load all the sample. + Default value is -1. + kwargs : feature type args + """ + super().__init__() + self.csv_path = csv_path + self.label2id_path = label2id_path + self.config = config + self.random_chunk = random_chunk + self.feat_type = feat_type + self.n_train_snts = n_train_snts + self.feat_config = kwargs + self.id2label = {} + self.label2id = {} + self.data = self.load_data_csv() + self.load_speaker_to_label() + + def load_data_csv(self): + """Load the csv dataset content and store them in the data property + the csv dataset's format has six fields, + that is audio_id or utt_id, audio duration, segment start point, segment stop point + and utterance label. + Note in training period, the utterance label must has a map to integer id in label2id_path + + Returns: + list: the csv data with meta_info type + """ + data = [] + + with open(self.csv_path, 'r') as rf: + for line in rf.readlines()[1:]: + audio_id, duration, wav, start, stop, spk_id = line.strip( + ).split(',') + data.append( + meta_info(audio_id, + float(duration), wav, + int(start), int(stop), spk_id)) + if self.n_train_snts > 0: + sample_num = min(self.n_train_snts, len(data)) + data = data[0:sample_num] + + return data + + def load_speaker_to_label(self): + """Load the utterance label map content. + In vector domain, we call the utterance label as speaker label. + The speaker label is real speaker label in speaker verification domain, + and in language identification is language label. + """ + if not self.label2id_path: + logger.warning("No speaker id to label file") + return + + with open(self.label2id_path, 'r') as f: + for line in f.readlines(): + label_name, label_id = line.strip().split(' ') + self.label2id[label_name] = int(label_id) + self.id2label[int(label_id)] = label_name + + def convert_to_record(self, idx: int): + """convert the dataset sample to training record the CSV Dataset + + Args: + idx (int) : the request index in all the dataset + """ + sample = self.data[idx] + + record = {} + # To show all fields in a namedtuple: `type(sample)._fields` + for field in fields(sample): + record[field.name] = getattr(sample, field.name) + + waveform, sr = load_audio(record['wav']) + + # random select a chunk audio samples from the audio + if self.config and self.config.random_chunk: + num_wav_samples = waveform.shape[0] + num_chunk_samples = int(self.config.chunk_duration * sr) + start = random.randint(0, num_wav_samples - num_chunk_samples - 1) + stop = start + num_chunk_samples + else: + start = record['start'] + stop = record['stop'] + + # we only return the waveform as feat + waveform = waveform[start:stop] + + # all availabel feature type is in feat_funcs + assert self.feat_type in feat_funcs.keys(), \ + f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}" + feat_func = feat_funcs[self.feat_type] + feat = feat_func( + waveform, sr=sr, **self.feat_config) if feat_func else waveform + + record.update({'feat': feat}) + if self.label2id: + record.update({'label': self.label2id[record['label']]}) + + return record + + def __getitem__(self, idx): + """Return the specific index sample + + Args: + idx (int) : the request index in all the dataset + """ + return self.convert_to_record(idx) + + def __len__(self): + """Return the dataset length + + Returns: + int: the length num of the dataset + """ + return len(self.data) diff --git a/paddlespeech/vector/io/dataset_from_json.py b/paddlespeech/vector/io/dataset_from_json.py new file mode 100644 index 000000000..5ffd2c186 --- /dev/null +++ b/paddlespeech/vector/io/dataset_from_json.py @@ -0,0 +1,116 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json + +from dataclasses import dataclass +from dataclasses import fields +from paddle.io import Dataset + +from paddleaudio import load as load_audio +from paddleaudio.compliance.librosa import melspectrogram +from paddleaudio.compliance.librosa import mfcc + + +@dataclass +class meta_info: + """the audio meta info in the vector JSONDataset + Args: + id (str): the segment name + duration (float): segment time + wav (str): wav file path + start (int): start point in the original wav file + stop (int): stop point in the original wav file + lab_id (str): the record id + """ + id: str + duration: float + wav: str + start: int + stop: int + record_id: str + + +# json dataset support feature type +feat_funcs = { + 'raw': None, + 'melspectrogram': melspectrogram, + 'mfcc': mfcc, +} + + +class JSONDataset(Dataset): + """ + dataset from json file. + """ + + def __init__(self, json_file: str, feat_type: str='raw', **kwargs): + """ + Ags: + json_file (:obj:`str`): Data prep JSON file. + labels (:obj:`List[int]`): Labels of audio files. + feat_type (:obj:`str`, `optional`, defaults to `raw`): + It identifies the feature type that user wants to extrace of an audio file. + """ + if feat_type not in feat_funcs.keys(): + raise RuntimeError( + f"Unknown feat_type: {feat_type}, it must be one in {list(feat_funcs.keys())}" + ) + + self.json_file = json_file + self.feat_type = feat_type + self.feat_config = kwargs + self._data = self._get_data() + super(JSONDataset, self).__init__() + + def _get_data(self): + with open(self.json_file, "r") as f: + meta_data = json.load(f) + data = [] + for key in meta_data: + sub_seg = meta_data[key]["wav"] + wav = sub_seg["file"] + duration = sub_seg["duration"] + start = sub_seg["start"] + stop = sub_seg["stop"] + rec_id = str(key).rsplit("_", 2)[0] + data.append( + meta_info( + str(key), + float(duration), wav, int(start), int(stop), str(rec_id))) + return data + + def _convert_to_record(self, idx: int): + sample = self._data[idx] + + record = {} + # To show all fields in a namedtuple + for field in fields(sample): + record[field.name] = getattr(sample, field.name) + + waveform, sr = load_audio(record['wav']) + waveform = waveform[record['start']:record['stop']] + + feat_func = feat_funcs[self.feat_type] + feat = feat_func( + waveform, sr=sr, **self.feat_config) if feat_func else waveform + + record.update({'feat': feat}) + + return record + + def __getitem__(self, idx): + return self._convert_to_record(idx) + + def __len__(self): + return len(self._data) diff --git a/paddlespeech/vector/io/embedding_norm.py b/paddlespeech/vector/io/embedding_norm.py new file mode 100644 index 000000000..619f37101 --- /dev/null +++ b/paddlespeech/vector/io/embedding_norm.py @@ -0,0 +1,214 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict + +import paddle + + +class InputNormalization: + spk_dict_mean: Dict[int, paddle.Tensor] + spk_dict_std: Dict[int, paddle.Tensor] + spk_dict_count: Dict[int, int] + + def __init__( + self, + mean_norm=True, + std_norm=True, + norm_type="global", ): + """Do feature or embedding mean and std norm + + Args: + mean_norm (bool, optional): mean norm flag. Defaults to True. + std_norm (bool, optional): std norm flag. Defaults to True. + norm_type (str, optional): norm type. Defaults to "global". + """ + super().__init__() + self.training = True + self.mean_norm = mean_norm + self.std_norm = std_norm + self.norm_type = norm_type + self.glob_mean = paddle.to_tensor([0], dtype="float32") + self.glob_std = paddle.to_tensor([0], dtype="float32") + self.spk_dict_mean = {} + self.spk_dict_std = {} + self.spk_dict_count = {} + self.weight = 1.0 + self.count = 0 + self.eps = 1e-10 + + def __call__(self, + x, + lengths, + spk_ids=paddle.to_tensor([], dtype="float32")): + """Returns the tensor with the surrounding context. + Args: + x (paddle.Tensor): A batch of tensors. + lengths (paddle.Tensor): A batch of tensors containing the relative length of each + sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid + computing stats on zero-padded steps. + spk_ids (_type_, optional): tensor containing the ids of each speaker (e.g, [0 10 6]). + It is used to perform per-speaker normalization when + norm_type='speaker'. Defaults to paddle.to_tensor([], dtype="float32"). + Returns: + paddle.Tensor: The normalized feature or embedding + """ + N_batches = x.shape[0] + # print(f"x shape: {x.shape[1]}") + current_means = [] + current_stds = [] + + for snt_id in range(N_batches): + + # Avoiding padded time steps + # actual size is the actual time data length + actual_size = paddle.round(lengths[snt_id] * + x.shape[1]).astype("int32") + # computing actual time data statistics + current_mean, current_std = self._compute_current_stats( + x[snt_id, 0:actual_size, ...].unsqueeze(0)) + current_means.append(current_mean) + current_stds.append(current_std) + + if self.norm_type == "global": + current_mean = paddle.mean(paddle.stack(current_means), axis=0) + current_std = paddle.mean(paddle.stack(current_stds), axis=0) + + if self.norm_type == "global": + + if self.training: + if self.count == 0: + self.glob_mean = current_mean + self.glob_std = current_std + + else: + self.weight = 1 / (self.count + 1) + + self.glob_mean = ( + 1 - self.weight + ) * self.glob_mean + self.weight * current_mean + + self.glob_std = ( + 1 - self.weight + ) * self.glob_std + self.weight * current_std + + self.glob_mean.detach() + self.glob_std.detach() + + self.count = self.count + 1 + x = (x - self.glob_mean) / (self.glob_std) + return x + + def _compute_current_stats(self, x): + """Returns the tensor with the surrounding context. + + Args: + x (paddle.Tensor): A batch of tensors. + + Returns: + the statistics of the data + """ + # Compute current mean + if self.mean_norm: + current_mean = paddle.mean(x, axis=0).detach() + else: + current_mean = paddle.to_tensor([0.0], dtype="float32") + + # Compute current std + if self.std_norm: + current_std = paddle.std(x, axis=0).detach() + else: + current_std = paddle.to_tensor([1.0], dtype="float32") + + # Improving numerical stability of std + current_std = paddle.maximum(current_std, + self.eps * paddle.ones_like(current_std)) + + return current_mean, current_std + + def _statistics_dict(self): + """Fills the dictionary containing the normalization statistics. + """ + state = {} + state["count"] = self.count + state["glob_mean"] = self.glob_mean + state["glob_std"] = self.glob_std + state["spk_dict_mean"] = self.spk_dict_mean + state["spk_dict_std"] = self.spk_dict_std + state["spk_dict_count"] = self.spk_dict_count + + return state + + def _load_statistics_dict(self, state): + """Loads the dictionary containing the statistics. + + Arguments + --------- + state : dict + A dictionary containing the normalization statistics. + """ + self.count = state["count"] + if isinstance(state["glob_mean"], int): + self.glob_mean = state["glob_mean"] + self.glob_std = state["glob_std"] + else: + self.glob_mean = state["glob_mean"] # .to(self.device_inp) + self.glob_std = state["glob_std"] # .to(self.device_inp) + + # Loading the spk_dict_mean in the right device + self.spk_dict_mean = {} + for spk in state["spk_dict_mean"]: + self.spk_dict_mean[spk] = state["spk_dict_mean"][spk] + + # Loading the spk_dict_std in the right device + self.spk_dict_std = {} + for spk in state["spk_dict_std"]: + self.spk_dict_std[spk] = state["spk_dict_std"][spk] + + self.spk_dict_count = state["spk_dict_count"] + + return state + + def to(self, device): + """Puts the needed tensors in the right device. + """ + self = super(InputNormalization, self).to(device) + self.glob_mean = self.glob_mean.to(device) + self.glob_std = self.glob_std.to(device) + for spk in self.spk_dict_mean: + self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device) + self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device) + return self + + def save(self, path): + """Save statistic dictionary. + + Args: + path (str): A path where to save the dictionary. + """ + stats = self._statistics_dict() + paddle.save(stats, path) + + def _load(self, path, end_of_epoch=False, device=None): + """Load statistic dictionary. + + Arguments + --------- + path : str + The path of the statistic dictionary + device : str, None + Passed to paddle.load(..., map_location=device) + """ + del end_of_epoch # Unused here. + stats = paddle.load(path, map_location=device) + self._load_statistics_dict(stats) diff --git a/paddlespeech/vector/models/ecapa_tdnn.py b/paddlespeech/vector/models/ecapa_tdnn.py index 0e7287cd3..895ff13f4 100644 --- a/paddlespeech/vector/models/ecapa_tdnn.py +++ b/paddlespeech/vector/models/ecapa_tdnn.py @@ -79,6 +79,20 @@ class Conv1d(nn.Layer): bias_attr=bias, ) def forward(self, x): + """Do conv1d forward + + Args: + x (paddle.Tensor): [N, C, L] input data, + N is the batch, + C is the data dimension, + L is the time + + Raises: + ValueError: only support the same padding type + + Returns: + paddle.Tensor: the value of conv1d + """ if self.padding == "same": x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride) @@ -88,6 +102,20 @@ class Conv1d(nn.Layer): return self.conv(x) def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int): + """Padding the input data + + Args: + x (paddle.Tensor): [N, C, L] input data + N is the batch, + C is the data dimension, + L is the time + kernel_size (int): 1-d convolution kernel size + dilation (int): 1-d convolution dilation + stride (int): 1-d convolution stride + + Returns: + paddle.Tensor: the padded input data + """ L_in = x.shape[-1] # Detecting input shape padding = self._get_padding_elem(L_in, stride, kernel_size, dilation) # Time padding @@ -101,6 +129,17 @@ class Conv1d(nn.Layer): stride: int, kernel_size: int, dilation: int): + """Calculate the padding value in same mode + + Args: + L_in (int): the times of the input data, + stride (int): 1-d convolution stride + kernel_size (int): 1-d convolution kernel size + dilation (int): 1-d convolution stride + + Returns: + int: return the padding value in same mode + """ if stride > 1: n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1) L_out = stride * (n_steps - 1) + kernel_size * dilation @@ -245,6 +284,13 @@ class SEBlock(nn.Layer): class AttentiveStatisticsPooling(nn.Layer): def __init__(self, channels, attention_channels=128, global_context=True): + """Compute the speaker verification statistics + The detail info is section 3.1 in https://arxiv.org/pdf/1709.01507.pdf + Args: + channels (int): input data channel or data dimension + attention_channels (int, optional): attention dimension. Defaults to 128. + global_context (bool, optional): If use the global context information. Defaults to True. + """ super().__init__() self.eps = 1e-12 diff --git a/paddlespeech/vector/utils/time.py b/paddlespeech/vector/utils/time.py index 8e85b0e12..9dfbbe1f7 100644 --- a/paddlespeech/vector/utils/time.py +++ b/paddlespeech/vector/utils/time.py @@ -23,6 +23,7 @@ class Timer(object): self.last_start_step = 0 self.current_step = 0 self._is_running = True + self.cur_ips = 0 def start(self): self.last_time = time.time() @@ -43,12 +44,17 @@ class Timer(object): self.last_start_step = self.current_step time_used = time.time() - self.last_time self.last_time = time.time() + self.cur_ips = run_steps / time_used return time_used / run_steps @property def is_running(self) -> bool: return self._is_running + @property + def ips(self) -> float: + return self.cur_ips + @property def eta(self) -> str: if not self.is_running: diff --git a/paddlespeech/vector/utils/vector_utils.py b/paddlespeech/vector/utils/vector_utils.py new file mode 100644 index 000000000..46de7ffaa --- /dev/null +++ b/paddlespeech/vector/utils/vector_utils.py @@ -0,0 +1,32 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def get_chunks(seg_dur, audio_id, audio_duration): + """Get all chunk segments from a utterance + + Args: + seg_dur (float): segment chunk duration, seconds + audio_id (str): utterance name, + audio_duration (float): utterance duration, seconds + + Returns: + List: all the chunk segments + """ + num_chunks = int(audio_duration / seg_dur) # all in seconds + chunk_lst = [ + audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur) + for i in range(num_chunks) + ] + return chunk_lst diff --git a/speechx/examples/aishell/local/compute-wer.py b/speechx/examples/aishell/local/compute-wer.py new file mode 100755 index 000000000..a3eefc0dc --- /dev/null +++ b/speechx/examples/aishell/local/compute-wer.py @@ -0,0 +1,500 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + + +import re, sys, unicodedata +import codecs + +remove_tag = True +spacelist= [' ', '\t', '\r', '\n'] +puncts = ['!', ',', '?', + '、', '。', '!', ',', ';', '?', + ':', '「', '」', '︰', '『', '』', '《', '》'] + +def characterize(string) : + res = [] + i = 0 + while i < len(string): + char = string[i] + if char in puncts: + i += 1 + continue + cat1 = unicodedata.category(char) + #https://unicodebook.readthedocs.io/unicode.html#unicode-categories + if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned + i += 1 + continue + if cat1 == 'Lo': # letter-other + res.append(char) + i += 1 + else: + # some input looks like: , we want to separate it to two words. + sep = ' ' + if char == '<': sep = '>' + j = i+1 + while j < len(string): + c = string[j] + if ord(c) >= 128 or (c in spacelist) or (c==sep): + break + j += 1 + if j < len(string) and string[j] == '>': + j += 1 + res.append(string[i:j]) + i = j + return res + +def stripoff_tags(x): + if not x: return '' + chars = [] + i = 0; T=len(x) + while i < T: + if x[i] == '<': + while i < T and x[i] != '>': + i += 1 + i += 1 + else: + chars.append(x[i]) + i += 1 + return ''.join(chars) + + +def normalize(sentence, ignore_words, cs, split=None): + """ sentence, ignore_words are both in unicode + """ + new_sentence = [] + for token in sentence: + x = token + if not cs: + x = x.upper() + if x in ignore_words: + continue + if remove_tag: + x = stripoff_tags(x) + if not x: + continue + if split and x in split: + new_sentence += split[x] + else: + new_sentence.append(x) + return new_sentence + +class Calculator : + def __init__(self) : + self.data = {} + self.space = [] + self.cost = {} + self.cost['cor'] = 0 + self.cost['sub'] = 1 + self.cost['del'] = 1 + self.cost['ins'] = 1 + def calculate(self, lab, rec) : + # Initialization + lab.insert(0, '') + rec.insert(0, '') + while len(self.space) < len(lab) : + self.space.append([]) + for row in self.space : + for element in row : + element['dist'] = 0 + element['error'] = 'non' + while len(row) < len(rec) : + row.append({'dist' : 0, 'error' : 'non'}) + for i in range(len(lab)) : + self.space[i][0]['dist'] = i + self.space[i][0]['error'] = 'del' + for j in range(len(rec)) : + self.space[0][j]['dist'] = j + self.space[0][j]['error'] = 'ins' + self.space[0][0]['error'] = 'non' + for token in lab : + if token not in self.data and len(token) > 0 : + self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0} + for token in rec : + if token not in self.data and len(token) > 0 : + self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0} + # Computing edit distance + for i, lab_token in enumerate(lab) : + for j, rec_token in enumerate(rec) : + if i == 0 or j == 0 : + continue + min_dist = sys.maxsize + min_error = 'none' + dist = self.space[i-1][j]['dist'] + self.cost['del'] + error = 'del' + if dist < min_dist : + min_dist = dist + min_error = error + dist = self.space[i][j-1]['dist'] + self.cost['ins'] + error = 'ins' + if dist < min_dist : + min_dist = dist + min_error = error + if lab_token == rec_token : + dist = self.space[i-1][j-1]['dist'] + self.cost['cor'] + error = 'cor' + else : + dist = self.space[i-1][j-1]['dist'] + self.cost['sub'] + error = 'sub' + if dist < min_dist : + min_dist = dist + min_error = error + self.space[i][j]['dist'] = min_dist + self.space[i][j]['error'] = min_error + # Tracing back + result = {'lab':[], 'rec':[], 'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} + i = len(lab) - 1 + j = len(rec) - 1 + while True : + if self.space[i][j]['error'] == 'cor' : # correct + if len(lab[i]) > 0 : + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 + result['all'] = result['all'] + 1 + result['cor'] = result['cor'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'sub' : # substitution + if len(lab[i]) > 0 : + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 + result['all'] = result['all'] + 1 + result['sub'] = result['sub'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'del' : # deletion + if len(lab[i]) > 0 : + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 + result['all'] = result['all'] + 1 + result['del'] = result['del'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, "") + i = i - 1 + elif self.space[i][j]['error'] == 'ins' : # insertion + if len(rec[j]) > 0 : + self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 + result['ins'] = result['ins'] + 1 + result['lab'].insert(0, "") + result['rec'].insert(0, rec[j]) + j = j - 1 + elif self.space[i][j]['error'] == 'non' : # starting point + break + else : # shouldn't reach here + print('this should not happen , i = {i} , j = {j} , error = {error}'.format(i = i, j = j, error = self.space[i][j]['error'])) + return result + def overall(self) : + result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} + for token in self.data : + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + def cluster(self, data) : + result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} + for token in data : + if token in self.data : + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + def keys(self) : + return list(self.data.keys()) + +def width(string): + return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) + +def default_cluster(word) : + unicode_names = [ unicodedata.name(char) for char in word ] + for i in reversed(range(len(unicode_names))) : + if unicode_names[i].startswith('DIGIT') : # 1 + unicode_names[i] = 'Number' # 'DIGIT' + elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or + unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')) : + # 明 / 郎 + unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH' + elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or + unicode_names[i].startswith('LATIN SMALL LETTER')) : + # A / a + unicode_names[i] = 'English' # 'LATIN LETTER' + elif unicode_names[i].startswith('HIRAGANA LETTER') : # は こ め + unicode_names[i] = 'Japanese' # 'GANA LETTER' + elif (unicode_names[i].startswith('AMPERSAND') or + unicode_names[i].startswith('APOSTROPHE') or + unicode_names[i].startswith('COMMERCIAL AT') or + unicode_names[i].startswith('DEGREE CELSIUS') or + unicode_names[i].startswith('EQUALS SIGN') or + unicode_names[i].startswith('FULL STOP') or + unicode_names[i].startswith('HYPHEN-MINUS') or + unicode_names[i].startswith('LOW LINE') or + unicode_names[i].startswith('NUMBER SIGN') or + unicode_names[i].startswith('PLUS SIGN') or + unicode_names[i].startswith('SEMICOLON')) : + # & / ' / @ / ℃ / = / . / - / _ / # / + / ; + del unicode_names[i] + else : + return 'Other' + if len(unicode_names) == 0 : + return 'Other' + if len(unicode_names) == 1 : + return unicode_names[0] + for i in range(len(unicode_names)-1) : + if unicode_names[i] != unicode_names[i+1] : + return 'Other' + return unicode_names[0] + +def usage() : + print("compute-wer.py : compute word error rate (WER) and align recognition results and references.") + print(" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer") + +if __name__ == '__main__': + if len(sys.argv) == 1 : + usage() + sys.exit(0) + calculator = Calculator() + cluster_file = '' + ignore_words = set() + tochar = False + verbose= 1 + padding_symbol= ' ' + case_sensitive = False + max_words_per_line = sys.maxsize + split = None + while len(sys.argv) > 3: + a = '--maxw=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):] + del sys.argv[1] + max_words_per_line = int(b) + continue + a = '--rt=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + remove_tag = (b == 'true') or (b != '0') + continue + a = '--cs=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + case_sensitive = (b == 'true') or (b != '0') + continue + a = '--cluster=' + if sys.argv[1].startswith(a): + cluster_file = sys.argv[1][len(a):] + del sys.argv[1] + continue + a = '--splitfile=' + if sys.argv[1].startswith(a): + split_file = sys.argv[1][len(a):] + del sys.argv[1] + split = dict() + with codecs.open(split_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + words = line.strip().split() + if len(words) >= 2: + split[words[0]] = words[1:] + continue + a = '--ig=' + if sys.argv[1].startswith(a): + ignore_file = sys.argv[1][len(a):] + del sys.argv[1] + with codecs.open(ignore_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + line = line.strip() + if len(line) > 0: + ignore_words.add(line) + continue + a = '--char=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + tochar = (b == 'true') or (b != '0') + continue + a = '--v=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + verbose=0 + try: + verbose=int(b) + except: + if b == 'true' or b != '0': + verbose = 1 + continue + a = '--padding-symbol=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + if b == 'space': + padding_symbol= ' ' + elif b == 'underline': + padding_symbol= '_' + continue + if True or sys.argv[1].startswith('-'): + #ignore invalid switch + del sys.argv[1] + continue + + if not case_sensitive: + ig=set([w.upper() for w in ignore_words]) + ignore_words = ig + + default_clusters = {} + default_words = {} + + ref_file = sys.argv[1] + hyp_file = sys.argv[2] + rec_set = {} + if split and not case_sensitive: + newsplit = dict() + for w in split: + words = split[w] + for i in range(len(words)): + words[i] = words[i].upper() + newsplit[w.upper()] = words + split = newsplit + + with codecs.open(hyp_file, 'r', 'utf-8') as fh: + for line in fh: + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array)==0: continue + fid = array[0] + rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, split) + + # compute error rate on the interaction of reference file and hyp file + for line in open(ref_file, 'r', encoding='utf-8') : + if tochar: + array = characterize(line) + else: + array = line.rstrip('\n').split() + if len(array)==0: continue + fid = array[0] + if fid not in rec_set: + continue + lab = normalize(array[1:], ignore_words, case_sensitive, split) + rec = rec_set[fid] + if verbose: + print('\nutt: %s' % fid) + + for word in rec + lab : + if word not in default_words : + default_cluster_name = default_cluster(word) + if default_cluster_name not in default_clusters : + default_clusters[default_cluster_name] = {} + if word not in default_clusters[default_cluster_name] : + default_clusters[default_cluster_name][word] = 1 + default_words[word] = default_cluster_name + + result = calculator.calculate(lab, rec) + if verbose: + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('WER: %4.2f %%' % wer, end = ' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) + space = {} + space['lab'] = [] + space['rec'] = [] + for idx in range(len(result['lab'])) : + len_lab = width(result['lab'][idx]) + len_rec = width(result['rec'][idx]) + length = max(len_lab, len_rec) + space['lab'].append(length-len_lab) + space['rec'].append(length-len_rec) + upper_lab = len(result['lab']) + upper_rec = len(result['rec']) + lab1, rec1 = 0, 0 + while lab1 < upper_lab or rec1 < upper_rec: + if verbose > 1: + print('lab(%s):' % fid.encode('utf-8'), end = ' ') + else: + print('lab:', end = ' ') + lab2 = min(upper_lab, lab1 + max_words_per_line) + for idx in range(lab1, lab2): + token = result['lab'][idx] + print('{token}'.format(token = token), end = '') + for n in range(space['lab'][idx]) : + print(padding_symbol, end = '') + print(' ',end='') + print() + if verbose > 1: + print('rec(%s):' % fid.encode('utf-8'), end = ' ') + else: + print('rec:', end = ' ') + rec2 = min(upper_rec, rec1 + max_words_per_line) + for idx in range(rec1, rec2): + token = result['rec'][idx] + print('{token}'.format(token = token), end = '') + for n in range(space['rec'][idx]) : + print(padding_symbol, end = '') + print(' ',end='') + print('\n', end='\n') + lab1 = lab2 + rec1 = rec2 + + if verbose: + print('===========================================================================') + print() + + result = calculator.overall() + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('Overall -> %4.2f %%' % wer, end = ' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) + if not verbose: + print() + + if verbose: + for cluster_id in default_clusters : + result = calculator.cluster([ k for k in default_clusters[cluster_id] ]) + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) + if len(cluster_file) > 0 : # compute separated WERs for word clusters + cluster_id = '' + cluster = [] + for line in open(cluster_file, 'r', encoding='utf-8') : + for token in line.decode('utf-8').rstrip('\n').split() : + # end of cluster reached, like + if token[0:2] == '' and \ + token.lstrip('') == cluster_id : + result = calculator.cluster(cluster) + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) + cluster_id = '' + cluster = [] + # begin of cluster reached, like + elif token[0] == '<' and token[len(token)-1] == '>' and \ + cluster_id == '' : + cluster_id = token.lstrip('<').rstrip('>') + cluster = [] + # general terms, like WEATHER / CAR / ... + else : + cluster.append(token) + print() + print('===========================================================================') diff --git a/speechx/examples/aishell/local/split_data.sh b/speechx/examples/aishell/local/split_data.sh new file mode 100755 index 000000000..df454d6cf --- /dev/null +++ b/speechx/examples/aishell/local/split_data.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +data=$1 +feat_scp=$2 +split_feat_name=$3 +numsplit=$4 + + +if ! [ "$numsplit" -gt 0 ]; then + echo "Invalid num-split argument"; + exit 1; +fi + +directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n; done) +feat_split_scp=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n/${split_feat_name}; done) +echo $feat_split_scp +# if this mkdir fails due to argument-list being too long, iterate. +if ! mkdir -p $directories >&/dev/null; then + for n in `seq $numsplit`; do + mkdir -p $data/split${numsplit}/$n + done +fi + +utils/split_scp.pl $feat_scp $feat_split_scp diff --git a/speechx/examples/aishell/path.sh b/speechx/examples/aishell/path.sh new file mode 100644 index 000000000..a0e7c9aed --- /dev/null +++ b/speechx/examples/aishell/path.sh @@ -0,0 +1,14 @@ +# This contains the locations of binarys build required for running the examples. + +SPEECHX_ROOT=$PWD/../.. +SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples + +SPEECHX_TOOLS=$SPEECHX_ROOT/tools +TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin + +[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; } + +export LC_AL=C + +SPEECHX_BIN=$SPEECHX_EXAMPLES/decoder:$SPEECHX_EXAMPLES/feat +export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN diff --git a/speechx/examples/aishell/run.sh b/speechx/examples/aishell/run.sh new file mode 100755 index 000000000..8a16a8650 --- /dev/null +++ b/speechx/examples/aishell/run.sh @@ -0,0 +1,113 @@ +#!/bin/bash +set +x +set -e + +. path.sh + +# 1. compile +if [ ! -d ${SPEECHX_EXAMPLES} ]; then + pushd ${SPEECHX_ROOT} + bash build.sh + popd +fi + + +# 2. download model +if [ ! -d ../paddle_asr_model ]; then + wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz + tar xzfv paddle_asr_model.tar.gz + mv ./paddle_asr_model ../ + # produce wav scp + echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp +fi + +mkdir -p data +data=$PWD/data +aishell_wav_scp=aishell_test.scp +if [ ! -d $data/test ]; then + wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip + unzip -d $data aishell_test.zip + realpath $data/test/*/*.wav > $data/wavlist + awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id + paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp +fi + +model_dir=$PWD/aishell_ds2_online_model +if [ ! -d $model_dir ]; then + mkdir -p $model_dir + wget -P $model_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz + tar xzfv $model_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $model_dir +fi + +# 3. make feature +aishell_online_model=$model_dir/exp/deepspeech2_online/checkpoints +lm_model_dir=../paddle_asr_model +label_file=./aishell_result +wer=./aishell_wer + +nj=40 +export GLOG_logtostderr=1 + +#./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj + +data=$PWD/data +# 3. gen linear feat +cmvn=$PWD/cmvn.ark +cmvn_json2binary_main --json_file=$model_dir/data/mean_std.json --cmvn_write_path=$cmvn + +utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat_log \ +linear_spectrogram_without_db_norm_main \ + --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ + --feature_wspecifier=ark,scp:$data/split${nj}/JOB/feat.ark,$data/split${nj}/JOB/feat.scp \ + --cmvn_file=$cmvn \ + --streaming_chunk=0.36 + +text=$data/test/text + +# 4. recognizer +utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log \ + offline_decoder_sliding_chunk_main \ + --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ + --model_path=$aishell_online_model/avg_1.jit.pdmodel \ + --param_path=$aishell_online_model/avg_1.jit.pdiparams \ + --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ + --dict_file=$lm_model_dir/vocab.txt \ + --result_wspecifier=ark,t:$data/split${nj}/JOB/result + +cat $data/split${nj}/*/result > ${label_file} +local/compute-wer.py --char=1 --v=1 ${label_file} $text > ${wer} + +# 4. decode with lm +utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log_lm \ + offline_decoder_sliding_chunk_main \ + --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ + --model_path=$aishell_online_model/avg_1.jit.pdmodel \ + --param_path=$aishell_online_model/avg_1.jit.pdiparams \ + --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ + --dict_file=$lm_model_dir/vocab.txt \ + --lm_path=$lm_model_dir/avg_1.jit.klm \ + --result_wspecifier=ark,t:$data/split${nj}/JOB/result_lm + +cat $data/split${nj}/*/result_lm > ${label_file}_lm +local/compute-wer.py --char=1 --v=1 ${label_file}_lm $text > ${wer}_lm + +graph_dir=./aishell_graph +if [ ! -d $ ]; then + wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip + unzip -d aishell_graph.zip +fi + +# 5. test TLG decoder +utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log_tlg \ + offline_wfst_decoder_main \ + --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ + --model_path=$aishell_online_model/avg_1.jit.pdmodel \ + --param_path=$aishell_online_model/avg_1.jit.pdiparams \ + --word_symbol_table=$graph_dir/words.txt \ + --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ + --graph_path=$graph_dir/TLG.fst --max_active=7500 \ + --acoustic_scale=1.2 \ + --result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg + +cat $data/split${nj}/*/result_tlg > ${label_file}_tlg +local/compute-wer.py --char=1 --v=1 ${label_file}_tlg $text > ${wer}_tlg \ No newline at end of file diff --git a/speechx/examples/aishell/utils b/speechx/examples/aishell/utils new file mode 120000 index 000000000..973afe674 --- /dev/null +++ b/speechx/examples/aishell/utils @@ -0,0 +1 @@ +../../../utils \ No newline at end of file diff --git a/speechx/examples/decoder/CMakeLists.txt b/speechx/examples/decoder/CMakeLists.txt index ded423e94..d446a6715 100644 --- a/speechx/examples/decoder/CMakeLists.txt +++ b/speechx/examples/decoder/CMakeLists.txt @@ -8,6 +8,10 @@ add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_ target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) +add_executable(offline_wfst_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_wfst_decoder_main.cc) +target_include_directories(offline_wfst_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(offline_wfst_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS}) + add_executable(decoder_test_main ${CMAKE_CURRENT_SOURCE_DIR}/decoder_test_main.cc) target_include_directories(decoder_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_link_libraries(decoder_test_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) diff --git a/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc b/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc index 7f6c572ca..40092de31 100644 --- a/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc +++ b/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc @@ -22,30 +22,37 @@ #include "nnet/decodable.h" #include "nnet/paddle_nnet.h" -DEFINE_string(feature_respecifier, "", "test feature rspecifier"); +DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); +DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); -DEFINE_string(lm_path, "lm.klm", "language model"); +DEFINE_string(lm_path, "", "language model"); DEFINE_int32(receptive_field_length, 7, "receptive field of two CNN(kernel=5) downsampling module."); DEFINE_int32(downsampling_rate, 4, "two CNN(kernel=5) module downsampling rate."); +DEFINE_string(model_output_names, + "save_infer_model/scale_0.tmp_1,save_infer_model/" + "scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/" + "scale_3.tmp_1", + "model output names"); +DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names"); using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; - // test ds2 online decoder by feeding speech feature int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); kaldi::SequentialBaseFloatMatrixReader feature_reader( - FLAGS_feature_respecifier); + FLAGS_feature_rspecifier); + kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); std::string model_graph = FLAGS_model_path; std::string model_params = FLAGS_param_path; std::string dict_file = FLAGS_dict_file; @@ -55,7 +62,6 @@ int main(int argc, char* argv[]) { LOG(INFO) << "dict path: " << dict_file; LOG(INFO) << "lm path: " << lm_path; - int32 num_done = 0, num_err = 0; ppspeech::CTCBeamSearchOptions opts; @@ -66,7 +72,8 @@ int main(int argc, char* argv[]) { ppspeech::ModelOptions model_opts; model_opts.model_path = model_graph; model_opts.params_path = model_params; - model_opts.cache_shape = "5-1-1024,5-1-1024"; + model_opts.cache_shape = FLAGS_model_cache_names; + model_opts.output_names = FLAGS_model_output_names; std::shared_ptr nnet( new ppspeech::PaddleNnet(model_opts)); std::shared_ptr raw_data(new ppspeech::DataCache()); @@ -129,9 +136,16 @@ int main(int argc, char* argv[]) { } std::string result; result = decoder.GetFinalBestPath(); - KALDI_LOG << " the result of " << utt << " is " << result; decodable->Reset(); decoder.Reset(); + if (result.empty()) { + // the TokenWriter can not write empty string. + ++num_err; + KALDI_LOG << " the result of " << utt << " is empty"; + continue; + } + KALDI_LOG << " the result of " << utt << " is " << result; + result_writer.Write(utt, result); ++num_done; } diff --git a/speechx/examples/decoder/offline_wfst_decoder_main.cc b/speechx/examples/decoder/offline_wfst_decoder_main.cc new file mode 100644 index 000000000..06460a45e --- /dev/null +++ b/speechx/examples/decoder/offline_wfst_decoder_main.cc @@ -0,0 +1,158 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// todo refactor, repalce with gtest + +#include "base/flags.h" +#include "base/log.h" +#include "decoder/ctc_tlg_decoder.h" +#include "frontend/audio/data_cache.h" +#include "kaldi/util/table-types.h" +#include "nnet/decodable.h" +#include "nnet/paddle_nnet.h" + +DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); +DEFINE_string(result_wspecifier, "", "test result wspecifier"); +DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); +DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); +DEFINE_string(word_symbol_table, "words.txt", "word symbol table"); +DEFINE_string(graph_path, "TLG", "decoder graph"); +DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); +DEFINE_int32(max_active, 7500, "decoder graph"); +DEFINE_int32(receptive_field_length, + 7, + "receptive field of two CNN(kernel=5) downsampling module."); +DEFINE_int32(downsampling_rate, + 4, + "two CNN(kernel=5) module downsampling rate."); +DEFINE_string(model_output_names, + "save_infer_model/scale_0.tmp_1,save_infer_model/" + "scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/" + "scale_3.tmp_1", + "model output names"); +DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names"); + +using kaldi::BaseFloat; +using kaldi::Matrix; +using std::vector; + +// test TLG decoder by feeding speech feature. +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + kaldi::SequentialBaseFloatMatrixReader feature_reader( + FLAGS_feature_rspecifier); + kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); + std::string model_graph = FLAGS_model_path; + std::string model_params = FLAGS_param_path; + std::string word_symbol_table = FLAGS_word_symbol_table; + std::string graph_path = FLAGS_graph_path; + LOG(INFO) << "model path: " << model_graph; + LOG(INFO) << "model param: " << model_params; + LOG(INFO) << "word symbol path: " << word_symbol_table; + LOG(INFO) << "graph path: " << graph_path; + + int32 num_done = 0, num_err = 0; + + ppspeech::TLGDecoderOptions opts; + opts.word_symbol_table = word_symbol_table; + opts.fst_path = graph_path; + opts.opts.max_active = FLAGS_max_active; + opts.opts.beam = 15.0; + opts.opts.lattice_beam = 7.5; + ppspeech::TLGDecoder decoder(opts); + + ppspeech::ModelOptions model_opts; + model_opts.model_path = model_graph; + model_opts.params_path = model_params; + model_opts.cache_shape = FLAGS_model_cache_names; + model_opts.output_names = FLAGS_model_output_names; + std::shared_ptr nnet( + new ppspeech::PaddleNnet(model_opts)); + std::shared_ptr raw_data(new ppspeech::DataCache()); + std::shared_ptr decodable( + new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale)); + + int32 chunk_size = FLAGS_receptive_field_length; + int32 chunk_stride = FLAGS_downsampling_rate; + int32 receptive_field_length = FLAGS_receptive_field_length; + LOG(INFO) << "chunk size (frame): " << chunk_size; + LOG(INFO) << "chunk stride (frame): " << chunk_stride; + LOG(INFO) << "receptive field (frame): " << receptive_field_length; + decoder.InitDecoder(); + + for (; !feature_reader.Done(); feature_reader.Next()) { + string utt = feature_reader.Key(); + kaldi::Matrix feature = feature_reader.Value(); + raw_data->SetDim(feature.NumCols()); + LOG(INFO) << "process utt: " << utt; + LOG(INFO) << "rows: " << feature.NumRows(); + LOG(INFO) << "cols: " << feature.NumCols(); + + int32 row_idx = 0; + int32 padding_len = 0; + int32 ori_feature_len = feature.NumRows(); + if ((feature.NumRows() - chunk_size) % chunk_stride != 0) { + padding_len = + chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride; + feature.Resize(feature.NumRows() + padding_len, + feature.NumCols(), + kaldi::kCopyData); + } + int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1; + for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { + kaldi::Vector feature_chunk(chunk_size * + feature.NumCols()); + int32 feature_chunk_size = 0; + if (ori_feature_len > chunk_idx * chunk_stride) { + feature_chunk_size = std::min( + ori_feature_len - chunk_idx * chunk_stride, chunk_size); + } + if (feature_chunk_size < receptive_field_length) break; + + int32 start = chunk_idx * chunk_stride; + for (int row_id = 0; row_id < chunk_size; ++row_id) { + kaldi::SubVector tmp(feature, start); + kaldi::SubVector f_chunk_tmp( + feature_chunk.Data() + row_id * feature.NumCols(), + feature.NumCols()); + f_chunk_tmp.CopyFromVec(tmp); + ++start; + } + raw_data->Accept(feature_chunk); + if (chunk_idx == num_chunks - 1) { + raw_data->SetFinished(); + } + decoder.AdvanceDecode(decodable); + } + std::string result; + result = decoder.GetFinalBestPath(); + decodable->Reset(); + decoder.Reset(); + if (result.empty()) { + // the TokenWriter can not write empty string. + ++num_err; + KALDI_LOG << " the result of " << utt << " is empty"; + continue; + } + KALDI_LOG << " the result of " << utt << " is " << result; + result_writer.Write(utt, result); + ++num_done; + } + + KALDI_LOG << "Done " << num_done << " utterances, " << num_err + << " with errors."; + return (num_done != 0 ? 0 : 1); +} diff --git a/speechx/examples/feat/CMakeLists.txt b/speechx/examples/feat/CMakeLists.txt index b8f516afb..d6fdb9bc6 100644 --- a/speechx/examples/feat/CMakeLists.txt +++ b/speechx/examples/feat/CMakeLists.txt @@ -7,4 +7,12 @@ target_link_libraries(mfcc-test kaldi-mfcc) add_executable(linear_spectrogram_main ${CMAKE_CURRENT_SOURCE_DIR}/linear_spectrogram_main.cc) target_include_directories(linear_spectrogram_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog) \ No newline at end of file +target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog) + +add_executable(linear_spectrogram_without_db_norm_main ${CMAKE_CURRENT_SOURCE_DIR}/linear_spectrogram_without_db_norm_main.cc) +target_include_directories(linear_spectrogram_without_db_norm_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(linear_spectrogram_without_db_norm_main frontend kaldi-util kaldi-feat-common gflags glog) + +add_executable(cmvn_json2binary_main ${CMAKE_CURRENT_SOURCE_DIR}/cmvn_json2binary_main.cc) +target_include_directories(cmvn_json2binary_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(cmvn_json2binary_main utils kaldi-util kaldi-matrix gflags glog) diff --git a/speechx/examples/feat/cmvn_json2binary_main.cc b/speechx/examples/feat/cmvn_json2binary_main.cc new file mode 100644 index 000000000..e77f983aa --- /dev/null +++ b/speechx/examples/feat/cmvn_json2binary_main.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/flags.h" +#include "base/log.h" +#include "kaldi/matrix/kaldi-matrix.h" +#include "kaldi/util/kaldi-io.h" +#include "utils/file_utils.h" +#include "utils/simdjson.h" + +DEFINE_string(json_file, "", "cmvn json file"); +DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn"); +DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)"); + +using namespace simdjson; + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + ondemand::parser parser; + padded_string json = padded_string::load(FLAGS_json_file); + ondemand::document val = parser.iterate(json); + ondemand::object doc = val; + kaldi::int32 frame_num = uint64_t(doc["frame_num"]); + auto mean_stat = doc["mean_stat"]; + std::vector mean_stat_vec; + for (double x : mean_stat) { + mean_stat_vec.push_back(x); + } + auto var_stat = doc["var_stat"]; + std::vector var_stat_vec; + for (double x : var_stat) { + var_stat_vec.push_back(x); + } + + size_t mean_size = mean_stat_vec.size(); + kaldi::Matrix cmvn_stats(2, mean_size + 1); + for (size_t idx = 0; idx < mean_size; ++idx) { + cmvn_stats(0, idx) = mean_stat_vec[idx]; + cmvn_stats(1, idx) = var_stat_vec[idx]; + } + cmvn_stats(0, mean_size) = frame_num; + kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary); + LOG(INFO) << "the json file have write into " << FLAGS_cmvn_write_path; + return 0; +} \ No newline at end of file diff --git a/speechx/examples/feat/linear_spectrogram_main.cc b/speechx/examples/feat/linear_spectrogram_main.cc index 2d75bb5df..2e70386d6 100644 --- a/speechx/examples/feat/linear_spectrogram_main.cc +++ b/speechx/examples/feat/linear_spectrogram_main.cc @@ -30,6 +30,7 @@ DEFINE_string(wav_rspecifier, "", "test wav scp path"); DEFINE_string(feature_wspecifier, "", "output feats wspecifier"); DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn"); +DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); std::vector mean_{ @@ -181,6 +182,7 @@ int main(int argc, char* argv[]) { ppspeech::LinearSpectrogramOptions opt; opt.frame_opts.frame_length_ms = 20; opt.frame_opts.frame_shift_ms = 10; + opt.streaming_chunk = FLAGS_streaming_chunk; opt.frame_opts.dither = 0.0; opt.frame_opts.remove_dc_offset = false; opt.frame_opts.window_type = "hanning"; @@ -198,7 +200,7 @@ int main(int argc, char* argv[]) { LOG(INFO) << "feat dim: " << feature_cache.Dim(); int sample_rate = 16000; - float streaming_chunk = 0.36; + float streaming_chunk = FLAGS_streaming_chunk; int chunk_sample_size = streaming_chunk * sample_rate; LOG(INFO) << "sr: " << sample_rate; LOG(INFO) << "chunk size (s): " << streaming_chunk; @@ -256,6 +258,7 @@ int main(int argc, char* argv[]) { } } feat_writer.Write(utt, features); + feature_cache.Reset(); if (num_done % 50 == 0 && num_done != 0) KALDI_VLOG(2) << "Processed " << num_done << " utterances"; diff --git a/speechx/examples/feat/linear_spectrogram_without_db_norm_main.cc b/speechx/examples/feat/linear_spectrogram_without_db_norm_main.cc new file mode 100644 index 000000000..5b875a3ee --- /dev/null +++ b/speechx/examples/feat/linear_spectrogram_without_db_norm_main.cc @@ -0,0 +1,139 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// todo refactor, repalce with gtest + +#include "base/flags.h" +#include "base/log.h" +#include "kaldi/feat/wave-reader.h" +#include "kaldi/util/kaldi-io.h" +#include "kaldi/util/table-types.h" + +#include "frontend/audio/audio_cache.h" +#include "frontend/audio/data_cache.h" +#include "frontend/audio/feature_cache.h" +#include "frontend/audio/frontend_itf.h" +#include "frontend/audio/linear_spectrogram.h" +#include "frontend/audio/normalizer.h" + +DEFINE_string(wav_rspecifier, "", "test wav scp path"); +DEFINE_string(feature_wspecifier, "", "output feats wspecifier"); +DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn"); +DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + kaldi::SequentialTableReader wav_reader( + FLAGS_wav_rspecifier); + kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier); + + int32 num_done = 0, num_err = 0; + + // feature pipeline: wave cache --> hanning + // window -->linear_spectrogram --> global cmvn -> feat cache + + std::unique_ptr data_source( + new ppspeech::AudioCache(3600 * 1600, true)); + + ppspeech::LinearSpectrogramOptions opt; + opt.frame_opts.frame_length_ms = 20; + opt.frame_opts.frame_shift_ms = 10; + opt.streaming_chunk = FLAGS_streaming_chunk; + opt.frame_opts.dither = 0.0; + opt.frame_opts.remove_dc_offset = false; + opt.frame_opts.window_type = "hanning"; + opt.frame_opts.preemph_coeff = 0.0; + LOG(INFO) << "frame length (ms): " << opt.frame_opts.frame_length_ms; + LOG(INFO) << "frame shift (ms): " << opt.frame_opts.frame_shift_ms; + + std::unique_ptr linear_spectrogram( + new ppspeech::LinearSpectrogram(opt, std::move(data_source))); + + std::unique_ptr cmvn( + new ppspeech::CMVN(FLAGS_cmvn_file, std::move(linear_spectrogram))); + + ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn)); + LOG(INFO) << "feat dim: " << feature_cache.Dim(); + + int sample_rate = 16000; + float streaming_chunk = FLAGS_streaming_chunk; + int chunk_sample_size = streaming_chunk * sample_rate; + LOG(INFO) << "sr: " << sample_rate; + LOG(INFO) << "chunk size (s): " << streaming_chunk; + LOG(INFO) << "chunk size (sample): " << chunk_sample_size; + + + for (; !wav_reader.Done(); wav_reader.Next()) { + std::string utt = wav_reader.Key(); + const kaldi::WaveData& wave_data = wav_reader.Value(); + LOG(INFO) << "process utt: " << utt; + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + LOG(INFO) << "wav len (sample): " << tot_samples; + + int sample_offset = 0; + std::vector> feats; + int feature_rows = 0; + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + kaldi::Vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk(i) = waveform(sample_offset + i); + } + + kaldi::Vector features; + feature_cache.Accept(wav_chunk); + if (cur_chunk_size < chunk_sample_size) { + feature_cache.SetFinished(); + } + feature_cache.Read(&features); + if (features.Dim() == 0) break; + + feats.push_back(features); + sample_offset += cur_chunk_size; + feature_rows += features.Dim() / feature_cache.Dim(); + } + + int cur_idx = 0; + kaldi::Matrix features(feature_rows, + feature_cache.Dim()); + for (auto feat : feats) { + int num_rows = feat.Dim() / feature_cache.Dim(); + for (int row_idx = 0; row_idx < num_rows; ++row_idx) { + for (size_t col_idx = 0; col_idx < feature_cache.Dim(); + ++col_idx) { + features(cur_idx, col_idx) = + feat(row_idx * feature_cache.Dim() + col_idx); + } + ++cur_idx; + } + } + feat_writer.Write(utt, features); + feature_cache.Reset(); + + if (num_done % 50 == 0 && num_done != 0) + KALDI_VLOG(2) << "Processed " << num_done << " utterances"; + num_done++; + } + KALDI_LOG << "Done " << num_done << " utterances, " << num_err + << " with errors."; + return (num_done != 0 ? 0 : 1); +} diff --git a/speechx/speechx/decoder/CMakeLists.txt b/speechx/speechx/decoder/CMakeLists.txt index 7cd281b66..ee0863fd5 100644 --- a/speechx/speechx/decoder/CMakeLists.txt +++ b/speechx/speechx/decoder/CMakeLists.txt @@ -6,5 +6,6 @@ add_library(decoder STATIC ctc_decoders/decoder_utils.cpp ctc_decoders/path_trie.cpp ctc_decoders/scorer.cpp + ctc_tlg_decoder.cc ) -target_link_libraries(decoder PUBLIC kenlm utils fst) \ No newline at end of file +target_link_libraries(decoder PUBLIC kenlm utils fst) diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_beam_search_decoder.cc index 5d7a4f77a..b4caa8e7b 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.cc +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.cc @@ -93,7 +93,7 @@ void CTCBeamSearch::AdvanceDecode( vector> likelihood; vector frame_prob; bool flag = - decodable->FrameLogLikelihood(num_frame_decoded_, &frame_prob); + decodable->FrameLikelihood(num_frame_decoded_, &frame_prob); if (flag == false) break; likelihood.push_back(frame_prob); AdvanceDecoding(likelihood); diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.h b/speechx/speechx/decoder/ctc_beam_search_decoder.h index 1387eee79..9d0a5d142 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.h +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.h @@ -15,7 +15,7 @@ #include "base/common.h" #include "decoder/ctc_decoders/path_trie.h" #include "decoder/ctc_decoders/scorer.h" -#include "nnet/decodable-itf.h" +#include "kaldi/decoder/decodable-itf.h" #include "util/parse-options.h" #pragma once diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.cc b/speechx/speechx/decoder/ctc_tlg_decoder.cc new file mode 100644 index 000000000..5365e7090 --- /dev/null +++ b/speechx/speechx/decoder/ctc_tlg_decoder.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/ctc_tlg_decoder.h" +namespace ppspeech { + +TLGDecoder::TLGDecoder(TLGDecoderOptions opts) { + fst_.reset(fst::Fst::Read(opts.fst_path)); + CHECK(fst_ != nullptr); + word_symbol_table_.reset( + fst::SymbolTable::ReadText(opts.word_symbol_table)); + decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts)); + decoder_->InitDecoding(); + frame_decoded_size_ = 0; +} + +void TLGDecoder::InitDecoder() { + decoder_->InitDecoding(); + frame_decoded_size_ = 0; +} + +void TLGDecoder::AdvanceDecode( + const std::shared_ptr& decodable) { + while (!decodable->IsLastFrame(frame_decoded_size_)) { + LOG(INFO) << "num frame decode: " << frame_decoded_size_; + AdvanceDecoding(decodable.get()); + } +} + +void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) { + decoder_->AdvanceDecoding(decodable, 1); + frame_decoded_size_++; +} + +void TLGDecoder::Reset() { + InitDecoder(); + return; +} + +std::string TLGDecoder::GetFinalBestPath() { + decoder_->FinalizeDecoding(); + kaldi::Lattice lat; + kaldi::LatticeWeight weight; + std::vector alignment; + std::vector words_id; + decoder_->GetBestPath(&lat, true); + fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight); + std::string words; + for (int32 idx = 0; idx < words_id.size(); ++idx) { + std::string word = word_symbol_table_->Find(words_id[idx]); + words += word; + } + return words; +} +} \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.h b/speechx/speechx/decoder/ctc_tlg_decoder.h new file mode 100644 index 000000000..361c44af5 --- /dev/null +++ b/speechx/speechx/decoder/ctc_tlg_decoder.h @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "base/basic_types.h" +#include "kaldi/decoder/decodable-itf.h" +#include "kaldi/decoder/lattice-faster-online-decoder.h" +#include "util/parse-options.h" + +namespace ppspeech { + +struct TLGDecoderOptions { + kaldi::LatticeFasterDecoderConfig opts; + // todo remove later, add into decode resource + std::string word_symbol_table; + std::string fst_path; + + TLGDecoderOptions() : word_symbol_table(""), fst_path("") {} +}; + +class TLGDecoder { + public: + explicit TLGDecoder(TLGDecoderOptions opts); + void InitDecoder(); + void Decode(); + std::string GetBestPath(); + std::vector> GetNBestPath(); + std::string GetFinalBestPath(); + int NumFrameDecoded(); + int DecodeLikelihoods(const std::vector>& probs, + std::vector& nbest_words); + void AdvanceDecode( + const std::shared_ptr& decodable); + void Reset(); + + private: + void AdvanceDecoding(kaldi::DecodableInterface* decodable); + + std::shared_ptr decoder_; + std::shared_ptr> fst_; + std::shared_ptr word_symbol_table_; + // the frame size which have decoded starts from 0. + int32 frame_decoded_size_; +}; + + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/audio/audio_cache.cc b/speechx/speechx/frontend/audio/audio_cache.cc index c3233e595..50aca4fb0 100644 --- a/speechx/speechx/frontend/audio/audio_cache.cc +++ b/speechx/speechx/frontend/audio/audio_cache.cc @@ -21,15 +21,20 @@ using kaldi::BaseFloat; using kaldi::VectorBase; using kaldi::Vector; -AudioCache::AudioCache(int buffer_size) +AudioCache::AudioCache(int buffer_size, bool convert2PCM32) : finished_(false), capacity_(buffer_size), size_(0), offset_(0), - timeout_(1) { + timeout_(1), + convert2PCM32_(convert2PCM32) { ring_buffer_.resize(capacity_); } +BaseFloat AudioCache::Convert2PCM32(BaseFloat val) { + return val * (1. / std::pow(2.0, 15)); +} + void AudioCache::Accept(const VectorBase& waves) { std::unique_lock lock(mutex_); while (size_ + waves.Dim() > ring_buffer_.size()) { @@ -38,6 +43,8 @@ void AudioCache::Accept(const VectorBase& waves) { for (size_t idx = 0; idx < waves.Dim(); ++idx) { int32 buffer_idx = (idx + offset_) % ring_buffer_.size(); ring_buffer_[buffer_idx] = waves(idx); + if (convert2PCM32_) + ring_buffer_[buffer_idx] = Convert2PCM32(waves(idx)); } size_ += waves.Dim(); } diff --git a/speechx/speechx/frontend/audio/audio_cache.h b/speechx/speechx/frontend/audio/audio_cache.h index 17e1a8389..adef12399 100644 --- a/speechx/speechx/frontend/audio/audio_cache.h +++ b/speechx/speechx/frontend/audio/audio_cache.h @@ -23,7 +23,8 @@ namespace ppspeech { // waves cache class AudioCache : public FrontendInterface { public: - explicit AudioCache(int buffer_size = kint16max); + explicit AudioCache(int buffer_size = 1000 * kint16max, + bool convert2PCM32 = false); virtual void Accept(const kaldi::VectorBase& waves); @@ -46,14 +47,17 @@ class AudioCache : public FrontendInterface { } private: + kaldi::BaseFloat Convert2PCM32(kaldi::BaseFloat val); + std::vector ring_buffer_; size_t offset_; // offset in ring_buffer_ size_t size_; // samples in ring_buffer_ now size_t capacity_; // capacity of ring_buffer_ bool finished_; // reach audio end - mutable std::mutex mutex_; + std::mutex mutex_; std::condition_variable ready_feed_condition_; kaldi::int32 timeout_; // millisecond + bool convert2PCM32_; DISALLOW_COPY_AND_ASSIGN(AudioCache); }; diff --git a/speechx/speechx/frontend/audio/cmvn.cc b/speechx/speechx/frontend/audio/cmvn.cc index 4c1ffd6a1..c7e446c92 100644 --- a/speechx/speechx/frontend/audio/cmvn.cc +++ b/speechx/speechx/frontend/audio/cmvn.cc @@ -120,4 +120,4 @@ void CMVN::ApplyCMVN(kaldi::MatrixBase* feats) { ApplyCmvn(stats_, var_norm_, feats); } -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech diff --git a/speechx/speechx/frontend/audio/linear_spectrogram.h b/speechx/speechx/frontend/audio/linear_spectrogram.h index 896c494dd..6b20b8b94 100644 --- a/speechx/speechx/frontend/audio/linear_spectrogram.h +++ b/speechx/speechx/frontend/audio/linear_spectrogram.h @@ -46,7 +46,10 @@ class LinearSpectrogram : public FrontendInterface { virtual size_t Dim() const { return dim_; } virtual void SetFinished() { base_extractor_->SetFinished(); } virtual bool IsFinished() const { return base_extractor_->IsFinished(); } - virtual void Reset() { base_extractor_->Reset(); } + virtual void Reset() { + base_extractor_->Reset(); + reminded_wav_.Resize(0); + } private: bool Compute(const kaldi::Vector& waves, diff --git a/speechx/speechx/kaldi/CMakeLists.txt b/speechx/speechx/kaldi/CMakeLists.txt index 414a6fa0c..6f7398cd1 100644 --- a/speechx/speechx/kaldi/CMakeLists.txt +++ b/speechx/speechx/kaldi/CMakeLists.txt @@ -4,3 +4,6 @@ add_subdirectory(base) add_subdirectory(util) add_subdirectory(feat) add_subdirectory(matrix) +add_subdirectory(lat) +add_subdirectory(fstext) +add_subdirectory(decoder) diff --git a/speechx/speechx/kaldi/decoder/CMakeLists.txt b/speechx/speechx/kaldi/decoder/CMakeLists.txt new file mode 100644 index 000000000..f1ee6eabb --- /dev/null +++ b/speechx/speechx/kaldi/decoder/CMakeLists.txt @@ -0,0 +1,6 @@ + +add_library(kaldi-decoder +lattice-faster-decoder.cc +lattice-faster-online-decoder.cc +) +target_link_libraries(kaldi-decoder PUBLIC kaldi-lat) diff --git a/speechx/speechx/nnet/decodable-itf.h b/speechx/speechx/kaldi/decoder/decodable-itf.h similarity index 98% rename from speechx/speechx/nnet/decodable-itf.h rename to speechx/speechx/kaldi/decoder/decodable-itf.h index 8e9a5a72a..b8ce9143e 100644 --- a/speechx/speechx/nnet/decodable-itf.h +++ b/speechx/speechx/kaldi/decoder/decodable-itf.h @@ -121,7 +121,7 @@ class DecodableInterface { /// decoding-from-matrix setting where we want to allow the last delta or /// LDA /// features to be flushed out for compatibility with the baseline setup. - virtual bool IsLastFrame(int32 frame) const = 0; + virtual bool IsLastFrame(int32 frame) = 0; /// The call NumFramesReady() will return the number of frames currently /// available @@ -143,7 +143,7 @@ class DecodableInterface { /// this is for compatibility with OpenFst). virtual int32 NumIndices() const = 0; - virtual bool FrameLogLikelihood( + virtual bool FrameLikelihood( int32 frame, std::vector* likelihood) = 0; diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc index 42d1d2af4..ae6b71600 100644 --- a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc +++ b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc @@ -1007,14 +1007,10 @@ template class LatticeFasterDecoderTpl, decoder::StdToken> template class LatticeFasterDecoderTpl, decoder::StdToken >; template class LatticeFasterDecoderTpl, decoder::StdToken >; -template class LatticeFasterDecoderTpl; -template class LatticeFasterDecoderTpl; template class LatticeFasterDecoderTpl , decoder::BackpointerToken>; template class LatticeFasterDecoderTpl, decoder::BackpointerToken >; template class LatticeFasterDecoderTpl, decoder::BackpointerToken >; -template class LatticeFasterDecoderTpl; -template class LatticeFasterDecoderTpl; } // end namespace kaldi. diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h index 2016ad571..d142a8c7d 100644 --- a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h +++ b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h @@ -23,11 +23,10 @@ #ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_H_ #define KALDI_DECODER_LATTICE_FASTER_DECODER_H_ -#include "decoder/grammar-fst.h" #include "fst/fstlib.h" #include "fst/memory.h" #include "fstext/fstext-lib.h" -#include "itf/decodable-itf.h" +#include "decoder/decodable-itf.h" #include "lat/determinize-lattice-pruned.h" #include "lat/kaldi-lattice.h" #include "util/hash-list.h" diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc index ebdace7e8..b5261503c 100644 --- a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc +++ b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc @@ -278,8 +278,8 @@ bool LatticeFasterOnlineDecoderTpl::GetRawLatticePruned( template class LatticeFasterOnlineDecoderTpl >; template class LatticeFasterOnlineDecoderTpl >; template class LatticeFasterOnlineDecoderTpl >; -template class LatticeFasterOnlineDecoderTpl; -template class LatticeFasterOnlineDecoderTpl; +//template class LatticeFasterOnlineDecoderTpl; +//template class LatticeFasterOnlineDecoderTpl; } // end namespace kaldi. diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h index 8b10996fd..f57368a44 100644 --- a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h +++ b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h @@ -30,7 +30,7 @@ #include "util/stl-utils.h" #include "util/hash-list.h" #include "fst/fstlib.h" -#include "itf/decodable-itf.h" +#include "decoder/decodable-itf.h" #include "fstext/fstext-lib.h" #include "lat/determinize-lattice-pruned.h" #include "lat/kaldi-lattice.h" diff --git a/speechx/speechx/kaldi/fstext/CMakeLists.txt b/speechx/speechx/kaldi/fstext/CMakeLists.txt new file mode 100644 index 000000000..af91fd985 --- /dev/null +++ b/speechx/speechx/kaldi/fstext/CMakeLists.txt @@ -0,0 +1,5 @@ + +add_library(kaldi-fstext +kaldi-fst-io.cc +) +target_link_libraries(kaldi-fstext PUBLIC kaldi-util) diff --git a/speechx/speechx/kaldi/fstext/determinize-lattice-inl.h b/speechx/speechx/kaldi/fstext/determinize-lattice-inl.h new file mode 100644 index 000000000..0bfbc8f41 --- /dev/null +++ b/speechx/speechx/kaldi/fstext/determinize-lattice-inl.h @@ -0,0 +1,1357 @@ +// fstext/determinize-lattice-inl.h + +// Copyright 2009-2012 Microsoft Corporation +// 2012-2013 Johns Hopkins University (Author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_FSTEXT_DETERMINIZE_LATTICE_INL_H_ +#define KALDI_FSTEXT_DETERMINIZE_LATTICE_INL_H_ +// Do not include this file directly. It is included by determinize-lattice.h + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fst { + +// This class maps back and forth from/to integer id's to sequences of strings. +// used in determinization algorithm. It is constructed in such a way that +// finding the string-id of the successor of (string, next-label) has constant +// time. + +// Note: class IntType, typically int32, is the type of the element in the +// string (typically a template argument of the CompactLatticeWeightTpl). + +template +class LatticeStringRepository { + public: + struct Entry { + const Entry *parent; // NULL for empty string. + IntType i; + inline bool operator==(const Entry &other) const { + return (parent == other.parent && i == other.i); + } + Entry() {} + Entry(const Entry &e) : parent(e.parent), i(e.i) {} + }; + // Note: all Entry* pointers returned in function calls are + // owned by the repository itself, not by the caller! + + // Interface guarantees empty string is NULL. + inline const Entry *EmptyString() { return NULL; } + + // Returns string of "parent" with i appended. Pointer + // owned by repository + const Entry *Successor(const Entry *parent, IntType i) { + new_entry_->parent = parent; + new_entry_->i = i; + + std::pair pr = set_.insert(new_entry_); + if (pr.second) { // Was successfully inserted (was not there). We need to + // replace the element we inserted, which resides on the + // stack, with one from the heap. + const Entry *ans = new_entry_; + new_entry_ = new Entry(); + return ans; + } else { // Was not inserted because an equivalent Entry already + // existed. + return *pr.first; + } + } + + const Entry *Concatenate(const Entry *a, const Entry *b) { + if (a == NULL) + return b; + else if (b == NULL) + return a; + std::vector v; + ConvertToVector(b, &v); + const Entry *ans = a; + for (size_t i = 0; i < v.size(); i++) ans = Successor(ans, v[i]); + return ans; + } + const Entry *CommonPrefix(const Entry *a, const Entry *b) { + std::vector a_vec, b_vec; + ConvertToVector(a, &a_vec); + ConvertToVector(b, &b_vec); + const Entry *ans = NULL; + for (size_t i = 0; + i < a_vec.size() && i < b_vec.size() && a_vec[i] == b_vec[i]; i++) + ans = Successor(ans, a_vec[i]); + return ans; + } + + // removes any elements from b that are not part of + // a common prefix with a. + void ReduceToCommonPrefix(const Entry *a, std::vector *b) { + size_t a_size = Size(a), b_size = b->size(); + while (a_size > b_size) { + a = a->parent; + a_size--; + } + if (b_size > a_size) b_size = a_size; + typename std::vector::iterator b_begin = b->begin(); + while (a_size != 0) { + if (a->i != *(b_begin + a_size - 1)) b_size = a_size - 1; + a = a->parent; + a_size--; + } + if (b_size != b->size()) b->resize(b_size); + } + + // removes the first n elements of a. + const Entry *RemovePrefix(const Entry *a, size_t n) { + if (n == 0) return a; + std::vector a_vec; + ConvertToVector(a, &a_vec); + assert(a_vec.size() >= n); + const Entry *ans = NULL; + for (size_t i = n; i < a_vec.size(); i++) ans = Successor(ans, a_vec[i]); + return ans; + } + + // Returns true if a is a prefix of b. If a is prefix of b, + // time taken is |b| - |a|. Else, time taken is |b|. + bool IsPrefixOf(const Entry *a, const Entry *b) const { + if (a == NULL) return true; // empty string prefix of all. + if (a == b) return true; + if (b == NULL) return false; + return IsPrefixOf(a, b->parent); + } + + inline size_t Size(const Entry *entry) const { + size_t ans = 0; + while (entry != NULL) { + ans++; + entry = entry->parent; + } + return ans; + } + + void ConvertToVector(const Entry *entry, std::vector *out) const { + size_t length = Size(entry); + out->resize(length); + if (entry != NULL) { + typename std::vector::reverse_iterator iter = out->rbegin(); + while (entry != NULL) { + *iter = entry->i; + entry = entry->parent; + ++iter; + } + } + } + + const Entry *ConvertFromVector(const std::vector &vec) { + const Entry *e = NULL; + for (size_t i = 0; i < vec.size(); i++) e = Successor(e, vec[i]); + return e; + } + + LatticeStringRepository() { new_entry_ = new Entry; } + + void Destroy() { + for (typename SetType::iterator iter = set_.begin(); iter != set_.end(); + ++iter) + delete *iter; + SetType tmp; + tmp.swap(set_); + if (new_entry_) { + delete new_entry_; + new_entry_ = NULL; + } + } + + // Rebuild will rebuild this object, guaranteeing only + // to preserve the Entry values that are in the vector pointed + // to (this list does not have to be unique). The point of + // this is to save memory. + void Rebuild(const std::vector &to_keep) { + SetType tmp_set; + for (typename std::vector::const_iterator iter = + to_keep.begin(); + iter != to_keep.end(); ++iter) + RebuildHelper(*iter, &tmp_set); + // Now delete all elems not in tmp_set. + for (typename SetType::iterator iter = set_.begin(); iter != set_.end(); + ++iter) { + if (tmp_set.count(*iter) == 0) + delete (*iter); // delete the Entry; not needed. + } + set_.swap(tmp_set); + } + + ~LatticeStringRepository() { Destroy(); } + int32 MemSize() const { + return set_.size() * sizeof(Entry) * 2; // this is a lower bound + // on the size this structure might take. + } + + private: + class EntryKey { // Hash function object. + public: + inline size_t operator()(const Entry *entry) const { + size_t prime = 49109; + return static_cast(entry->i) + + prime * reinterpret_cast(entry->parent); + } + }; + class EntryEqual { + public: + inline bool operator()(const Entry *e1, const Entry *e2) const { + return (*e1 == *e2); + } + }; + typedef std::unordered_set SetType; + + void RebuildHelper(const Entry *to_add, SetType *tmp_set) { + while (true) { + if (to_add == NULL) return; + typename SetType::iterator iter = tmp_set->find(to_add); + if (iter == tmp_set->end()) { // not in tmp_set. + tmp_set->insert(to_add); + to_add = to_add->parent; // and loop. + } else { + return; + } + } + } + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeStringRepository); + Entry *new_entry_; // We always have a pre-allocated Entry ready to use, + // to avoid unnecessary news and deletes. + SetType set_; +}; + +// class LatticeDeterminizer is templated on the same types that +// CompactLatticeWeight is templated on: the base weight (Weight), typically +// LatticeWeightTpl etc. but could also be e.g. TropicalWeight, and the +// IntType, typically int32, used for the output symbols in the compact +// representation of strings [note: the output symbols would usually be +// p.d.f. id's in the anticipated use of this code] It has a special requirement +// on the Weight type: that there should be a Compare function on the weights +// such that Compare(w1, w2) returns -1 if w1 < w2, 0 if w1 == w2, and +1 if w1 +// > w2. This requires that there be a total order on the weights. + +template +class LatticeDeterminizer { + public: + // Output to Gallic acceptor (so the strings go on weights, and there is a 1-1 + // correspondence between our states and the states in ofst. If destroy == + // true, release memory as we go (but we cannot output again). + + typedef CompactLatticeWeightTpl CompactWeight; + typedef ArcTpl + CompactArc; // arc in compact, acceptor form of lattice + typedef ArcTpl Arc; // arc in non-compact version of lattice + + // Output to standard FST with CompactWeightTpl as its weight type + // (the weight stores the original output-symbol strings). If destroy == + // true, release memory as we go (but we cannot output again). + void Output(MutableFst *ofst, bool destroy = true) { + assert(determinized_); + typedef typename Arc::StateId StateId; + StateId nStates = static_cast(output_arcs_.size()); + if (destroy) FreeMostMemory(); + ofst->DeleteStates(); + ofst->SetStart(kNoStateId); + if (nStates == 0) { + return; + } + for (StateId s = 0; s < nStates; s++) { + OutputStateId news = ofst->AddState(); + assert(news == s); + } + ofst->SetStart(0); + // now process transitions. + for (StateId this_state = 0; this_state < nStates; this_state++) { + std::vector &this_vec(output_arcs_[this_state]); + typename std::vector::const_iterator iter = this_vec.begin(), + end = this_vec.end(); + + for (; iter != end; ++iter) { + const TempArc &temp_arc(*iter); + CompactArc new_arc; + std::vector