commit
c2ee6bc67d
@ -0,0 +1,164 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Prepare VoxCeleb2 dataset
|
||||||
|
|
||||||
|
Download and unpack the voxceleb2 data files.
|
||||||
|
Voxceleb2 data is stored as the m4a format,
|
||||||
|
so we need convert the m4a to wav with the convert.sh scripts
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import codecs
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
from utils.utility import download
|
||||||
|
from utils.utility import unzip
|
||||||
|
|
||||||
|
# all the data will be download in the current data/voxceleb directory default
|
||||||
|
DATA_HOME = os.path.expanduser('.')
|
||||||
|
|
||||||
|
BASE_URL = "--no-check-certificate https://www.robots.ox.ac.uk/~vgg/data/voxceleb/data/"
|
||||||
|
|
||||||
|
# dev data
|
||||||
|
DEV_DATA_URL = BASE_URL + '/vox2_aac.zip'
|
||||||
|
DEV_MD5SUM = "bbc063c46078a602ca71605645c2a402"
|
||||||
|
|
||||||
|
# test data
|
||||||
|
TEST_DATA_URL = BASE_URL + '/vox2_test_aac.zip'
|
||||||
|
TEST_MD5SUM = "0d2b3ea430a821c33263b5ea37ede312"
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--target_dir",
|
||||||
|
default=DATA_HOME + "/voxceleb2/",
|
||||||
|
type=str,
|
||||||
|
help="Directory to save the voxceleb1 dataset. (default: %(default)s)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest_prefix",
|
||||||
|
default="manifest",
|
||||||
|
type=str,
|
||||||
|
help="Filepath prefix for output manifests. (default: %(default)s)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--download",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Download the voxceleb2 dataset. (default: %(default)s)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--generate",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Generate the manifest files. (default: %(default)s)")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def create_manifest(data_dir, manifest_path_prefix):
|
||||||
|
print("Creating manifest %s ..." % manifest_path_prefix)
|
||||||
|
json_lines = []
|
||||||
|
data_path = os.path.join(data_dir, "**", "*.wav")
|
||||||
|
total_sec = 0.0
|
||||||
|
total_text = 0.0
|
||||||
|
total_num = 0
|
||||||
|
speakers = set()
|
||||||
|
for audio_path in glob.glob(data_path, recursive=True):
|
||||||
|
audio_id = "-".join(audio_path.split("/")[-3:])
|
||||||
|
utt2spk = audio_path.split("/")[-3]
|
||||||
|
duration = soundfile.info(audio_path).duration
|
||||||
|
text = ""
|
||||||
|
json_lines.append(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"utt": audio_id,
|
||||||
|
"utt2spk": str(utt2spk),
|
||||||
|
"feat": audio_path,
|
||||||
|
"feat_shape": (duration, ),
|
||||||
|
"text": text # compatible with asr data format
|
||||||
|
},
|
||||||
|
ensure_ascii=False))
|
||||||
|
|
||||||
|
total_sec += duration
|
||||||
|
total_text += len(text)
|
||||||
|
total_num += 1
|
||||||
|
speakers.add(utt2spk)
|
||||||
|
|
||||||
|
# data_dir_name refer to dev or test
|
||||||
|
# voxceleb2 is given explicit in the path
|
||||||
|
data_dir_name = Path(data_dir).name
|
||||||
|
manifest_path_prefix = manifest_path_prefix + "." + data_dir_name
|
||||||
|
|
||||||
|
if not os.path.exists(os.path.dirname(manifest_path_prefix)):
|
||||||
|
os.makedirs(os.path.dirname(manifest_path_prefix))
|
||||||
|
with codecs.open(manifest_path_prefix, 'w', encoding='utf-8') as f:
|
||||||
|
for line in json_lines:
|
||||||
|
f.write(line + "\n")
|
||||||
|
|
||||||
|
manifest_dir = os.path.dirname(manifest_path_prefix)
|
||||||
|
meta_path = os.path.join(manifest_dir, "voxceleb2." +
|
||||||
|
data_dir_name) + ".meta"
|
||||||
|
with codecs.open(meta_path, 'w', encoding='utf-8') as f:
|
||||||
|
print(f"{total_num} utts", file=f)
|
||||||
|
print(f"{len(speakers)} speakers", file=f)
|
||||||
|
print(f"{total_sec / (60 * 60)} h", file=f)
|
||||||
|
print(f"{total_text} text", file=f)
|
||||||
|
print(f"{total_text / total_sec} text/sec", file=f)
|
||||||
|
print(f"{total_sec / total_num} sec/utt", file=f)
|
||||||
|
|
||||||
|
|
||||||
|
def download_dataset(url, md5sum, target_dir, dataset):
|
||||||
|
if not os.path.exists(target_dir):
|
||||||
|
os.makedirs(target_dir)
|
||||||
|
|
||||||
|
# wav directory already exists, it need do nothing
|
||||||
|
print("target dir {}".format(os.path.join(target_dir, dataset)))
|
||||||
|
# unzip the dev dataset will create the dev and unzip the m4a to dev dir
|
||||||
|
# but the test dataset will unzip to aac
|
||||||
|
# so, wo create the ${target_dir}/test and unzip the m4a to test dir
|
||||||
|
if not os.path.exists(os.path.join(target_dir, dataset)):
|
||||||
|
filepath = download(url, md5sum, target_dir)
|
||||||
|
if dataset == "test":
|
||||||
|
unzip(filepath, os.path.join(target_dir, "test"))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if args.target_dir.startswith('~'):
|
||||||
|
args.target_dir = os.path.expanduser(args.target_dir)
|
||||||
|
|
||||||
|
# download and unpack the vox2-dev data
|
||||||
|
print("download: {}".format(args.download))
|
||||||
|
if args.download:
|
||||||
|
download_dataset(
|
||||||
|
url=DEV_DATA_URL,
|
||||||
|
md5sum=DEV_MD5SUM,
|
||||||
|
target_dir=args.target_dir,
|
||||||
|
dataset="dev")
|
||||||
|
|
||||||
|
download_dataset(
|
||||||
|
url=TEST_DATA_URL,
|
||||||
|
md5sum=TEST_MD5SUM,
|
||||||
|
target_dir=args.target_dir,
|
||||||
|
dataset="test")
|
||||||
|
|
||||||
|
print("VoxCeleb2 download is done!")
|
||||||
|
|
||||||
|
if args.generate:
|
||||||
|
create_manifest(
|
||||||
|
args.target_dir, manifest_path_prefix=args.manifest_prefix)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Before Width: | Height: | Size: 80 KiB After Width: | Height: | Size: 50 KiB |
Before Width: | Height: | Size: 84 KiB After Width: | Height: | Size: 81 KiB |
@ -1,12 +1,13 @@
|
|||||||
soundfile==0.10.3.post1
|
|
||||||
librosa==0.8.0
|
|
||||||
numpy
|
|
||||||
pymysql
|
|
||||||
fastapi
|
|
||||||
uvicorn
|
|
||||||
diskcache==5.2.1
|
diskcache==5.2.1
|
||||||
|
dtaidistance==2.3.1
|
||||||
|
fastapi
|
||||||
|
librosa==0.8.0
|
||||||
|
numpy==1.21.0
|
||||||
|
pydantic
|
||||||
pymilvus==2.0.1
|
pymilvus==2.0.1
|
||||||
|
pymysql
|
||||||
python-multipart
|
python-multipart
|
||||||
typing
|
soundfile==0.10.3.post1
|
||||||
starlette
|
starlette
|
||||||
pydantic
|
typing
|
||||||
|
uvicorn
|
@ -0,0 +1,158 @@
|
|||||||
|
([简体中文](./README_cn.md)|English)
|
||||||
|
# Speech Verification)
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
Speaker Verification, refers to the problem of getting a speaker embedding from an audio.
|
||||||
|
|
||||||
|
This demo is an implementation to extract speaker embedding from a specific audio file. It can be done by a single command or a few lines in python using `PaddleSpeech`.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
### 1. Installation
|
||||||
|
see [installation](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install.md).
|
||||||
|
|
||||||
|
You can choose one way from easy, meduim and hard to install paddlespeech.
|
||||||
|
|
||||||
|
### 2. Prepare Input File
|
||||||
|
The input of this demo should be a WAV file(`.wav`), and the sample rate must be the same as the model.
|
||||||
|
|
||||||
|
Here are sample files for this demo that can be downloaded:
|
||||||
|
```bash
|
||||||
|
wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Usage
|
||||||
|
- Command Line(Recommended)
|
||||||
|
```bash
|
||||||
|
paddlespeech vector --task spk --input 85236145389.wav
|
||||||
|
|
||||||
|
echo -e "demo1 85236145389.wav" > vec.job
|
||||||
|
paddlespeech vector --task spk --input vec.job
|
||||||
|
|
||||||
|
echo -e "demo2 85236145389.wav \n demo3 85236145389.wav" | paddlespeech vector --task spk
|
||||||
|
```
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
```bash
|
||||||
|
paddlespeech vector --help
|
||||||
|
```
|
||||||
|
Arguments:
|
||||||
|
- `input`(required): Audio file to recognize.
|
||||||
|
- `model`: Model type of vector task. Default: `ecapatdnn_voxceleb12`.
|
||||||
|
- `sample_rate`: Sample rate of the model. Default: `16000`.
|
||||||
|
- `config`: Config of vector task. Use pretrained model when it is None. Default: `None`.
|
||||||
|
- `ckpt_path`: Model checkpoint. Use pretrained model when it is None. Default: `None`.
|
||||||
|
- `device`: Choose device to execute model inference. Default: default device of paddlepaddle in current environment.
|
||||||
|
|
||||||
|
Output:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
demo [ -5.749211 9.505463 -8.200284 -5.2075014 5.3940268
|
||||||
|
-3.04878 1.611095 10.127234 -10.534177 -15.821609
|
||||||
|
1.2032688 -0.35080156 1.2629458 -12.643498 -2.5758228
|
||||||
|
-11.343508 2.3385992 -8.719341 14.213509 15.404744
|
||||||
|
-0.39327756 6.338786 2.688887 8.7104025 17.469526
|
||||||
|
-8.77959 7.0576906 4.648855 -1.3089896 -23.294737
|
||||||
|
8.013747 13.891729 -9.926753 5.655307 -5.9422326
|
||||||
|
-22.842539 0.6293588 -18.46266 -10.811862 9.8192625
|
||||||
|
3.0070958 3.8072643 -2.3861165 3.0821571 -14.739942
|
||||||
|
1.7594414 -0.6485091 4.485623 2.0207152 7.264915
|
||||||
|
-6.40137 23.63524 2.9711294 -22.708025 9.93719
|
||||||
|
20.354511 -10.324688 -0.700492 -8.783211 -5.27593
|
||||||
|
15.999649 3.3004563 12.747926 15.429879 4.7849145
|
||||||
|
5.6699696 -2.3826702 10.605882 3.9112158 3.1500628
|
||||||
|
15.859915 -2.1832209 -23.908653 -6.4799504 -4.5365124
|
||||||
|
-9.224193 14.568347 -10.568833 4.982321 -4.342062
|
||||||
|
0.0914714 12.645902 -5.74285 -3.2141201 -2.7173362
|
||||||
|
-6.680575 0.4757669 -5.035051 -6.7964664 16.865469
|
||||||
|
-11.54324 7.681869 0.44475392 9.708182 -8.932846
|
||||||
|
0.4123232 -4.361452 1.3948607 9.511665 0.11667654
|
||||||
|
2.9079323 6.049952 9.275183 -18.078873 6.2983274
|
||||||
|
-0.7500531 -2.725033 -7.6027865 3.3404543 2.990815
|
||||||
|
4.010979 11.000591 -2.8873312 7.1352735 -16.79663
|
||||||
|
18.495346 -14.293832 7.89578 2.2714825 22.976387
|
||||||
|
-4.875734 -3.0836344 -2.9999814 13.751918 6.448228
|
||||||
|
-11.924197 2.171869 2.0423572 -6.173772 10.778437
|
||||||
|
25.77281 -4.9495463 14.57806 0.3044315 2.6132357
|
||||||
|
-7.591999 -2.076944 9.025118 1.7834753 -3.1799617
|
||||||
|
-4.9401326 23.465864 5.1685796 -9.018578 9.037825
|
||||||
|
-4.4150195 6.859591 -12.274467 -0.88911164 5.186309
|
||||||
|
-3.9988663 -13.638606 -9.925445 -0.06329413 -3.6709652
|
||||||
|
-12.397416 -12.719869 -1.395601 2.1150916 5.7381287
|
||||||
|
-4.4691963 -3.82819 -0.84233856 -1.1604277 -13.490127
|
||||||
|
8.731719 -20.778936 -11.495662 5.8033476 -4.752041
|
||||||
|
10.833007 -6.717991 4.504732 13.4244375 1.1306485
|
||||||
|
7.3435574 1.400918 14.704036 -9.501399 7.2315617
|
||||||
|
-6.417456 1.3333273 11.872697 -0.30664724 8.8845
|
||||||
|
6.5569253 4.7948146 0.03662816 -8.704245 6.224871
|
||||||
|
-3.2701402 -11.508579 ]
|
||||||
|
```
|
||||||
|
|
||||||
|
- Python API
|
||||||
|
```python
|
||||||
|
import paddle
|
||||||
|
from paddlespeech.cli import VectorExecutor
|
||||||
|
|
||||||
|
vector_executor = VectorExecutor()
|
||||||
|
audio_emb = vector_executor(
|
||||||
|
model='ecapatdnn_voxceleb12',
|
||||||
|
sample_rate=16000,
|
||||||
|
config=None,
|
||||||
|
ckpt_path=None,
|
||||||
|
audio_file='./85236145389.wav',
|
||||||
|
force_yes=False,
|
||||||
|
device=paddle.get_device())
|
||||||
|
print('Audio embedding Result: \n{}'.format(audio_emb))
|
||||||
|
```
|
||||||
|
|
||||||
|
Output:
|
||||||
|
```bash
|
||||||
|
# Vector Result:
|
||||||
|
[ -5.749211 9.505463 -8.200284 -5.2075014 5.3940268
|
||||||
|
-3.04878 1.611095 10.127234 -10.534177 -15.821609
|
||||||
|
1.2032688 -0.35080156 1.2629458 -12.643498 -2.5758228
|
||||||
|
-11.343508 2.3385992 -8.719341 14.213509 15.404744
|
||||||
|
-0.39327756 6.338786 2.688887 8.7104025 17.469526
|
||||||
|
-8.77959 7.0576906 4.648855 -1.3089896 -23.294737
|
||||||
|
8.013747 13.891729 -9.926753 5.655307 -5.9422326
|
||||||
|
-22.842539 0.6293588 -18.46266 -10.811862 9.8192625
|
||||||
|
3.0070958 3.8072643 -2.3861165 3.0821571 -14.739942
|
||||||
|
1.7594414 -0.6485091 4.485623 2.0207152 7.264915
|
||||||
|
-6.40137 23.63524 2.9711294 -22.708025 9.93719
|
||||||
|
20.354511 -10.324688 -0.700492 -8.783211 -5.27593
|
||||||
|
15.999649 3.3004563 12.747926 15.429879 4.7849145
|
||||||
|
5.6699696 -2.3826702 10.605882 3.9112158 3.1500628
|
||||||
|
15.859915 -2.1832209 -23.908653 -6.4799504 -4.5365124
|
||||||
|
-9.224193 14.568347 -10.568833 4.982321 -4.342062
|
||||||
|
0.0914714 12.645902 -5.74285 -3.2141201 -2.7173362
|
||||||
|
-6.680575 0.4757669 -5.035051 -6.7964664 16.865469
|
||||||
|
-11.54324 7.681869 0.44475392 9.708182 -8.932846
|
||||||
|
0.4123232 -4.361452 1.3948607 9.511665 0.11667654
|
||||||
|
2.9079323 6.049952 9.275183 -18.078873 6.2983274
|
||||||
|
-0.7500531 -2.725033 -7.6027865 3.3404543 2.990815
|
||||||
|
4.010979 11.000591 -2.8873312 7.1352735 -16.79663
|
||||||
|
18.495346 -14.293832 7.89578 2.2714825 22.976387
|
||||||
|
-4.875734 -3.0836344 -2.9999814 13.751918 6.448228
|
||||||
|
-11.924197 2.171869 2.0423572 -6.173772 10.778437
|
||||||
|
25.77281 -4.9495463 14.57806 0.3044315 2.6132357
|
||||||
|
-7.591999 -2.076944 9.025118 1.7834753 -3.1799617
|
||||||
|
-4.9401326 23.465864 5.1685796 -9.018578 9.037825
|
||||||
|
-4.4150195 6.859591 -12.274467 -0.88911164 5.186309
|
||||||
|
-3.9988663 -13.638606 -9.925445 -0.06329413 -3.6709652
|
||||||
|
-12.397416 -12.719869 -1.395601 2.1150916 5.7381287
|
||||||
|
-4.4691963 -3.82819 -0.84233856 -1.1604277 -13.490127
|
||||||
|
8.731719 -20.778936 -11.495662 5.8033476 -4.752041
|
||||||
|
10.833007 -6.717991 4.504732 13.4244375 1.1306485
|
||||||
|
7.3435574 1.400918 14.704036 -9.501399 7.2315617
|
||||||
|
-6.417456 1.3333273 11.872697 -0.30664724 8.8845
|
||||||
|
6.5569253 4.7948146 0.03662816 -8.704245 6.224871
|
||||||
|
-3.2701402 -11.508579 ]
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.Pretrained Models
|
||||||
|
|
||||||
|
Here is a list of pretrained models released by PaddleSpeech that can be used by command and python API:
|
||||||
|
|
||||||
|
| Model | Sample Rate
|
||||||
|
| :--- | :---: |
|
||||||
|
| ecapatdnn_voxceleb12 | 16k
|
@ -0,0 +1,6 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
|
||||||
|
|
||||||
|
# asr
|
||||||
|
paddlespeech vector --task spk --input ./85236145389.wav
|
@ -0,0 +1,107 @@
|
|||||||
|
# use CNND
|
||||||
|
###########################################################
|
||||||
|
# FEATURE EXTRACTION SETTING #
|
||||||
|
###########################################################
|
||||||
|
|
||||||
|
fs: 24000 # sr
|
||||||
|
n_fft: 2048 # FFT size (samples).
|
||||||
|
n_shift: 300 # Hop size (samples). 12.5ms
|
||||||
|
win_length: 1200 # Window length (samples). 50ms
|
||||||
|
# If set to null, it will be the same as fft_size.
|
||||||
|
window: "hann" # Window function.
|
||||||
|
|
||||||
|
# Only used for feats_type != raw
|
||||||
|
|
||||||
|
fmin: 80 # Minimum frequency of Mel basis.
|
||||||
|
fmax: 7600 # Maximum frequency of Mel basis.
|
||||||
|
n_mels: 80 # The number of mel basis.
|
||||||
|
|
||||||
|
# Only used for the model using pitch features (e.g. FastSpeech2)
|
||||||
|
f0min: 80 # Minimum f0 for pitch extraction.
|
||||||
|
f0max: 400 # Maximum f0 for pitch extraction.
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# DATA SETTING #
|
||||||
|
###########################################################
|
||||||
|
batch_size: 64
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# MODEL SETTING #
|
||||||
|
###########################################################
|
||||||
|
model:
|
||||||
|
adim: 384 # attention dimension
|
||||||
|
aheads: 2 # number of attention heads
|
||||||
|
elayers: 4 # number of encoder layers
|
||||||
|
eunits: 1536 # number of encoder ff units
|
||||||
|
dlayers: 4 # number of decoder layers
|
||||||
|
dunits: 1536 # number of decoder ff units
|
||||||
|
positionwise_layer_type: conv1d # type of position-wise layer
|
||||||
|
positionwise_conv_kernel_size: 3 # kernel size of position wise conv layer
|
||||||
|
duration_predictor_layers: 2 # number of layers of duration predictor
|
||||||
|
duration_predictor_chans: 256 # number of channels of duration predictor
|
||||||
|
duration_predictor_kernel_size: 3 # filter size of duration predictor
|
||||||
|
postnet_layers: 5 # number of layers of postnset
|
||||||
|
postnet_filts: 5 # filter size of conv layers in postnet
|
||||||
|
postnet_chans: 256 # number of channels of conv layers in postnet
|
||||||
|
use_scaled_pos_enc: True # whether to use scaled positional encoding
|
||||||
|
encoder_normalize_before: True # whether to perform layer normalization before the input
|
||||||
|
decoder_normalize_before: True # whether to perform layer normalization before the input
|
||||||
|
reduction_factor: 1 # reduction factor
|
||||||
|
encoder_type: transformer # encoder type
|
||||||
|
decoder_type: cnndecoder # decoder type
|
||||||
|
init_type: xavier_uniform # initialization type
|
||||||
|
init_enc_alpha: 1.0 # initial value of alpha of encoder scaled position encoding
|
||||||
|
init_dec_alpha: 1.0 # initial value of alpha of decoder scaled position encoding
|
||||||
|
transformer_enc_dropout_rate: 0.2 # dropout rate for transformer encoder layer
|
||||||
|
transformer_enc_positional_dropout_rate: 0.2 # dropout rate for transformer encoder positional encoding
|
||||||
|
transformer_enc_attn_dropout_rate: 0.2 # dropout rate for transformer encoder attention layer
|
||||||
|
cnn_dec_dropout_rate: 0.2 # dropout rate for cnn decoder layer
|
||||||
|
cnn_postnet_dropout_rate: 0.2
|
||||||
|
cnn_postnet_resblock_kernel_sizes: [256, 256] # kernel sizes for residual block of cnn_postnet
|
||||||
|
cnn_postnet_kernel_size: 5 # kernel size of cnn_postnet
|
||||||
|
cnn_decoder_embedding_dim: 256
|
||||||
|
pitch_predictor_layers: 5 # number of conv layers in pitch predictor
|
||||||
|
pitch_predictor_chans: 256 # number of channels of conv layers in pitch predictor
|
||||||
|
pitch_predictor_kernel_size: 5 # kernel size of conv leyers in pitch predictor
|
||||||
|
pitch_predictor_dropout: 0.5 # dropout rate in pitch predictor
|
||||||
|
pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch
|
||||||
|
pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch
|
||||||
|
stop_gradient_from_pitch_predictor: True # whether to stop the gradient from pitch predictor to encoder
|
||||||
|
energy_predictor_layers: 2 # number of conv layers in energy predictor
|
||||||
|
energy_predictor_chans: 256 # number of channels of conv layers in energy predictor
|
||||||
|
energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor
|
||||||
|
energy_predictor_dropout: 0.5 # dropout rate in energy predictor
|
||||||
|
energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy
|
||||||
|
energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy
|
||||||
|
stop_gradient_from_energy_predictor: False # whether to stop the gradient from energy predictor to encoder
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# UPDATER SETTING #
|
||||||
|
###########################################################
|
||||||
|
updater:
|
||||||
|
use_masking: True # whether to apply masking for padded part in loss calculation
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# OPTIMIZER SETTING #
|
||||||
|
###########################################################
|
||||||
|
optimizer:
|
||||||
|
optim: adam # optimizer type
|
||||||
|
learning_rate: 0.001 # learning rate
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# TRAINING SETTING #
|
||||||
|
###########################################################
|
||||||
|
max_epoch: 1000
|
||||||
|
num_snapshots: 5
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# OTHER SETTING #
|
||||||
|
###########################################################
|
||||||
|
seed: 10086
|
@ -0,0 +1,92 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
train_output_path=$2
|
||||||
|
ckpt_name=$3
|
||||||
|
|
||||||
|
stage=0
|
||||||
|
stop_stage=0
|
||||||
|
|
||||||
|
# pwgan
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
FLAGS_allocator_strategy=naive_best_fit \
|
||||||
|
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||||
|
python3 ${BIN_DIR}/../synthesize_streaming.py \
|
||||||
|
--am=fastspeech2_csmsc \
|
||||||
|
--am_config=${config_path} \
|
||||||
|
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
|
||||||
|
--am_stat=dump/train/speech_stats.npy \
|
||||||
|
--voc=pwgan_csmsc \
|
||||||
|
--voc_config=pwg_baker_ckpt_0.4/pwg_default.yaml \
|
||||||
|
--voc_ckpt=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \
|
||||||
|
--voc_stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
|
||||||
|
--lang=zh \
|
||||||
|
--text=${BIN_DIR}/../sentences.txt \
|
||||||
|
--output_dir=${train_output_path}/test_e2e_streaming \
|
||||||
|
--phones_dict=dump/phone_id_map.txt \
|
||||||
|
--am_streaming=True
|
||||||
|
fi
|
||||||
|
|
||||||
|
# for more GAN Vocoders
|
||||||
|
# multi band melgan
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
FLAGS_allocator_strategy=naive_best_fit \
|
||||||
|
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||||
|
python3 ${BIN_DIR}/../synthesize_streaming.py \
|
||||||
|
--am=fastspeech2_csmsc \
|
||||||
|
--am_config=${config_path} \
|
||||||
|
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
|
||||||
|
--am_stat=dump/train/speech_stats.npy \
|
||||||
|
--voc=mb_melgan_csmsc \
|
||||||
|
--voc_config=mb_melgan_csmsc_ckpt_0.1.1/default.yaml \
|
||||||
|
--voc_ckpt=mb_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1000000.pdz\
|
||||||
|
--voc_stat=mb_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \
|
||||||
|
--lang=zh \
|
||||||
|
--text=${BIN_DIR}/../sentences.txt \
|
||||||
|
--output_dir=${train_output_path}/test_e2e_streaming \
|
||||||
|
--phones_dict=dump/phone_id_map.txt \
|
||||||
|
--am_streaming=True
|
||||||
|
fi
|
||||||
|
|
||||||
|
# the pretrained models haven't release now
|
||||||
|
# style melgan
|
||||||
|
# style melgan's Dygraph to Static Graph is not ready now
|
||||||
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
FLAGS_allocator_strategy=naive_best_fit \
|
||||||
|
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||||
|
python3 ${BIN_DIR}/../synthesize_streaming.py \
|
||||||
|
--am=fastspeech2_csmsc \
|
||||||
|
--am_config=${config_path} \
|
||||||
|
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
|
||||||
|
--am_stat=dump/train/speech_stats.npy \
|
||||||
|
--voc=style_melgan_csmsc \
|
||||||
|
--voc_config=style_melgan_csmsc_ckpt_0.1.1/default.yaml \
|
||||||
|
--voc_ckpt=style_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1500000.pdz \
|
||||||
|
--voc_stat=style_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \
|
||||||
|
--lang=zh \
|
||||||
|
--text=${BIN_DIR}/../sentences.txt \
|
||||||
|
--output_dir=${train_output_path}/test_e2e_streaming \
|
||||||
|
--phones_dict=dump/phone_id_map.txt \
|
||||||
|
--am_streaming=True
|
||||||
|
fi
|
||||||
|
|
||||||
|
# hifigan
|
||||||
|
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||||
|
echo "in hifigan syn_e2e"
|
||||||
|
FLAGS_allocator_strategy=naive_best_fit \
|
||||||
|
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||||
|
python3 ${BIN_DIR}/../synthesize_streaming.py \
|
||||||
|
--am=fastspeech2_csmsc \
|
||||||
|
--am_config=${config_path} \
|
||||||
|
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
|
||||||
|
--am_stat=dump/train/speech_stats.npy \
|
||||||
|
--voc=hifigan_csmsc \
|
||||||
|
--voc_config=hifigan_csmsc_ckpt_0.1.1/default.yaml \
|
||||||
|
--voc_ckpt=hifigan_csmsc_ckpt_0.1.1/snapshot_iter_2500000.pdz \
|
||||||
|
--voc_stat=hifigan_csmsc_ckpt_0.1.1/feats_stats.npy \
|
||||||
|
--lang=zh \
|
||||||
|
--text=${BIN_DIR}/../sentences.txt \
|
||||||
|
--output_dir=${train_output_path}/test_e2e_streaming \
|
||||||
|
--phones_dict=dump/phone_id_map.txt \
|
||||||
|
--am_streaming=True
|
||||||
|
fi
|
@ -0,0 +1,48 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
source path.sh
|
||||||
|
|
||||||
|
gpus=0,1
|
||||||
|
stage=0
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
conf_path=conf/cnndecoder.yaml
|
||||||
|
train_output_path=exp/cnndecoder
|
||||||
|
ckpt_name=snapshot_iter_153.pdz
|
||||||
|
|
||||||
|
# with the following command, you can choose the stage range you want to run
|
||||||
|
# such as `./run.sh --stage 0 --stop-stage 0`
|
||||||
|
# this can not be mixed use with `$1`, `$2` ...
|
||||||
|
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
# prepare data
|
||||||
|
./local/preprocess.sh ${conf_path} || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
# synthesize, vocoder is pwgan
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||||
|
# synthesize_e2e, vocoder is pwgan
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||||
|
# inference with static model
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||||
|
# synthesize_e2e, vocoder is pwgan
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_streaming.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
|
||||||
|
fi
|
||||||
|
|
@ -0,0 +1,7 @@
|
|||||||
|
# VoxCeleb
|
||||||
|
|
||||||
|
## ECAPA-TDNN
|
||||||
|
|
||||||
|
| Model | Number of Params | Release | Config | dim | Test set | Cosine | Cosine + S-Norm |
|
||||||
|
| --- | --- | --- | --- | --- | --- | --- | ---- |
|
||||||
|
| ECAPA-TDNN | 85M | 0.1.1 | conf/ecapa_tdnn.yaml |192 | test | 1.15 | 1.06 |
|
@ -0,0 +1,52 @@
|
|||||||
|
###########################################
|
||||||
|
# Data #
|
||||||
|
###########################################
|
||||||
|
# we should explicitly specify the wav path of vox2 audio data converted from m4a
|
||||||
|
vox2_base_path:
|
||||||
|
augment: True
|
||||||
|
batch_size: 16
|
||||||
|
num_workers: 2
|
||||||
|
num_speakers: 7205 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
|
||||||
|
shuffle: True
|
||||||
|
random_chunk: True
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# FEATURE EXTRACTION SETTING #
|
||||||
|
###########################################################
|
||||||
|
# currently, we only support fbank
|
||||||
|
sr: 16000 # sample rate
|
||||||
|
n_mels: 80
|
||||||
|
window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400
|
||||||
|
hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# MODEL SETTING #
|
||||||
|
###########################################################
|
||||||
|
# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml
|
||||||
|
# if we want use another model, please choose another configuration yaml file
|
||||||
|
model:
|
||||||
|
input_size: 80
|
||||||
|
# "channels": [512, 512, 512, 512, 1536],
|
||||||
|
channels: [1024, 1024, 1024, 1024, 3072]
|
||||||
|
kernel_sizes: [5, 3, 3, 3, 1]
|
||||||
|
dilations: [1, 2, 3, 4, 1]
|
||||||
|
attention_channels: 128
|
||||||
|
lin_neurons: 192
|
||||||
|
|
||||||
|
###########################################
|
||||||
|
# Training #
|
||||||
|
###########################################
|
||||||
|
seed: 1986 # according from speechbrain configuration
|
||||||
|
epochs: 10
|
||||||
|
save_interval: 1
|
||||||
|
log_interval: 1
|
||||||
|
learning_rate: 1e-8
|
||||||
|
|
||||||
|
|
||||||
|
###########################################
|
||||||
|
# Testing #
|
||||||
|
###########################################
|
||||||
|
global_embedding_norm: True
|
||||||
|
embedding_mean_norm: True
|
||||||
|
embedding_std_norm: False
|
||||||
|
|
@ -0,0 +1,58 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
stage=1
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
. ${MAIN_ROOT}/utils/parse_options.sh || exit -1;
|
||||||
|
|
||||||
|
if [ $# -ne 2 ] ; then
|
||||||
|
echo "Usage: $0 [options] <data-dir> <conf-path>";
|
||||||
|
echo "e.g.: $0 ./data/ conf/ecapa_tdnn.yaml"
|
||||||
|
echo "Options: "
|
||||||
|
echo " --stage <stage|-1> # Used to run a partially-completed data process from somewhere in the middle."
|
||||||
|
echo " --stop-stage <stop-stage|100> # Used to run a partially-completed data process stop stage in the middle"
|
||||||
|
exit 1;
|
||||||
|
fi
|
||||||
|
|
||||||
|
dir=$1
|
||||||
|
conf_path=$2
|
||||||
|
mkdir -p ${dir}
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
# data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
|
||||||
|
# we should use the local/convert.sh convert m4a to wav
|
||||||
|
python3 local/data_prepare.py \
|
||||||
|
--data-dir ${dir} \
|
||||||
|
--config ${conf_path}
|
||||||
|
fi
|
||||||
|
|
||||||
|
TARGET_DIR=${MAIN_ROOT}/dataset
|
||||||
|
mkdir -p ${TARGET_DIR}
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
# download data, generate manifests
|
||||||
|
python3 ${TARGET_DIR}/voxceleb/voxceleb1.py \
|
||||||
|
--manifest_prefix="data/vox1/manifest" \
|
||||||
|
--target_dir="${TARGET_DIR}/voxceleb/vox1/"
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Prepare voxceleb failed. Terminated."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# for dataset in train dev test; do
|
||||||
|
# mv data/manifest.${dataset} data/manifest.${dataset}.raw
|
||||||
|
# done
|
||||||
|
fi
|
@ -0,0 +1,70 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddleaudio.datasets.voxceleb import VoxCeleb
|
||||||
|
from paddlespeech.s2t.utils.log import Log
|
||||||
|
from paddlespeech.vector.io.augment import build_augment_pipeline
|
||||||
|
from paddlespeech.vector.training.seeding import seed_everything
|
||||||
|
|
||||||
|
logger = Log(__name__).getlog()
|
||||||
|
|
||||||
|
|
||||||
|
def main(args, config):
|
||||||
|
|
||||||
|
# stage0: set the cpu device, all data prepare process will be done in cpu mode
|
||||||
|
paddle.set_device("cpu")
|
||||||
|
# set the random seed, it is a must for multiprocess training
|
||||||
|
seed_everything(config.seed)
|
||||||
|
|
||||||
|
# stage 1: generate the voxceleb csv file
|
||||||
|
# Note: this may occurs c++ execption, but the program will execute fine
|
||||||
|
# so we ignore the execption
|
||||||
|
# we explicitly pass the vox2 base path to data prepare and generate the audio info
|
||||||
|
logger.info("start to generate the voxceleb dataset info")
|
||||||
|
train_dataset = VoxCeleb(
|
||||||
|
'train', target_dir=args.data_dir, vox2_base_path=config.vox2_base_path)
|
||||||
|
|
||||||
|
# stage 2: generate the augment noise csv file
|
||||||
|
if config.augment:
|
||||||
|
logger.info("start to generate the augment dataset info")
|
||||||
|
augment_pipeline = build_augment_pipeline(target_dir=args.data_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# yapf: disable
|
||||||
|
parser = argparse.ArgumentParser(__doc__)
|
||||||
|
parser.add_argument("--data-dir",
|
||||||
|
default="./data/",
|
||||||
|
type=str,
|
||||||
|
help="data directory")
|
||||||
|
parser.add_argument("--config",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="configuration file")
|
||||||
|
args = parser.parse_args()
|
||||||
|
# yapf: enable
|
||||||
|
|
||||||
|
# https://yaml.org/type/float.html
|
||||||
|
config = CfgNode(new_allowed=True)
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
|
||||||
|
config.freeze()
|
||||||
|
print(config)
|
||||||
|
|
||||||
|
main(args, config)
|
@ -0,0 +1,51 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
. ./path.sh
|
||||||
|
|
||||||
|
stage=0
|
||||||
|
stop_stage=100
|
||||||
|
exp_dir=exp/ecapa-tdnn-vox12-big/ # experiment directory
|
||||||
|
conf_path=conf/ecapa_tdnn.yaml
|
||||||
|
audio_path="demo/voxceleb/00001.wav"
|
||||||
|
use_gpu=true
|
||||||
|
|
||||||
|
. ${MAIN_ROOT}/utils/parse_options.sh || exit -1;
|
||||||
|
|
||||||
|
if [ $# -ne 0 ] ; then
|
||||||
|
echo "Usage: $0 [options]";
|
||||||
|
echo "e.g.: $0 ./data/ exp/voxceleb12/ conf/ecapa_tdnn.yaml"
|
||||||
|
echo "Options: "
|
||||||
|
echo " --use-gpu <true,false|true> # specify is gpu is to be used for training"
|
||||||
|
echo " --stage <stage|-1> # Used to run a partially-completed data process from somewhere in the middle."
|
||||||
|
echo " --stop-stage <stop-stage|100> # Used to run a partially-completed data process stop stage in the middle"
|
||||||
|
echo " --exp-dir # experiment directorh, where is has the model.pdparams"
|
||||||
|
echo " --conf-path # configuration file for extracting the embedding"
|
||||||
|
echo " --audio-path # audio-path, which will be processed to extract the embedding"
|
||||||
|
exit 1;
|
||||||
|
fi
|
||||||
|
|
||||||
|
# set the test device
|
||||||
|
device="cpu"
|
||||||
|
if ${use_gpu}; then
|
||||||
|
device="gpu"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
# extract the audio embedding
|
||||||
|
python3 ${BIN_DIR}/extract_emb.py --device ${device} \
|
||||||
|
--config ${conf_path} \
|
||||||
|
--audio-path ${audio_path} --load-checkpoint ${exp_dir}
|
||||||
|
fi
|
@ -0,0 +1,42 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
stage=1
|
||||||
|
stop_stage=100
|
||||||
|
use_gpu=true # if true, we run on GPU.
|
||||||
|
|
||||||
|
. ${MAIN_ROOT}/utils/parse_options.sh || exit -1;
|
||||||
|
|
||||||
|
if [ $# -ne 3 ] ; then
|
||||||
|
echo "Usage: $0 [options] <data-dir> <exp-dir> <conf-path>";
|
||||||
|
echo "e.g.: $0 ./data/ exp/voxceleb12/ conf/ecapa_tdnn.yaml"
|
||||||
|
echo "Options: "
|
||||||
|
echo " --use-gpu <true,false|true> # specify is gpu is to be used for training"
|
||||||
|
echo " --stage <stage|-1> # Used to run a partially-completed data process from somewhere in the middle."
|
||||||
|
echo " --stop-stage <stop-stage|100> # Used to run a partially-completed data process stop stage in the middle"
|
||||||
|
exit 1;
|
||||||
|
fi
|
||||||
|
|
||||||
|
dir=$1
|
||||||
|
exp_dir=$2
|
||||||
|
conf_path=$3
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
# test the model and compute the eer metrics
|
||||||
|
python3 ${BIN_DIR}/test.py \
|
||||||
|
--data-dir ${dir} \
|
||||||
|
--load-checkpoint ${exp_dir} \
|
||||||
|
--config ${conf_path}
|
||||||
|
fi
|
@ -0,0 +1,61 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
stage=0
|
||||||
|
stop_stage=100
|
||||||
|
use_gpu=true # if true, we run on GPU.
|
||||||
|
|
||||||
|
. ${MAIN_ROOT}/utils/parse_options.sh || exit -1;
|
||||||
|
|
||||||
|
if [ $# -ne 3 ] ; then
|
||||||
|
echo "Usage: $0 [options] <data-dir> <exp-dir> <conf-path>";
|
||||||
|
echo "e.g.: $0 ./data/ exp/voxceleb12/ conf/ecapa_tdnn.yaml"
|
||||||
|
echo "Options: "
|
||||||
|
echo " --use-gpu <true,false|true> # specify is gpu is to be used for training"
|
||||||
|
echo " --stage <stage|-1> # Used to run a partially-completed data process from somewhere in the middle."
|
||||||
|
echo " --stop-stage <stop-stage|100> # Used to run a partially-completed data process stop stage in the middle"
|
||||||
|
exit 1;
|
||||||
|
fi
|
||||||
|
|
||||||
|
dir=$1
|
||||||
|
exp_dir=$2
|
||||||
|
conf_path=$3
|
||||||
|
|
||||||
|
# get the gpu nums for training
|
||||||
|
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||||
|
echo "using $ngpu gpus..."
|
||||||
|
|
||||||
|
# setting training device
|
||||||
|
device="cpu"
|
||||||
|
if ${use_gpu}; then
|
||||||
|
device="gpu"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
# train the speaker identification task with voxceleb data
|
||||||
|
# and we will create the trained model parameters in ${exp_dir}/model.pdparams as the soft link
|
||||||
|
# Note: we will store the log file in exp/log directory
|
||||||
|
python3 -m paddle.distributed.launch --gpus=$CUDA_VISIBLE_DEVICES \
|
||||||
|
${BIN_DIR}/train.py --device ${device} --checkpoint-dir ${exp_dir} \
|
||||||
|
--data-dir ${dir} --config ${conf_path}
|
||||||
|
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in training!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,28 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||||
|
|
||||||
|
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||||
|
export LC_ALL=C
|
||||||
|
|
||||||
|
export PYTHONDONTWRITEBYTECODE=1
|
||||||
|
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||||
|
export PYTHONIOENCODING=UTF-8
|
||||||
|
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||||
|
|
||||||
|
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
|
||||||
|
|
||||||
|
MODEL=ecapa_tdnn
|
||||||
|
export BIN_DIR=${MAIN_ROOT}/paddlespeech/vector/exps/${MODEL}
|
@ -0,0 +1,69 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
. ./path.sh
|
||||||
|
set -e
|
||||||
|
|
||||||
|
#######################################################################
|
||||||
|
# stage 0: data prepare, including voxceleb1 download and generate {train,dev,enroll,test}.csv
|
||||||
|
# voxceleb2 data is m4a format, so we need user to convert the m4a to wav yourselves as described in Readme.md with the script local/convert.sh
|
||||||
|
# stage 1: train the speaker identification model
|
||||||
|
# stage 2: test speaker identification
|
||||||
|
# stage 3: extract the training embeding to train the LDA and PLDA
|
||||||
|
######################################################################
|
||||||
|
|
||||||
|
# we can set the variable PPAUDIO_HOME to specifiy the root directory of the downloaded vox1 and vox2 dataset
|
||||||
|
# default the dataset will be stored in the ~/.paddleaudio/
|
||||||
|
# the vox2 dataset is stored in m4a format, we need to convert the audio from m4a to wav yourself
|
||||||
|
# and put all of them to ${PPAUDIO_HOME}/datasets/vox2
|
||||||
|
# we will find the wav from ${PPAUDIO_HOME}/datasets/vox1/wav and ${PPAUDIO_HOME}/datasets/vox2/wav
|
||||||
|
# export PPAUDIO_HOME=
|
||||||
|
stage=0
|
||||||
|
stop_stage=50
|
||||||
|
|
||||||
|
# data directory
|
||||||
|
# if we set the variable ${dir}, we will store the wav info to this directory
|
||||||
|
# otherwise, we will store the wav info to vox1 and vox2 directory respectively
|
||||||
|
# vox2 wav path, we must convert the m4a format to wav format
|
||||||
|
dir=data/ # data info directory
|
||||||
|
|
||||||
|
exp_dir=exp/ecapa-tdnn-vox12-big/ # experiment directory
|
||||||
|
conf_path=conf/ecapa_tdnn.yaml
|
||||||
|
gpus=0,1,2,3
|
||||||
|
|
||||||
|
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||||
|
|
||||||
|
mkdir -p ${exp_dir}
|
||||||
|
|
||||||
|
if [ $stage -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
# stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
|
||||||
|
bash ./local/data.sh ${dir} ${conf_path}|| exit -1;
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
# stage 1: train the speaker identification model
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} bash ./local/train.sh ${dir} ${exp_dir} ${conf_path}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
# stage 2: get the speaker verification scores with cosine function
|
||||||
|
# now we only support use cosine to get the scores
|
||||||
|
CUDA_VISIBLE_DEVICES=0 bash ./local/test.sh ${dir} ${exp_dir} ${conf_path}
|
||||||
|
fi
|
||||||
|
|
||||||
|
# if [ $stage -le 3 ]; then
|
||||||
|
# # stage 2: extract the training embeding to train the LDA and PLDA
|
||||||
|
# # todo: extract the training embedding
|
||||||
|
# fi
|
@ -0,0 +1 @@
|
|||||||
|
../../../utils/
|
@ -0,0 +1,2 @@
|
|||||||
|
.eggs
|
||||||
|
*.wav
|
@ -0,0 +1,7 @@
|
|||||||
|
# PaddleAudio
|
||||||
|
|
||||||
|
PaddleAudio is an audio library for PaddlePaddle.
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
`pip install .`
|
@ -0,0 +1,19 @@
|
|||||||
|
# Minimal makefile for Sphinx documentation
|
||||||
|
#
|
||||||
|
|
||||||
|
# You can set these variables from the command line.
|
||||||
|
SPHINXOPTS =
|
||||||
|
SPHINXBUILD = sphinx-build
|
||||||
|
SOURCEDIR = source
|
||||||
|
BUILDDIR = build
|
||||||
|
|
||||||
|
# Put it first so that "make" without argument is like "make help".
|
||||||
|
help:
|
||||||
|
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||||
|
|
||||||
|
.PHONY: help Makefile
|
||||||
|
|
||||||
|
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||||
|
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||||
|
%: Makefile
|
||||||
|
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
@ -0,0 +1,24 @@
|
|||||||
|
# Build docs for PaddleAudio
|
||||||
|
|
||||||
|
Execute the following steps in **current directory**.
|
||||||
|
|
||||||
|
## 1. Install
|
||||||
|
|
||||||
|
`pip install Sphinx sphinx_rtd_theme`
|
||||||
|
|
||||||
|
|
||||||
|
## 2. Generate API docs
|
||||||
|
|
||||||
|
Generate API docs from doc string.
|
||||||
|
|
||||||
|
`sphinx-apidoc -fMeT -o source ../paddleaudio ../paddleaudio/utils --templatedir source/_templates`
|
||||||
|
|
||||||
|
|
||||||
|
## 3. Build
|
||||||
|
|
||||||
|
`sphinx-build source _html`
|
||||||
|
|
||||||
|
|
||||||
|
## 4. Preview
|
||||||
|
|
||||||
|
Open `_html/index.html` for page preview.
|
After Width: | Height: | Size: 4.9 KiB |
@ -0,0 +1,35 @@
|
|||||||
|
@ECHO OFF
|
||||||
|
|
||||||
|
pushd %~dp0
|
||||||
|
|
||||||
|
REM Command file for Sphinx documentation
|
||||||
|
|
||||||
|
if "%SPHINXBUILD%" == "" (
|
||||||
|
set SPHINXBUILD=sphinx-build
|
||||||
|
)
|
||||||
|
set SOURCEDIR=source
|
||||||
|
set BUILDDIR=build
|
||||||
|
|
||||||
|
if "%1" == "" goto help
|
||||||
|
|
||||||
|
%SPHINXBUILD% >NUL 2>NUL
|
||||||
|
if errorlevel 9009 (
|
||||||
|
echo.
|
||||||
|
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||||
|
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||||
|
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||||
|
echo.may add the Sphinx directory to PATH.
|
||||||
|
echo.
|
||||||
|
echo.If you don't have Sphinx installed, grab it from
|
||||||
|
echo.http://sphinx-doc.org/
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
|
||||||
|
goto end
|
||||||
|
|
||||||
|
:help
|
||||||
|
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
|
||||||
|
|
||||||
|
:end
|
||||||
|
popd
|
@ -0,0 +1,5 @@
|
|||||||
|
.wy-nav-content {
|
||||||
|
max-width: 80%;
|
||||||
|
}
|
||||||
|
.table table{ background:#b9b9b9}
|
||||||
|
.table table td{ background:#FFF; }
|
@ -0,0 +1,9 @@
|
|||||||
|
{%- if show_headings %}
|
||||||
|
{{- basename | e | heading }}
|
||||||
|
|
||||||
|
{% endif -%}
|
||||||
|
.. automodule:: {{ qualname }}
|
||||||
|
{%- for option in automodule_options %}
|
||||||
|
:{{ option }}:
|
||||||
|
{%- endfor %}
|
||||||
|
|
@ -0,0 +1,57 @@
|
|||||||
|
{%- macro automodule(modname, options) -%}
|
||||||
|
.. automodule:: {{ modname }}
|
||||||
|
{%- for option in options %}
|
||||||
|
:{{ option }}:
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endmacro %}
|
||||||
|
|
||||||
|
{%- macro toctree(docnames) -%}
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: {{ maxdepth }}
|
||||||
|
{% for docname in docnames %}
|
||||||
|
{{ docname }}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endmacro %}
|
||||||
|
|
||||||
|
{%- if is_namespace %}
|
||||||
|
{{- [pkgname, "namespace"] | join(" ") | e | heading }}
|
||||||
|
{% else %}
|
||||||
|
{{- pkgname | e | heading }}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{%- if is_namespace %}
|
||||||
|
.. py:module:: {{ pkgname }}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{%- if modulefirst and not is_namespace %}
|
||||||
|
{{ automodule(pkgname, automodule_options) }}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{%- if subpackages %}
|
||||||
|
Subpackages
|
||||||
|
-----------
|
||||||
|
|
||||||
|
{{ toctree(subpackages) }}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{%- if submodules %}
|
||||||
|
Submodules
|
||||||
|
----------
|
||||||
|
{% if separatemodules %}
|
||||||
|
{{ toctree(submodules) }}
|
||||||
|
{% else %}
|
||||||
|
{%- for submodule in submodules %}
|
||||||
|
{% if show_headings %}
|
||||||
|
{{- submodule | e | heading(2) }}
|
||||||
|
{% endif %}
|
||||||
|
{{ automodule(submodule, automodule_options) }}
|
||||||
|
{% endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
{%- if not modulefirst and not is_namespace %}
|
||||||
|
Module contents
|
||||||
|
---------------
|
||||||
|
|
||||||
|
{{ automodule(pkgname, automodule_options) }}
|
||||||
|
{% endif %}
|
@ -0,0 +1,8 @@
|
|||||||
|
{{ header | heading }}
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: {{ maxdepth }}
|
||||||
|
{% for docname in docnames %}
|
||||||
|
{{ docname }}
|
||||||
|
{%- endfor %}
|
||||||
|
|
@ -0,0 +1,181 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Configuration file for the Sphinx documentation builder.
|
||||||
|
#
|
||||||
|
# This file does only contain a selection of the most common options. For a
|
||||||
|
# full list see the documentation:
|
||||||
|
# http://www.sphinx-doc.org/en/master/config
|
||||||
|
# -- Path setup --------------------------------------------------------------
|
||||||
|
# If extensions (or modules to document with autodoc) are in another directory,
|
||||||
|
# add these directories to sys.path here. If the directory is relative to the
|
||||||
|
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, os.path.abspath('../..'))
|
||||||
|
|
||||||
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
|
project = 'PaddleAudio'
|
||||||
|
copyright = '2022, PaddlePaddle'
|
||||||
|
author = 'PaddlePaddle'
|
||||||
|
|
||||||
|
# The short X.Y version
|
||||||
|
version = ''
|
||||||
|
# The full version, including alpha/beta/rc tags
|
||||||
|
release = '0.2.0'
|
||||||
|
|
||||||
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
|
# If your documentation needs a minimal Sphinx version, state it here.
|
||||||
|
#
|
||||||
|
# needs_sphinx = '1.0'
|
||||||
|
|
||||||
|
# Add any Sphinx extension module names here, as strings. They can be
|
||||||
|
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||||
|
# ones.
|
||||||
|
extensions = [
|
||||||
|
'sphinx.ext.autodoc',
|
||||||
|
'sphinx.ext.intersphinx',
|
||||||
|
'sphinx.ext.mathjax',
|
||||||
|
'sphinx.ext.viewcode',
|
||||||
|
'sphinx.ext.napoleon',
|
||||||
|
]
|
||||||
|
|
||||||
|
napoleon_google_docstring = True
|
||||||
|
|
||||||
|
# Add any paths that contain templates here, relative to this directory.
|
||||||
|
templates_path = ['_templates']
|
||||||
|
|
||||||
|
# The suffix(es) of source filenames.
|
||||||
|
# You can specify multiple suffix as a list of string:
|
||||||
|
#
|
||||||
|
# source_suffix = ['.rst', '.md']
|
||||||
|
source_suffix = '.rst'
|
||||||
|
|
||||||
|
# The master toctree document.
|
||||||
|
master_doc = 'index'
|
||||||
|
|
||||||
|
# The language for content autogenerated by Sphinx. Refer to documentation
|
||||||
|
# for a list of supported languages.
|
||||||
|
#
|
||||||
|
# This is also used if you do content translation via gettext catalogs.
|
||||||
|
# Usually you set "language" from the command line for these cases.
|
||||||
|
language = None
|
||||||
|
|
||||||
|
# List of patterns, relative to source directory, that match files and
|
||||||
|
# directories to ignore when looking for source files.
|
||||||
|
# This pattern also affects html_static_path and html_extra_path.
|
||||||
|
exclude_patterns = []
|
||||||
|
|
||||||
|
# The name of the Pygments (syntax highlighting) style to use.
|
||||||
|
pygments_style = None
|
||||||
|
|
||||||
|
# -- Options for HTML output -------------------------------------------------
|
||||||
|
|
||||||
|
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||||
|
# a list of builtin themes.
|
||||||
|
#
|
||||||
|
|
||||||
|
import sphinx_rtd_theme
|
||||||
|
html_theme = 'sphinx_rtd_theme'
|
||||||
|
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
|
||||||
|
smartquotes = False
|
||||||
|
|
||||||
|
# Theme options are theme-specific and customize the look and feel of a theme
|
||||||
|
# further. For a list of options available for each theme, see the
|
||||||
|
# documentation.
|
||||||
|
#
|
||||||
|
# html_theme_options = {}
|
||||||
|
|
||||||
|
# Add any paths that contain custom static files (such as style sheets) here,
|
||||||
|
# relative to this directory. They are copied after the builtin static files,
|
||||||
|
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||||
|
html_static_path = ['_static']
|
||||||
|
html_logo = '../images/paddle.png'
|
||||||
|
html_css_files = [
|
||||||
|
'custom.css',
|
||||||
|
]
|
||||||
|
|
||||||
|
# Custom sidebar templates, must be a dictionary that maps document names
|
||||||
|
# to template names.
|
||||||
|
#
|
||||||
|
# The default sidebars (for documents that don't match any pattern) are
|
||||||
|
# defined by theme itself. Builtin themes are using these templates by
|
||||||
|
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
|
||||||
|
# 'searchbox.html']``.
|
||||||
|
#
|
||||||
|
# html_sidebars = {}
|
||||||
|
|
||||||
|
# -- Options for HTMLHelp output ---------------------------------------------
|
||||||
|
|
||||||
|
# Output file base name for HTML help builder.
|
||||||
|
htmlhelp_basename = 'PaddleAudiodoc'
|
||||||
|
|
||||||
|
# -- Options for LaTeX output ------------------------------------------------
|
||||||
|
|
||||||
|
latex_elements = {
|
||||||
|
# The paper size ('letterpaper' or 'a4paper').
|
||||||
|
#
|
||||||
|
# 'papersize': 'letterpaper',
|
||||||
|
|
||||||
|
# The font size ('10pt', '11pt' or '12pt').
|
||||||
|
#
|
||||||
|
# 'pointsize': '10pt',
|
||||||
|
|
||||||
|
# Additional stuff for the LaTeX preamble.
|
||||||
|
#
|
||||||
|
# 'preamble': '',
|
||||||
|
|
||||||
|
# Latex figure (float) alignment
|
||||||
|
#
|
||||||
|
# 'figure_align': 'htbp',
|
||||||
|
}
|
||||||
|
|
||||||
|
# Grouping the document tree into LaTeX files. List of tuples
|
||||||
|
# (source start file, target name, title,
|
||||||
|
# author, documentclass [howto, manual, or own class]).
|
||||||
|
latex_documents = [
|
||||||
|
(master_doc, 'PaddleAudio.tex', 'PaddleAudio Documentation', 'PaddlePaddle',
|
||||||
|
'manual'),
|
||||||
|
]
|
||||||
|
|
||||||
|
# -- Options for manual page output ------------------------------------------
|
||||||
|
|
||||||
|
# One entry per manual page. List of tuples
|
||||||
|
# (source start file, name, description, authors, manual section).
|
||||||
|
man_pages = [(master_doc, 'paddleaudio', 'PaddleAudio Documentation', [author],
|
||||||
|
1)]
|
||||||
|
|
||||||
|
# -- Options for Texinfo output ----------------------------------------------
|
||||||
|
|
||||||
|
# Grouping the document tree into Texinfo files. List of tuples
|
||||||
|
# (source start file, target name, title, author,
|
||||||
|
# dir menu entry, description, category)
|
||||||
|
texinfo_documents = [
|
||||||
|
(master_doc, 'PaddleAudio', 'PaddleAudio Documentation', author,
|
||||||
|
'PaddleAudio', 'One line description of project.', 'Miscellaneous'),
|
||||||
|
]
|
||||||
|
|
||||||
|
# -- Options for Epub output -------------------------------------------------
|
||||||
|
|
||||||
|
# Bibliographic Dublin Core info.
|
||||||
|
epub_title = project
|
||||||
|
|
||||||
|
# The unique identifier of the text. This can be a ISBN number
|
||||||
|
# or the project homepage.
|
||||||
|
#
|
||||||
|
# epub_identifier = ''
|
||||||
|
|
||||||
|
# A unique identification for the text.
|
||||||
|
#
|
||||||
|
# epub_uid = ''
|
||||||
|
|
||||||
|
# A list of files that should not be packed into the epub file.
|
||||||
|
epub_exclude_files = ['search.html']
|
||||||
|
|
||||||
|
# -- Extension configuration -------------------------------------------------
|
||||||
|
|
||||||
|
# -- Options for intersphinx extension ---------------------------------------
|
||||||
|
|
||||||
|
# Example configuration for intersphinx: refer to the Python standard library.
|
||||||
|
intersphinx_mapping = {'https://docs.python.org/': None}
|
@ -0,0 +1,22 @@
|
|||||||
|
.. PaddleAudio documentation master file, created by
|
||||||
|
sphinx-quickstart on Tue Mar 22 15:57:16 2022.
|
||||||
|
You can adapt this file completely to your liking, but it should at least
|
||||||
|
contain the root `toctree` directive.
|
||||||
|
|
||||||
|
Welcome to PaddleAudio's documentation!
|
||||||
|
=======================================
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 1
|
||||||
|
|
||||||
|
Index <self>
|
||||||
|
|
||||||
|
|
||||||
|
API References
|
||||||
|
--------------
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 2
|
||||||
|
:titlesonly:
|
||||||
|
|
||||||
|
paddleaudio
|
@ -0,0 +1,201 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import collections
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from paddle.io import Dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from ..backends import load as load_audio
|
||||||
|
from ..backends import save as save_wav
|
||||||
|
from ..utils import DATA_HOME
|
||||||
|
from ..utils.download import download_and_decompress
|
||||||
|
from .dataset import feat_funcs
|
||||||
|
|
||||||
|
__all__ = ['OpenRIRNoise']
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRIRNoise(Dataset):
|
||||||
|
archieves = [
|
||||||
|
{
|
||||||
|
'url': 'http://www.openslr.org/resources/28/rirs_noises.zip',
|
||||||
|
'md5': 'e6f48e257286e05de56413b4779d8ffb',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
sample_rate = 16000
|
||||||
|
meta_info = collections.namedtuple('META_INFO', ('id', 'duration', 'wav'))
|
||||||
|
base_path = os.path.join(DATA_HOME, 'open_rir_noise')
|
||||||
|
wav_path = os.path.join(base_path, 'RIRS_NOISES')
|
||||||
|
csv_path = os.path.join(base_path, 'csv')
|
||||||
|
subsets = ['rir', 'noise']
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
subset: str='rir',
|
||||||
|
feat_type: str='raw',
|
||||||
|
target_dir=None,
|
||||||
|
random_chunk: bool=True,
|
||||||
|
chunk_duration: float=3.0,
|
||||||
|
seed: int=0,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
assert subset in self.subsets, \
|
||||||
|
'Dataset subset must be one in {}, but got {}'.format(self.subsets, subset)
|
||||||
|
|
||||||
|
self.subset = subset
|
||||||
|
self.feat_type = feat_type
|
||||||
|
self.feat_config = kwargs
|
||||||
|
self.random_chunk = random_chunk
|
||||||
|
self.chunk_duration = chunk_duration
|
||||||
|
|
||||||
|
OpenRIRNoise.csv_path = os.path.join(
|
||||||
|
target_dir, "open_rir_noise",
|
||||||
|
"csv") if target_dir else self.csv_path
|
||||||
|
self._data = self._get_data()
|
||||||
|
super(OpenRIRNoise, self).__init__()
|
||||||
|
|
||||||
|
# Set up a seed to reproduce training or predicting result.
|
||||||
|
# random.seed(seed)
|
||||||
|
|
||||||
|
def _get_data(self):
|
||||||
|
# Download audio files.
|
||||||
|
print(f"rirs noises base path: {self.base_path}")
|
||||||
|
if not os.path.isdir(self.base_path):
|
||||||
|
download_and_decompress(
|
||||||
|
self.archieves, self.base_path, decompress=True)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"{self.base_path} already exists, we will not download and decompress again"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Data preparation.
|
||||||
|
print(f"prepare the csv to {self.csv_path}")
|
||||||
|
if not os.path.isdir(self.csv_path):
|
||||||
|
os.makedirs(self.csv_path)
|
||||||
|
self.prepare_data()
|
||||||
|
|
||||||
|
data = []
|
||||||
|
with open(os.path.join(self.csv_path, f'{self.subset}.csv'), 'r') as rf:
|
||||||
|
for line in rf.readlines()[1:]:
|
||||||
|
audio_id, duration, wav = line.strip().split(',')
|
||||||
|
data.append(self.meta_info(audio_id, float(duration), wav))
|
||||||
|
|
||||||
|
random.shuffle(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _convert_to_record(self, idx: int):
|
||||||
|
sample = self._data[idx]
|
||||||
|
|
||||||
|
record = {}
|
||||||
|
# To show all fields in a namedtuple: `type(sample)._fields`
|
||||||
|
for field in type(sample)._fields:
|
||||||
|
record[field] = getattr(sample, field)
|
||||||
|
|
||||||
|
waveform, sr = load_audio(record['wav'])
|
||||||
|
|
||||||
|
assert self.feat_type in feat_funcs.keys(), \
|
||||||
|
f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}"
|
||||||
|
feat_func = feat_funcs[self.feat_type]
|
||||||
|
feat = feat_func(
|
||||||
|
waveform, sr=sr, **self.feat_config) if feat_func else waveform
|
||||||
|
|
||||||
|
record.update({'feat': feat})
|
||||||
|
return record
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_chunks(seg_dur, audio_id, audio_duration):
|
||||||
|
num_chunks = int(audio_duration / seg_dur) # all in milliseconds
|
||||||
|
|
||||||
|
chunk_lst = [
|
||||||
|
audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur)
|
||||||
|
for i in range(num_chunks)
|
||||||
|
]
|
||||||
|
return chunk_lst
|
||||||
|
|
||||||
|
def _get_audio_info(self, wav_file: str,
|
||||||
|
split_chunks: bool) -> List[List[str]]:
|
||||||
|
waveform, sr = load_audio(wav_file)
|
||||||
|
audio_id = wav_file.split("/open_rir_noise/")[-1].split(".")[0]
|
||||||
|
audio_duration = waveform.shape[0] / sr
|
||||||
|
|
||||||
|
ret = []
|
||||||
|
if split_chunks and audio_duration > self.chunk_duration: # Split into pieces of self.chunk_duration seconds.
|
||||||
|
uniq_chunks_list = self._get_chunks(self.chunk_duration, audio_id,
|
||||||
|
audio_duration)
|
||||||
|
|
||||||
|
for idx, chunk in enumerate(uniq_chunks_list):
|
||||||
|
s, e = chunk.split("_")[-2:] # Timestamps of start and end
|
||||||
|
start_sample = int(float(s) * sr)
|
||||||
|
end_sample = int(float(e) * sr)
|
||||||
|
new_wav_file = os.path.join(self.base_path,
|
||||||
|
audio_id + f'_chunk_{idx+1:02}.wav')
|
||||||
|
save_wav(waveform[start_sample:end_sample], sr, new_wav_file)
|
||||||
|
# id, duration, new_wav
|
||||||
|
ret.append([chunk, self.chunk_duration, new_wav_file])
|
||||||
|
else: # Keep whole audio.
|
||||||
|
ret.append([audio_id, audio_duration, wav_file])
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def generate_csv(self,
|
||||||
|
wav_files: List[str],
|
||||||
|
output_file: str,
|
||||||
|
split_chunks: bool=True):
|
||||||
|
print(f'Generating csv: {output_file}')
|
||||||
|
header = ["id", "duration", "wav"]
|
||||||
|
|
||||||
|
infos = list(
|
||||||
|
tqdm(
|
||||||
|
map(self._get_audio_info, wav_files, [split_chunks] * len(
|
||||||
|
wav_files)),
|
||||||
|
total=len(wav_files)))
|
||||||
|
|
||||||
|
csv_lines = []
|
||||||
|
for info in infos:
|
||||||
|
csv_lines.extend(info)
|
||||||
|
|
||||||
|
with open(output_file, mode="w") as csv_f:
|
||||||
|
csv_writer = csv.writer(
|
||||||
|
csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
||||||
|
csv_writer.writerow(header)
|
||||||
|
for line in csv_lines:
|
||||||
|
csv_writer.writerow(line)
|
||||||
|
|
||||||
|
def prepare_data(self):
|
||||||
|
rir_list = os.path.join(self.wav_path, "real_rirs_isotropic_noises",
|
||||||
|
"rir_list")
|
||||||
|
rir_files = []
|
||||||
|
with open(rir_list, 'r') as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
rir_file = line.strip().split(' ')[-1]
|
||||||
|
rir_files.append(os.path.join(self.base_path, rir_file))
|
||||||
|
|
||||||
|
noise_list = os.path.join(self.wav_path, "pointsource_noises",
|
||||||
|
"noise_list")
|
||||||
|
noise_files = []
|
||||||
|
with open(noise_list, 'r') as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
noise_file = line.strip().split(' ')[-1]
|
||||||
|
noise_files.append(os.path.join(self.base_path, noise_file))
|
||||||
|
|
||||||
|
self.generate_csv(rir_files, os.path.join(self.csv_path, 'rir.csv'))
|
||||||
|
self.generate_csv(noise_files, os.path.join(self.csv_path, 'noise.csv'))
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self._convert_to_record(idx)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._data)
|
@ -0,0 +1,356 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import collections
|
||||||
|
import csv
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from multiprocessing import cpu_count
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from paddle.io import Dataset
|
||||||
|
from pathos.multiprocessing import Pool
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from ..backends import load as load_audio
|
||||||
|
from ..utils import DATA_HOME
|
||||||
|
from ..utils import decompress
|
||||||
|
from ..utils.download import download_and_decompress
|
||||||
|
from .dataset import feat_funcs
|
||||||
|
|
||||||
|
__all__ = ['VoxCeleb']
|
||||||
|
|
||||||
|
|
||||||
|
class VoxCeleb(Dataset):
|
||||||
|
source_url = 'https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/'
|
||||||
|
archieves_audio_dev = [
|
||||||
|
{
|
||||||
|
'url': source_url + 'vox1_dev_wav_partaa',
|
||||||
|
'md5': 'e395d020928bc15670b570a21695ed96',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + 'vox1_dev_wav_partab',
|
||||||
|
'md5': 'bbfaaccefab65d82b21903e81a8a8020',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + 'vox1_dev_wav_partac',
|
||||||
|
'md5': '017d579a2a96a077f40042ec33e51512',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + 'vox1_dev_wav_partad',
|
||||||
|
'md5': '7bb1e9f70fddc7a678fa998ea8b3ba19',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
archieves_audio_test = [
|
||||||
|
{
|
||||||
|
'url': source_url + 'vox1_test_wav.zip',
|
||||||
|
'md5': '185fdc63c3c739954633d50379a3d102',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
archieves_meta = [
|
||||||
|
{
|
||||||
|
'url':
|
||||||
|
'https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt',
|
||||||
|
'md5':
|
||||||
|
'b73110731c9223c1461fe49cb48dddfc',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
num_speakers = 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
|
||||||
|
sample_rate = 16000
|
||||||
|
meta_info = collections.namedtuple(
|
||||||
|
'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'spk_id'))
|
||||||
|
base_path = os.path.join(DATA_HOME, 'vox1')
|
||||||
|
wav_path = os.path.join(base_path, 'wav')
|
||||||
|
meta_path = os.path.join(base_path, 'meta')
|
||||||
|
veri_test_file = os.path.join(meta_path, 'veri_test2.txt')
|
||||||
|
csv_path = os.path.join(base_path, 'csv')
|
||||||
|
subsets = ['train', 'dev', 'enroll', 'test']
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
subset: str='train',
|
||||||
|
feat_type: str='raw',
|
||||||
|
random_chunk: bool=True,
|
||||||
|
chunk_duration: float=3.0, # seconds
|
||||||
|
split_ratio: float=0.9, # train split ratio
|
||||||
|
seed: int=0,
|
||||||
|
target_dir: str=None,
|
||||||
|
vox2_base_path=None,
|
||||||
|
**kwargs):
|
||||||
|
"""VoxCeleb data prepare and get the specific dataset audio info
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subset (str, optional): dataset name, such as train, dev, enroll or test. Defaults to 'train'.
|
||||||
|
feat_type (str, optional): feat type, such raw, melspectrogram(fbank) or mfcc . Defaults to 'raw'.
|
||||||
|
random_chunk (bool, optional): random select a duration from audio. Defaults to True.
|
||||||
|
chunk_duration (float, optional): chunk duration if random_chunk flag is set. Defaults to 3.0.
|
||||||
|
target_dir (str, optional): data dir, audio info will be stored in this directory. Defaults to None.
|
||||||
|
vox2_base_path (_type_, optional): vox2 directory. vox2 data must be converted from m4a to wav. Defaults to None.
|
||||||
|
"""
|
||||||
|
assert subset in self.subsets, \
|
||||||
|
'Dataset subset must be one in {}, but got {}'.format(self.subsets, subset)
|
||||||
|
|
||||||
|
self.subset = subset
|
||||||
|
self.spk_id2label = {}
|
||||||
|
self.feat_type = feat_type
|
||||||
|
self.feat_config = kwargs
|
||||||
|
self.random_chunk = random_chunk
|
||||||
|
self.chunk_duration = chunk_duration
|
||||||
|
self.split_ratio = split_ratio
|
||||||
|
self.target_dir = target_dir if target_dir else VoxCeleb.base_path
|
||||||
|
self.vox2_base_path = vox2_base_path
|
||||||
|
|
||||||
|
# if we set the target dir, we will change the vox data info data from base path to target dir
|
||||||
|
VoxCeleb.csv_path = os.path.join(
|
||||||
|
target_dir, "voxceleb", 'csv') if target_dir else VoxCeleb.csv_path
|
||||||
|
VoxCeleb.meta_path = os.path.join(
|
||||||
|
target_dir, "voxceleb",
|
||||||
|
'meta') if target_dir else VoxCeleb.meta_path
|
||||||
|
VoxCeleb.veri_test_file = os.path.join(VoxCeleb.meta_path,
|
||||||
|
'veri_test2.txt')
|
||||||
|
# self._data = self._get_data()[:1000] # KP: Small dataset test.
|
||||||
|
self._data = self._get_data()
|
||||||
|
super(VoxCeleb, self).__init__()
|
||||||
|
|
||||||
|
# Set up a seed to reproduce training or predicting result.
|
||||||
|
# random.seed(seed)
|
||||||
|
|
||||||
|
def _get_data(self):
|
||||||
|
# Download audio files.
|
||||||
|
# We need the users to decompress all vox1/dev/wav and vox1/test/wav/ to vox1/wav/ dir
|
||||||
|
# so, we check the vox1/wav dir status
|
||||||
|
print(f"wav base path: {self.wav_path}")
|
||||||
|
if not os.path.isdir(self.wav_path):
|
||||||
|
print("start to download the voxceleb1 dataset")
|
||||||
|
download_and_decompress( # multi-zip parts concatenate to vox1_dev_wav.zip
|
||||||
|
self.archieves_audio_dev,
|
||||||
|
self.base_path,
|
||||||
|
decompress=False)
|
||||||
|
download_and_decompress( # download the vox1_test_wav.zip and unzip
|
||||||
|
self.archieves_audio_test,
|
||||||
|
self.base_path,
|
||||||
|
decompress=True)
|
||||||
|
|
||||||
|
# Download all parts and concatenate the files into one zip file.
|
||||||
|
dev_zipfile = os.path.join(self.base_path, 'vox1_dev_wav.zip')
|
||||||
|
print(f'Concatenating all parts to: {dev_zipfile}')
|
||||||
|
os.system(
|
||||||
|
f'cat {os.path.join(self.base_path, "vox1_dev_wav_parta*")} > {dev_zipfile}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract all audio files of dev and test set.
|
||||||
|
decompress(dev_zipfile, self.base_path)
|
||||||
|
|
||||||
|
# Download meta files.
|
||||||
|
if not os.path.isdir(self.meta_path):
|
||||||
|
print("prepare the meta data")
|
||||||
|
download_and_decompress(
|
||||||
|
self.archieves_meta, self.meta_path, decompress=False)
|
||||||
|
|
||||||
|
# Data preparation.
|
||||||
|
if not os.path.isdir(self.csv_path):
|
||||||
|
os.makedirs(self.csv_path)
|
||||||
|
self.prepare_data()
|
||||||
|
|
||||||
|
data = []
|
||||||
|
print(
|
||||||
|
f"read the {self.subset} from {os.path.join(self.csv_path, f'{self.subset}.csv')}"
|
||||||
|
)
|
||||||
|
with open(os.path.join(self.csv_path, f'{self.subset}.csv'), 'r') as rf:
|
||||||
|
for line in rf.readlines()[1:]:
|
||||||
|
audio_id, duration, wav, start, stop, spk_id = line.strip(
|
||||||
|
).split(',')
|
||||||
|
data.append(
|
||||||
|
self.meta_info(audio_id,
|
||||||
|
float(duration), wav,
|
||||||
|
int(start), int(stop), spk_id))
|
||||||
|
|
||||||
|
with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'r') as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
spk_id, label = line.strip().split(' ')
|
||||||
|
self.spk_id2label[spk_id] = int(label)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _convert_to_record(self, idx: int):
|
||||||
|
sample = self._data[idx]
|
||||||
|
|
||||||
|
record = {}
|
||||||
|
# To show all fields in a namedtuple: `type(sample)._fields`
|
||||||
|
for field in type(sample)._fields:
|
||||||
|
record[field] = getattr(sample, field)
|
||||||
|
|
||||||
|
waveform, sr = load_audio(record['wav'])
|
||||||
|
|
||||||
|
# random select a chunk audio samples from the audio
|
||||||
|
if self.random_chunk:
|
||||||
|
num_wav_samples = waveform.shape[0]
|
||||||
|
num_chunk_samples = int(self.chunk_duration * sr)
|
||||||
|
start = random.randint(0, num_wav_samples - num_chunk_samples - 1)
|
||||||
|
stop = start + num_chunk_samples
|
||||||
|
else:
|
||||||
|
start = record['start']
|
||||||
|
stop = record['stop']
|
||||||
|
|
||||||
|
waveform = waveform[start:stop]
|
||||||
|
|
||||||
|
assert self.feat_type in feat_funcs.keys(), \
|
||||||
|
f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}"
|
||||||
|
feat_func = feat_funcs[self.feat_type]
|
||||||
|
feat = feat_func(
|
||||||
|
waveform, sr=sr, **self.feat_config) if feat_func else waveform
|
||||||
|
|
||||||
|
record.update({'feat': feat})
|
||||||
|
if self.subset in ['train',
|
||||||
|
'dev']: # Labels are available in train and dev.
|
||||||
|
record.update({'label': self.spk_id2label[record['spk_id']]})
|
||||||
|
|
||||||
|
return record
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_chunks(seg_dur, audio_id, audio_duration):
|
||||||
|
num_chunks = int(audio_duration / seg_dur) # all in milliseconds
|
||||||
|
|
||||||
|
chunk_lst = [
|
||||||
|
audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur)
|
||||||
|
for i in range(num_chunks)
|
||||||
|
]
|
||||||
|
return chunk_lst
|
||||||
|
|
||||||
|
def _get_audio_info(self, wav_file: str,
|
||||||
|
split_chunks: bool) -> List[List[str]]:
|
||||||
|
waveform, sr = load_audio(wav_file)
|
||||||
|
spk_id, sess_id, utt_id = wav_file.split("/")[-3:]
|
||||||
|
audio_id = '-'.join([spk_id, sess_id, utt_id.split(".")[0]])
|
||||||
|
audio_duration = waveform.shape[0] / sr
|
||||||
|
|
||||||
|
ret = []
|
||||||
|
if split_chunks: # Split into pieces of self.chunk_duration seconds.
|
||||||
|
uniq_chunks_list = self._get_chunks(self.chunk_duration, audio_id,
|
||||||
|
audio_duration)
|
||||||
|
|
||||||
|
for chunk in uniq_chunks_list:
|
||||||
|
s, e = chunk.split("_")[-2:] # Timestamps of start and end
|
||||||
|
start_sample = int(float(s) * sr)
|
||||||
|
end_sample = int(float(e) * sr)
|
||||||
|
# id, duration, wav, start, stop, spk_id
|
||||||
|
ret.append([
|
||||||
|
chunk, audio_duration, wav_file, start_sample, end_sample,
|
||||||
|
spk_id
|
||||||
|
])
|
||||||
|
else: # Keep whole audio.
|
||||||
|
ret.append([
|
||||||
|
audio_id, audio_duration, wav_file, 0, waveform.shape[0], spk_id
|
||||||
|
])
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def generate_csv(self,
|
||||||
|
wav_files: List[str],
|
||||||
|
output_file: str,
|
||||||
|
split_chunks: bool=True):
|
||||||
|
print(f'Generating csv: {output_file}')
|
||||||
|
header = ["ID", "duration", "wav", "start", "stop", "spk_id"]
|
||||||
|
# Note: this may occurs c++ execption, but the program will execute fine
|
||||||
|
# so we can ignore the execption
|
||||||
|
with Pool(cpu_count()) as p:
|
||||||
|
infos = list(
|
||||||
|
tqdm(
|
||||||
|
p.imap(lambda x: self._get_audio_info(x, split_chunks),
|
||||||
|
wav_files),
|
||||||
|
total=len(wav_files)))
|
||||||
|
|
||||||
|
csv_lines = []
|
||||||
|
for info in infos:
|
||||||
|
csv_lines.extend(info)
|
||||||
|
|
||||||
|
with open(output_file, mode="w") as csv_f:
|
||||||
|
csv_writer = csv.writer(
|
||||||
|
csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
||||||
|
csv_writer.writerow(header)
|
||||||
|
for line in csv_lines:
|
||||||
|
csv_writer.writerow(line)
|
||||||
|
|
||||||
|
def prepare_data(self):
|
||||||
|
# Audio of speakers in veri_test_file should not be included in training set.
|
||||||
|
print("start to prepare the data csv file")
|
||||||
|
enroll_files = set()
|
||||||
|
test_files = set()
|
||||||
|
# get the enroll and test audio file path
|
||||||
|
with open(self.veri_test_file, 'r') as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
_, enrol_file, test_file = line.strip().split(' ')
|
||||||
|
enroll_files.add(os.path.join(self.wav_path, enrol_file))
|
||||||
|
test_files.add(os.path.join(self.wav_path, test_file))
|
||||||
|
enroll_files = sorted(enroll_files)
|
||||||
|
test_files = sorted(test_files)
|
||||||
|
|
||||||
|
# get the enroll and test speakers
|
||||||
|
test_spks = set()
|
||||||
|
for file in (enroll_files + test_files):
|
||||||
|
spk = file.split('/wav/')[1].split('/')[0]
|
||||||
|
test_spks.add(spk)
|
||||||
|
|
||||||
|
# get all the train and dev audios file path
|
||||||
|
audio_files = []
|
||||||
|
speakers = set()
|
||||||
|
print("Getting file list...")
|
||||||
|
for path in [self.wav_path, self.vox2_base_path]:
|
||||||
|
# if vox2 directory is not set and vox2 is not a directory
|
||||||
|
# we will not process this directory
|
||||||
|
if not path or not os.path.exists(path):
|
||||||
|
print(f"{path} is an invalid path, please check again, "
|
||||||
|
"and we will ignore the vox2 base path")
|
||||||
|
continue
|
||||||
|
for file in glob.glob(
|
||||||
|
os.path.join(path, "**", "*.wav"), recursive=True):
|
||||||
|
spk = file.split('/wav/')[1].split('/')[0]
|
||||||
|
if spk in test_spks:
|
||||||
|
continue
|
||||||
|
speakers.add(spk)
|
||||||
|
audio_files.append(file)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"start to generate the {os.path.join(self.meta_path, 'spk_id2label.txt')}"
|
||||||
|
)
|
||||||
|
# encode the train and dev speakers label to spk_id2label.txt
|
||||||
|
with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'w') as f:
|
||||||
|
for label, spk_id in enumerate(
|
||||||
|
sorted(speakers)): # 1211 vox1, 5994 vox2, 7205 vox1+2
|
||||||
|
f.write(f'{spk_id} {label}\n')
|
||||||
|
|
||||||
|
audio_files = sorted(audio_files)
|
||||||
|
random.shuffle(audio_files)
|
||||||
|
split_idx = int(self.split_ratio * len(audio_files))
|
||||||
|
# split_ratio to train
|
||||||
|
train_files, dev_files = audio_files[:split_idx], audio_files[
|
||||||
|
split_idx:]
|
||||||
|
|
||||||
|
self.generate_csv(train_files, os.path.join(self.csv_path, 'train.csv'))
|
||||||
|
self.generate_csv(dev_files, os.path.join(self.csv_path, 'dev.csv'))
|
||||||
|
|
||||||
|
self.generate_csv(
|
||||||
|
enroll_files,
|
||||||
|
os.path.join(self.csv_path, 'enroll.csv'),
|
||||||
|
split_chunks=False)
|
||||||
|
self.generate_csv(
|
||||||
|
test_files,
|
||||||
|
os.path.join(self.csv_path, 'test.csv'),
|
||||||
|
split_chunks=False)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self._convert_to_record(idx)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._data)
|
@ -0,0 +1,100 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
from sklearn.metrics import roc_curve
|
||||||
|
|
||||||
|
|
||||||
|
def compute_eer(labels: np.ndarray, scores: np.ndarray) -> List[float]:
|
||||||
|
"""Compute EER and return score threshold.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
labels (np.ndarray): the trial label, shape: [N], one-dimention, N refer to the samples num
|
||||||
|
scores (np.ndarray): the trial scores, shape: [N], one-dimention, N refer to the samples num
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: eer and the specific threshold
|
||||||
|
"""
|
||||||
|
fpr, tpr, threshold = roc_curve(y_true=labels, y_score=scores)
|
||||||
|
fnr = 1 - tpr
|
||||||
|
eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]
|
||||||
|
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
||||||
|
return eer, eer_threshold
|
||||||
|
|
||||||
|
|
||||||
|
def compute_minDCF(positive_scores,
|
||||||
|
negative_scores,
|
||||||
|
c_miss=1.0,
|
||||||
|
c_fa=1.0,
|
||||||
|
p_target=0.01):
|
||||||
|
"""
|
||||||
|
This is modified from SpeechBrain
|
||||||
|
https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/utils/metric_stats.py#L509
|
||||||
|
Computes the minDCF metric normally used to evaluate speaker verification
|
||||||
|
systems. The min_DCF is the minimum of the following C_det function computed
|
||||||
|
within the defined threshold range:
|
||||||
|
|
||||||
|
C_det = c_miss * p_miss * p_target + c_fa * p_fa * (1 -p_target)
|
||||||
|
|
||||||
|
where p_miss is the missing probability and p_fa is the probability of having
|
||||||
|
a false alarm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
positive_scores (Paddle.Tensor): The scores from entries of the same class.
|
||||||
|
negative_scores (Paddle.Tensor): The scores from entries of different classes.
|
||||||
|
c_miss (float, optional): Cost assigned to a missing error (default 1.0).
|
||||||
|
c_fa (float, optional): Cost assigned to a false alarm (default 1.0).
|
||||||
|
p_target (float, optional): Prior probability of having a target (default 0.01).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: min dcf and the specific threshold
|
||||||
|
"""
|
||||||
|
# Computing candidate thresholds
|
||||||
|
if len(positive_scores.shape) > 1:
|
||||||
|
positive_scores = positive_scores.squeeze()
|
||||||
|
|
||||||
|
if len(negative_scores.shape) > 1:
|
||||||
|
negative_scores = negative_scores.squeeze()
|
||||||
|
|
||||||
|
thresholds = paddle.sort(paddle.concat([positive_scores, negative_scores]))
|
||||||
|
thresholds = paddle.unique(thresholds)
|
||||||
|
|
||||||
|
# Adding intermediate thresholds
|
||||||
|
interm_thresholds = (thresholds[0:-1] + thresholds[1:]) / 2
|
||||||
|
thresholds = paddle.sort(paddle.concat([thresholds, interm_thresholds]))
|
||||||
|
|
||||||
|
# Computing False Rejection Rate (miss detection)
|
||||||
|
positive_scores = paddle.concat(
|
||||||
|
len(thresholds) * [positive_scores.unsqueeze(0)])
|
||||||
|
pos_scores_threshold = positive_scores.transpose(perm=[1, 0]) <= thresholds
|
||||||
|
p_miss = (pos_scores_threshold.sum(0)
|
||||||
|
).astype("float32") / positive_scores.shape[1]
|
||||||
|
del positive_scores
|
||||||
|
del pos_scores_threshold
|
||||||
|
|
||||||
|
# Computing False Acceptance Rate (false alarm)
|
||||||
|
negative_scores = paddle.concat(
|
||||||
|
len(thresholds) * [negative_scores.unsqueeze(0)])
|
||||||
|
neg_scores_threshold = negative_scores.transpose(perm=[1, 0]) > thresholds
|
||||||
|
p_fa = (neg_scores_threshold.sum(0)
|
||||||
|
).astype("float32") / negative_scores.shape[1]
|
||||||
|
del negative_scores
|
||||||
|
del neg_scores_threshold
|
||||||
|
|
||||||
|
c_det = c_miss * p_miss * p_target + c_fa * p_fa * (1 - p_target)
|
||||||
|
c_min = paddle.min(c_det, axis=0)
|
||||||
|
min_index = paddle.argmin(c_det, axis=0)
|
||||||
|
return float(c_min), float(thresholds[min_index])
|
@ -0,0 +1,14 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from .infer import VectorExecutor
|
@ -0,0 +1,139 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
from paddlespeech.s2t.modules.initializer import KaimingUniform
|
||||||
|
"""
|
||||||
|
To align the initializer between paddle and torch,
|
||||||
|
the API below are set defalut initializer with priority higger than global initializer.
|
||||||
|
"""
|
||||||
|
global_init_type = None
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.LayerNorm):
|
||||||
|
def __init__(self,
|
||||||
|
normalized_shape,
|
||||||
|
epsilon=1e-05,
|
||||||
|
weight_attr=None,
|
||||||
|
bias_attr=None,
|
||||||
|
name=None):
|
||||||
|
if weight_attr is None:
|
||||||
|
weight_attr = paddle.ParamAttr(
|
||||||
|
initializer=nn.initializer.Constant(1.0))
|
||||||
|
if bias_attr is None:
|
||||||
|
bias_attr = paddle.ParamAttr(
|
||||||
|
initializer=nn.initializer.Constant(0.0))
|
||||||
|
super(LayerNorm, self).__init__(normalized_shape, epsilon, weight_attr,
|
||||||
|
bias_attr, name)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchNorm1D(nn.BatchNorm1D):
|
||||||
|
def __init__(self,
|
||||||
|
num_features,
|
||||||
|
momentum=0.9,
|
||||||
|
epsilon=1e-05,
|
||||||
|
weight_attr=None,
|
||||||
|
bias_attr=None,
|
||||||
|
data_format='NCL',
|
||||||
|
name=None):
|
||||||
|
if weight_attr is None:
|
||||||
|
weight_attr = paddle.ParamAttr(
|
||||||
|
initializer=nn.initializer.Constant(1.0))
|
||||||
|
if bias_attr is None:
|
||||||
|
bias_attr = paddle.ParamAttr(
|
||||||
|
initializer=nn.initializer.Constant(0.0))
|
||||||
|
super(BatchNorm1D,
|
||||||
|
self).__init__(num_features, momentum, epsilon, weight_attr,
|
||||||
|
bias_attr, data_format, name)
|
||||||
|
|
||||||
|
|
||||||
|
class Embedding(nn.Embedding):
|
||||||
|
def __init__(self,
|
||||||
|
num_embeddings,
|
||||||
|
embedding_dim,
|
||||||
|
padding_idx=None,
|
||||||
|
sparse=False,
|
||||||
|
weight_attr=None,
|
||||||
|
name=None):
|
||||||
|
if weight_attr is None:
|
||||||
|
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal())
|
||||||
|
super(Embedding, self).__init__(num_embeddings, embedding_dim,
|
||||||
|
padding_idx, sparse, weight_attr, name)
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(nn.Linear):
|
||||||
|
def __init__(self,
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
weight_attr=None,
|
||||||
|
bias_attr=None,
|
||||||
|
name=None):
|
||||||
|
if weight_attr is None:
|
||||||
|
if global_init_type == "kaiming_uniform":
|
||||||
|
weight_attr = paddle.ParamAttr(initializer=KaimingUniform())
|
||||||
|
if bias_attr is None:
|
||||||
|
if global_init_type == "kaiming_uniform":
|
||||||
|
bias_attr = paddle.ParamAttr(initializer=KaimingUniform())
|
||||||
|
super(Linear, self).__init__(in_features, out_features, weight_attr,
|
||||||
|
bias_attr, name)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1D(nn.Conv1D):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
padding_mode='zeros',
|
||||||
|
weight_attr=None,
|
||||||
|
bias_attr=None,
|
||||||
|
data_format='NCL'):
|
||||||
|
if weight_attr is None:
|
||||||
|
if global_init_type == "kaiming_uniform":
|
||||||
|
print("set kaiming_uniform")
|
||||||
|
weight_attr = paddle.ParamAttr(initializer=KaimingUniform())
|
||||||
|
if bias_attr is None:
|
||||||
|
if global_init_type == "kaiming_uniform":
|
||||||
|
bias_attr = paddle.ParamAttr(initializer=KaimingUniform())
|
||||||
|
super(Conv1D, self).__init__(
|
||||||
|
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||||
|
groups, padding_mode, weight_attr, bias_attr, data_format)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2D(nn.Conv2D):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
padding_mode='zeros',
|
||||||
|
weight_attr=None,
|
||||||
|
bias_attr=None,
|
||||||
|
data_format='NCHW'):
|
||||||
|
if weight_attr is None:
|
||||||
|
if global_init_type == "kaiming_uniform":
|
||||||
|
weight_attr = paddle.ParamAttr(initializer=KaimingUniform())
|
||||||
|
if bias_attr is None:
|
||||||
|
if global_init_type == "kaiming_uniform":
|
||||||
|
bias_attr = paddle.ParamAttr(initializer=KaimingUniform())
|
||||||
|
super(Conv2D, self).__init__(
|
||||||
|
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||||
|
groups, padding_mode, weight_attr, bias_attr, data_format)
|
@ -0,0 +1,172 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import numpy as np
|
||||||
|
from paddle.fluid import framework
|
||||||
|
from paddle.fluid import unique_name
|
||||||
|
from paddle.fluid.core import VarDesc
|
||||||
|
from paddle.fluid.initializer import MSRAInitializer
|
||||||
|
|
||||||
|
__all__ = ['KaimingUniform']
|
||||||
|
|
||||||
|
|
||||||
|
class KaimingUniform(MSRAInitializer):
|
||||||
|
r"""Implements the Kaiming Uniform initializer
|
||||||
|
|
||||||
|
This class implements the weight initialization from the paper
|
||||||
|
`Delving Deep into Rectifiers: Surpassing Human-Level Performance on
|
||||||
|
ImageNet Classification <https://arxiv.org/abs/1502.01852>`_
|
||||||
|
by Kaiming He, Xiangyu Zhang, Shaoqing Ren and Jian Sun. This is a
|
||||||
|
robust initialization method that particularly considers the rectifier
|
||||||
|
nonlinearities.
|
||||||
|
|
||||||
|
In case of Uniform distribution, the range is [-x, x], where
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
x = \sqrt{\frac{1.0}{fan\_in}}
|
||||||
|
|
||||||
|
In case of Normal distribution, the mean is 0 and the standard deviation
|
||||||
|
is
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
\sqrt{\\frac{2.0}{fan\_in}}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fan_in (float32|None): fan_in for Kaiming uniform Initializer. If None, it is\
|
||||||
|
inferred from the variable. default is None.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
It is recommended to set fan_in to None for most cases.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
|
||||||
|
linear = nn.Linear(2,
|
||||||
|
4,
|
||||||
|
weight_attr=nn.initializer.KaimingUniform())
|
||||||
|
data = paddle.rand([30, 10, 2], dtype='float32')
|
||||||
|
res = linear(data)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, fan_in=None):
|
||||||
|
super(KaimingUniform, self).__init__(
|
||||||
|
uniform=True, fan_in=fan_in, seed=0)
|
||||||
|
|
||||||
|
def __call__(self, var, block=None):
|
||||||
|
"""Initialize the input tensor with MSRA initialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
var(Tensor): Tensor that needs to be initialized.
|
||||||
|
block(Block, optional): The block in which initialization ops
|
||||||
|
should be added. Used in static graph only, default None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The initialization op
|
||||||
|
"""
|
||||||
|
block = self._check_block(block)
|
||||||
|
|
||||||
|
assert isinstance(var, framework.Variable)
|
||||||
|
assert isinstance(block, framework.Block)
|
||||||
|
f_in, f_out = self._compute_fans(var)
|
||||||
|
|
||||||
|
# If fan_in is passed, use it
|
||||||
|
fan_in = f_in if self._fan_in is None else self._fan_in
|
||||||
|
|
||||||
|
if self._seed == 0:
|
||||||
|
self._seed = block.program.random_seed
|
||||||
|
|
||||||
|
# to be compatible of fp16 initalizers
|
||||||
|
if var.dtype == VarDesc.VarType.FP16 or (
|
||||||
|
var.dtype == VarDesc.VarType.BF16 and not self._uniform):
|
||||||
|
out_dtype = VarDesc.VarType.FP32
|
||||||
|
out_var = block.create_var(
|
||||||
|
name=unique_name.generate(
|
||||||
|
".".join(['masra_init', var.name, 'tmp'])),
|
||||||
|
shape=var.shape,
|
||||||
|
dtype=out_dtype,
|
||||||
|
type=VarDesc.VarType.LOD_TENSOR,
|
||||||
|
persistable=False)
|
||||||
|
else:
|
||||||
|
out_dtype = var.dtype
|
||||||
|
out_var = var
|
||||||
|
|
||||||
|
if self._uniform:
|
||||||
|
limit = np.sqrt(1.0 / float(fan_in))
|
||||||
|
op = block.append_op(
|
||||||
|
type="uniform_random",
|
||||||
|
inputs={},
|
||||||
|
outputs={"Out": out_var},
|
||||||
|
attrs={
|
||||||
|
"shape": out_var.shape,
|
||||||
|
"dtype": int(out_dtype),
|
||||||
|
"min": -limit,
|
||||||
|
"max": limit,
|
||||||
|
"seed": self._seed
|
||||||
|
},
|
||||||
|
stop_gradient=True)
|
||||||
|
|
||||||
|
else:
|
||||||
|
std = np.sqrt(2.0 / float(fan_in))
|
||||||
|
op = block.append_op(
|
||||||
|
type="gaussian_random",
|
||||||
|
outputs={"Out": out_var},
|
||||||
|
attrs={
|
||||||
|
"shape": out_var.shape,
|
||||||
|
"dtype": int(out_dtype),
|
||||||
|
"mean": 0.0,
|
||||||
|
"std": std,
|
||||||
|
"seed": self._seed
|
||||||
|
},
|
||||||
|
stop_gradient=True)
|
||||||
|
|
||||||
|
if var.dtype == VarDesc.VarType.FP16 or (
|
||||||
|
var.dtype == VarDesc.VarType.BF16 and not self._uniform):
|
||||||
|
block.append_op(
|
||||||
|
type="cast",
|
||||||
|
inputs={"X": out_var},
|
||||||
|
outputs={"Out": var},
|
||||||
|
attrs={"in_dtype": out_var.dtype,
|
||||||
|
"out_dtype": var.dtype})
|
||||||
|
|
||||||
|
if not framework.in_dygraph_mode():
|
||||||
|
var.op = op
|
||||||
|
return op
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultInitializerContext(object):
|
||||||
|
"""
|
||||||
|
egs:
|
||||||
|
with DefaultInitializerContext("kaiming_uniform"):
|
||||||
|
code for setup_model
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, init_type=None):
|
||||||
|
self.init_type = init_type
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if self.init_type is None:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
from paddlespeech.s2t.modules import align
|
||||||
|
align.global_init_type = self.init_type
|
||||||
|
return
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
from paddlespeech.s2t.modules import align
|
||||||
|
align.global_init_type = None
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue