[tts] add tts finetune example (#2297)

* add tts finetune example, test=tts

* fix finetune

Co-authored-by: TianYuan <white-sky@qq.com>
pull/2304/head
liangym 2 years ago committed by GitHub
parent 043b21d3b4
commit 1f100b1573
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,214 @@
# Finetune your own AM based on FastSpeech2 with AISHELL-3.
This example shows how to finetune your own AM based on FastSpeech2 with AISHELL-3. We use part of csmsc's data (top 200) as finetune data in this example. The example is implemented according to this [discussion](https://github.com/PaddlePaddle/PaddleSpeech/discussions/1842). Thanks to the developer for the idea.
We use AISHELL-3 to train a multi-speaker fastspeech2 model here. You can refer [examples/aishell3/tts3](https://github.com/lym0302/PaddleSpeech/tree/develop/examples/aishell3/tts3) to train multi-speaker fastspeech2 from scratch.
## Prepare
### Download Pretrained Fastspeech2 model
Assume the path to the model is `./pretrained_models`. Download pretrained fastspeech2 model with aishell3: [fastspeech2_aishell3_ckpt_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_aishell3_ckpt_1.1.0.zip).
```bash
mkdir -p pretrained_models && cd pretrained_models
wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_aishell3_ckpt_1.1.0.zip
unzip fastspeech2_aishell3_ckpt_1.1.0.zip
cd ../
```
### Download MFA tools and pretrained model
Assume the path to the MFA tool is `./tools`. Download [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner/releases/download/v1.0.1/montreal-forced-aligner_linux.tar.gz) and pretrained MFA models with aishell3: [aishell3_model.zip](https://paddlespeech.bj.bcebos.com/MFA/ernie_sat/aishell3_model.zip).
```bash
mkdir -p tools && cd tools
# mfa tool
wget https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner/releases/download/v1.0.1/montreal-forced-aligner_linux.tar.gz
tar xvf montreal-forced-aligner_linux.tar.gz
cp montreal-forced-aligner/lib/libpython3.6m.so.1.0 montreal-forced-aligner/lib/libpython3.6m.so
# pretrained mfa model
mkdir -p aligner && cd aligner
wget https://paddlespeech.bj.bcebos.com/MFA/ernie_sat/aishell3_model.zip
unzip aishell3_model.zip
wget https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/simple.lexicon
cd ../../
```
### Prepare your data
Assume the path to the dataset is `./input`. This directory contains audio files (*.wav) and label file (labels.txt). The audio file is in wav format. The format of the label file is: utt_id|pinyin. Here is an example of the first 200 data of csmsc.
```bash
mkdir -p input && cd input
wget https://paddlespeech.bj.bcebos.com/datasets/csmsc_mini.zip
unzip csmsc_mini.zip
cd ../
```
When "Prepare" done. The structure of the current directory is listed below.
```text
├── input
│ ├── csmsc_mini
│ │ ├── 000001.wav
│ │ ├── 000002.wav
│ │ ├── 000003.wav
│ │ ├── ...
│ │ ├── 000200.wav
│ │ ├── labels.txt
│ └── csmsc_mini.zip
├── pretrained_models
│ ├── fastspeech2_aishell3_ckpt_1.1.0
│ │ ├── default.yaml
│ │ ├── energy_stats.npy
│ │ ├── phone_id_map.txt
│ │ ├── pitch_stats.npy
│ │ ├── snapshot_iter_96400.pdz
│ │ ├── speaker_id_map.txt
│ │ └── speech_stats.npy
│ └── fastspeech2_aishell3_ckpt_1.1.0.zip
└── tools
├── aligner
│ ├── aishell3_model
│ ├── aishell3_model.zip
│ └── simple.lexicon
├── montreal-forced-aligner
│ ├── bin
│ ├── lib
│ └── pretrained_models
└── montreal-forced-aligner_linux.tar.gz
...
```
## Get Started
Run the command below to
1. **source path**.
2. finetune the model.
3. synthesize wavs.
- synthesize waveform from text file.
```bash
./run.sh
```
You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to run only one stage.
### Model Finetune
Finetune a FastSpeech2 model.
```bash
./run.sh --stage 0 --stop-stage 0
```
`stage 0` of `run.sh` calls `finetune.py`, here's the complete help message.
```text
usage: finetune.py [-h] [--input_dir INPUT_DIR] [--pretrained_model_dir PRETRAINED_MODEL_DIR]
[--mfa_dir MFA_DIR] [--dump_dir DUMP_DIR]
[--output_dir OUTPUT_DIR] [--lang LANG]
[--ngpu NGPU]
optional arguments:
-h, --help show this help message and exit
--input_dir INPUT_DIR
directory containing audio and label file
--pretrained_model_dir PRETRAINED_MODEL_DIR
Path to pretrained model
--mfa_dir MFA_DIR directory to save aligned files
--dump_dir DUMP_DIR
directory to save feature files and metadata
--output_dir OUTPUT_DIR
directory to save finetune model
--lang LANG Choose input audio language, zh or en
--ngpu NGPU if ngpu=0, use cpu
--epoch EPOCH the epoch of finetune
--batch_size BATCH_SIZE
the batch size of finetune, default -1 means same as pretrained model
```
1. `--input_dir` is the directory containing audio and label file.
2. `--pretrained_model_dir` is the directory incluing pretrained fastspeech2_aishell3 model.
3. `--mfa_dir` is the directory to save the results of aligning from pretrained MFA_aishell3 model.
4. `--dump_dir` is the directory including audio feature and metadata.
5. `--output_dir` is the directory to save finetune model.
6. `--lang` is the language of input audio, zh or en.
7. `--ngpu` is the number of gpu.
8. `--epoch` is the epoch of finetune.
9. `--batch_size` is the batch size of finetune.
### Synthesizing
We use [HiFiGAN](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc5) as the neural vocoder.
Assume the path to the hifigan model is `./pretrained_models`. Download the pretrained HiFiGAN model from [hifigan_aishell3_ckpt_0.2.0](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip) and unzip it.
```bash
cd pretrained_models
wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip
unzip hifigan_aishell3_ckpt_0.2.0.zip
cd ../
```
HiFiGAN checkpoint contains files listed below.
```text
hifigan_aishell3_ckpt_0.2.0
├── default.yaml # default config used to train HiFiGAN
├── feats_stats.npy # statistics used to normalize spectrogram when training HiFiGAN
└── snapshot_iter_2500000.pdz # generator parameters of HiFiGAN
```
Modify `ckpt` in `run.sh` to the final model in `exp/default/checkpoints`.
```bash
./run.sh --stage 1 --stop-stage 1
```
`stage 1` of `run.sh` calls `${BIN_DIR}/../synthesize_e2e.py`, which can synthesize waveform from text file.
```text
usage: synthesize_e2e.py [-h]
[--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}]
[--am_config AM_CONFIG] [--am_ckpt AM_CKPT]
[--am_stat AM_STAT] [--phones_dict PHONES_DICT]
[--tones_dict TONES_DICT]
[--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID]
[--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}]
[--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT]
[--voc_stat VOC_STAT] [--lang LANG]
[--inference_dir INFERENCE_DIR] [--ngpu NGPU]
[--text TEXT] [--output_dir OUTPUT_DIR]
Synthesize with acoustic model & vocoder
optional arguments:
-h, --help show this help message and exit
--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}
Choose acoustic model type of tts task.
--am_config AM_CONFIG
Config of acoustic model.
--am_ckpt AM_CKPT Checkpoint file of acoustic model.
--am_stat AM_STAT mean and standard deviation used to normalize
spectrogram when training acoustic model.
--phones_dict PHONES_DICT
phone vocabulary file.
--tones_dict TONES_DICT
tone vocabulary file.
--speaker_dict SPEAKER_DICT
speaker id map file.
--spk_id SPK_ID spk id for multi speaker acoustic model
--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}
Choose vocoder type of tts task.
--voc_config VOC_CONFIG
Config of voc.
--voc_ckpt VOC_CKPT Checkpoint file of voc.
--voc_stat VOC_STAT mean and standard deviation used to normalize
spectrogram when training voc.
--lang LANG Choose model language. zh or en
--inference_dir INFERENCE_DIR
dir to save inference models
--ngpu NGPU if ngpu == 0, use cpu.
--text TEXT text to synthesize, a 'utt_id sentence' pair per line.
--output_dir OUTPUT_DIR
output dir.
```
1. `--am` is acoustic model type with the format {model_name}_{dataset}
2. `--am_config`, `--am_ckpt`, `--am_stat`, `--phones_dict` `--speaker_dict` are arguments for acoustic model, which correspond to the 5 files in the fastspeech2 pretrained model.
3. `--voc` is vocoder type with the format {model_name}_{dataset}
4. `--voc_config`, `--voc_ckpt`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model.
5. `--lang` is the model language, which can be `zh` or `en`.
6. `--text` is the text file, which contains sentences to synthesize.
7. `--output_dir` is the directory to save synthesized audio files.
8. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
### Tips
If you want to get better audio quality, you can use more audios to finetune.

@ -0,0 +1,192 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from pathlib import Path
from typing import Union
import yaml
from paddle import distributed as dist
from yacs.config import CfgNode
from paddlespeech.t2s.exps.fastspeech2.train import train_sp
from local.check_oov import get_check_result
from local.extract import extract_feature
from local.label_process import get_single_label
from local.prepare_env import generate_finetune_env
from utils.gen_duration_from_textgrid import gen_duration_from_textgrid
DICT_EN = 'tools/aligner/cmudict-0.7b'
DICT_ZH = 'tools/aligner/simple.lexicon'
MODEL_DIR_EN = 'tools/aligner/vctk_model.zip'
MODEL_DIR_ZH = 'tools/aligner/aishell3_model.zip'
MFA_PHONE_EN = 'tools/aligner/vctk_model/meta.yaml'
MFA_PHONE_ZH = 'tools/aligner/aishell3_model/meta.yaml'
MFA_PATH = 'tools/montreal-forced-aligner/bin'
os.environ['PATH'] = MFA_PATH + '/:' + os.environ['PATH']
class TrainArgs():
def __init__(self, ngpu, config_file, dump_dir: Path, output_dir: Path):
self.config = str(config_file)
self.train_metadata = str(dump_dir / "train/norm/metadata.jsonl")
self.dev_metadata = str(dump_dir / "dev/norm/metadata.jsonl")
self.output_dir = str(output_dir)
self.ngpu = ngpu
self.phones_dict = str(dump_dir / "phone_id_map.txt")
self.speaker_dict = str(dump_dir / "speaker_id_map.txt")
self.voice_cloning = False
def get_mfa_result(
input_dir: Union[str, Path],
mfa_dir: Union[str, Path],
lang: str='en', ):
"""get mfa result
Args:
input_dir (Union[str, Path]): input dir including wav file and label
mfa_dir (Union[str, Path]): mfa result dir
lang (str, optional): input audio language. Defaults to 'en'.
"""
# MFA
if lang == 'en':
DICT = DICT_EN
MODEL_DIR = MODEL_DIR_EN
elif lang == 'zh':
DICT = DICT_ZH
MODEL_DIR = MODEL_DIR_ZH
else:
print('please input right lang!!')
CMD = 'mfa_align' + ' ' + str(
input_dir) + ' ' + DICT + ' ' + MODEL_DIR + ' ' + str(mfa_dir)
os.system(CMD)
if __name__ == '__main__':
# parse config and args
parser = argparse.ArgumentParser(
description="Preprocess audio and then extract features.")
parser.add_argument(
"--input_dir",
type=str,
default="./input/baker_mini",
help="directory containing audio and label file")
parser.add_argument(
"--pretrained_model_dir",
type=str,
default="./pretrained_models/fastspeech2_aishell3_ckpt_1.1.0",
help="Path to pretrained model")
parser.add_argument(
"--mfa_dir",
type=str,
default="./mfa_result",
help="directory to save aligned files")
parser.add_argument(
"--dump_dir",
type=str,
default="./dump",
help="directory to save feature files and metadata.")
parser.add_argument(
"--output_dir",
type=str,
default="./exp/default/",
help="directory to save finetune model.")
parser.add_argument(
'--lang',
type=str,
default='zh',
choices=['zh', 'en'],
help='Choose input audio language. zh or en')
parser.add_argument(
"--ngpu", type=int, default=2, help="if ngpu=0, use cpu.")
parser.add_argument("--epoch", type=int, default=100, help="finetune epoch")
parser.add_argument(
"--batch_size",
type=int,
default=-1,
help="batch size, default -1 means same as pretrained model")
args = parser.parse_args()
fs = 24000
n_shift = 300
input_dir = Path(args.input_dir).expanduser()
mfa_dir = Path(args.mfa_dir).expanduser()
mfa_dir.mkdir(parents=True, exist_ok=True)
dump_dir = Path(args.dump_dir).expanduser()
dump_dir.mkdir(parents=True, exist_ok=True)
output_dir = Path(args.output_dir).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
pretrained_model_dir = Path(args.pretrained_model_dir).expanduser()
# read config
config_file = pretrained_model_dir / "default.yaml"
with open(config_file) as f:
config = CfgNode(yaml.safe_load(f))
config.max_epoch = config.max_epoch + args.epoch
if args.batch_size > 0:
config.batch_size = args.batch_size
if args.lang == 'en':
lexicon_file = DICT_EN
mfa_phone_file = MFA_PHONE_EN
elif args.lang == 'zh':
lexicon_file = DICT_ZH
mfa_phone_file = MFA_PHONE_ZH
else:
print('please input right lang!!')
am_phone_file = pretrained_model_dir / "phone_id_map.txt"
label_file = input_dir / "labels.txt"
#check phone for mfa and am finetune
oov_words, oov_files, oov_file_words = get_check_result(
label_file, lexicon_file, mfa_phone_file, am_phone_file)
input_dir = get_single_label(label_file, oov_files, input_dir)
# get mfa result
get_mfa_result(input_dir, mfa_dir, args.lang)
# # generate durations.txt
duration_file = "./durations.txt"
gen_duration_from_textgrid(mfa_dir, duration_file, fs, n_shift)
# generate phone and speaker map files
extract_feature(duration_file, config, input_dir, dump_dir,
pretrained_model_dir)
# create finetune env
generate_finetune_env(output_dir, pretrained_model_dir)
# create a new args for training
train_args = TrainArgs(args.ngpu, config_file, dump_dir, output_dir)
# finetune models
# dispatch
if args.ngpu > 1:
dist.spawn(train_sp, (train_args, config), nprocs=args.ngpu)
else:
train_sp(train_args, config)

@ -0,0 +1,125 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Dict
from typing import List
from typing import Union
def check_phone(label_file: Union[str, Path],
pinyin_phones: Dict[str, str],
mfa_phones: List[str],
am_phones: List[str],
oov_record: str="./oov_info.txt"):
"""Check whether the phoneme corresponding to the audio text content
is in the phoneme list of the pretrained mfa model to ensure that the alignment is normal.
Check whether the phoneme corresponding to the audio text content
is in the phoneme list of the pretrained am model to ensure finetune (normalize) is normal.
Args:
label_file (Union[str, Path]): label file, format: utt_id|phone seq
pinyin_phones (dict): pinyin to phones map dict
mfa_phones (list): the phone list of pretrained mfa model
am_phones (list): the phone list of pretrained mfa model
Returns:
oov_words (list): oov words
oov_files (list): utt id list that exist oov
oov_file_words (dict): the oov file and oov phone in this file
"""
oov_words = []
oov_files = []
oov_file_words = {}
with open(label_file, "r") as f:
for line in f.readlines():
utt_id = line.split("|")[0]
transcription = line.strip().split("|")[1]
flag = 0
temp_oov_words = []
for word in transcription.split(" "):
if word not in pinyin_phones.keys():
temp_oov_words.append(word)
flag = 1
if word not in oov_words:
oov_words.append(word)
else:
for p in pinyin_phones[word]:
if p not in mfa_phones or p not in am_phones:
temp_oov_words.append(word)
flag = 1
if word not in oov_words:
oov_words.append(word)
if flag == 1:
oov_files.append(utt_id)
oov_file_words[utt_id] = temp_oov_words
if oov_record is not None:
with open(oov_record, "w") as fw:
fw.write("oov_words: " + str(oov_words) + "\n")
fw.write("oov_files: " + str(oov_files) + "\n")
fw.write("oov_file_words: " + str(oov_file_words) + "\n")
return oov_words, oov_files, oov_file_words
def get_pinyin_phones(lexicon_file: Union[str, Path]):
# pinyin to phones
pinyin_phones = {}
with open(lexicon_file, "r") as f2:
for line in f2.readlines():
line_list = line.strip().split(" ")
pinyin = line_list[0]
if line_list[1] == '':
phones = line_list[2:]
else:
phones = line_list[1:]
pinyin_phones[pinyin] = phones
return pinyin_phones
def get_mfa_phone(mfa_phone_file: Union[str, Path]):
# get phones from pretrained mfa model (meta.yaml)
mfa_phones = []
with open(mfa_phone_file, "r") as f:
for line in f.readlines():
if line.startswith("-"):
phone = line.strip().split(" ")[-1]
mfa_phones.append(phone)
return mfa_phones
def get_am_phone(am_phone_file: Union[str, Path]):
# get phones from pretrained am model (phone_id_map.txt)
am_phones = []
with open(am_phone_file, "r") as f:
for line in f.readlines():
phone = line.strip().split(" ")[0]
am_phones.append(phone)
return am_phones
def get_check_result(label_file: Union[str, Path],
lexicon_file: Union[str, Path],
mfa_phone_file: Union[str, Path],
am_phone_file: Union[str, Path]):
pinyin_phones = get_pinyin_phones(lexicon_file)
mfa_phones = get_mfa_phone(mfa_phone_file)
am_phones = get_am_phone(am_phone_file)
oov_words, oov_files, oov_file_words = check_phone(
label_file, pinyin_phones, mfa_phones, am_phones)
return oov_words, oov_files, oov_file_words

@ -0,0 +1,287 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import os
from operator import itemgetter
from pathlib import Path
from typing import Dict
from typing import Union
import jsonlines
import numpy as np
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.datasets.get_feats import Energy
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from paddlespeech.t2s.datasets.get_feats import Pitch
from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur
from paddlespeech.t2s.datasets.preprocess_utils import merge_silence
from paddlespeech.t2s.exps.fastspeech2.preprocess import process_sentences
def read_stats(stats_file: Union[str, Path]):
scaler = StandardScaler()
scaler.mean_ = np.load(stats_file)[0]
scaler.scale_ = np.load(stats_file)[1]
scaler.n_features_in_ = scaler.mean_.shape[0]
return scaler
def get_stats(pretrained_model_dir: Path):
speech_stats_file = pretrained_model_dir / "speech_stats.npy"
pitch_stats_file = pretrained_model_dir / "pitch_stats.npy"
energy_stats_file = pretrained_model_dir / "energy_stats.npy"
speech_scaler = read_stats(speech_stats_file)
pitch_scaler = read_stats(pitch_stats_file)
energy_scaler = read_stats(energy_stats_file)
return speech_scaler, pitch_scaler, energy_scaler
def get_map(duration_file: Union[str, Path],
dump_dir: Path,
pretrained_model_dir: Path):
"""get phone map and speaker map, save on dump_dir
Args:
duration_file (str): durantions.txt
dump_dir (Path): dump dir
pretrained_model_dir (Path): pretrained model dir
"""
# copy phone map file from pretrained model path
phones_dict = dump_dir / "phone_id_map.txt"
os.system("cp %s %s" %
(pretrained_model_dir / "phone_id_map.txt", phones_dict))
# create a new speaker map file, replace the previous speakers.
sentences, speaker_set = get_phn_dur(duration_file)
merge_silence(sentences)
speakers = sorted(list(speaker_set))
num = len(speakers)
speaker_dict = dump_dir / "speaker_id_map.txt"
with open(speaker_dict, 'w') as f, open(pretrained_model_dir /
"speaker_id_map.txt", 'r') as fr:
for i, spk in enumerate(speakers):
f.write(spk + ' ' + str(i) + '\n')
for line in fr.readlines():
spk_id = line.strip().split(" ")[-1]
if int(spk_id) >= num:
f.write(line)
vocab_phones = {}
with open(phones_dict, 'rt') as f:
phn_id = [line.strip().split() for line in f.readlines()]
for phn, id in phn_id:
vocab_phones[phn] = int(id)
vocab_speaker = {}
with open(speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
for spk, id in spk_id:
vocab_speaker[spk] = int(id)
return sentences, vocab_phones, vocab_speaker
def get_extractor(config):
# Extractor
mel_extractor = LogMelFBank(
sr=config.fs,
n_fft=config.n_fft,
hop_length=config.n_shift,
win_length=config.win_length,
window=config.window,
n_mels=config.n_mels,
fmin=config.fmin,
fmax=config.fmax)
pitch_extractor = Pitch(
sr=config.fs,
hop_length=config.n_shift,
f0min=config.f0min,
f0max=config.f0max)
energy_extractor = Energy(
n_fft=config.n_fft,
hop_length=config.n_shift,
win_length=config.win_length,
window=config.window)
return mel_extractor, pitch_extractor, energy_extractor
def normalize(speech_scaler,
pitch_scaler,
energy_scaler,
vocab_phones: Dict,
vocab_speaker: Dict,
raw_dump_dir: Path,
type: str):
dumpdir = raw_dump_dir / type / "norm"
dumpdir = Path(dumpdir).expanduser()
dumpdir.mkdir(parents=True, exist_ok=True)
# get dataset
metadata_file = raw_dump_dir / type / "raw" / "metadata.jsonl"
with jsonlines.open(metadata_file, 'r') as reader:
metadata = list(reader)
dataset = DataTable(
metadata,
converters={
"speech": np.load,
"pitch": np.load,
"energy": np.load,
})
logging.info(f"The number of files = {len(dataset)}.")
# process each file
output_metadata = []
for item in tqdm(dataset):
utt_id = item['utt_id']
speech = item['speech']
pitch = item['pitch']
energy = item['energy']
# normalize
speech = speech_scaler.transform(speech)
speech_dir = dumpdir / "data_speech"
speech_dir.mkdir(parents=True, exist_ok=True)
speech_path = speech_dir / f"{utt_id}_speech.npy"
np.save(speech_path, speech.astype(np.float32), allow_pickle=False)
pitch = pitch_scaler.transform(pitch)
pitch_dir = dumpdir / "data_pitch"
pitch_dir.mkdir(parents=True, exist_ok=True)
pitch_path = pitch_dir / f"{utt_id}_pitch.npy"
np.save(pitch_path, pitch.astype(np.float32), allow_pickle=False)
energy = energy_scaler.transform(energy)
energy_dir = dumpdir / "data_energy"
energy_dir.mkdir(parents=True, exist_ok=True)
energy_path = energy_dir / f"{utt_id}_energy.npy"
np.save(energy_path, energy.astype(np.float32), allow_pickle=False)
phone_ids = [vocab_phones[p] for p in item['phones']]
spk_id = vocab_speaker[item["speaker"]]
record = {
"utt_id": item['utt_id'],
"spk_id": spk_id,
"text": phone_ids,
"text_lengths": item['text_lengths'],
"speech_lengths": item['speech_lengths'],
"durations": item['durations'],
"speech": str(speech_path),
"pitch": str(pitch_path),
"energy": str(energy_path)
}
# add spk_emb for voice cloning
if "spk_emb" in item:
record["spk_emb"] = str(item["spk_emb"])
output_metadata.append(record)
output_metadata.sort(key=itemgetter('utt_id'))
output_metadata_path = Path(dumpdir) / "metadata.jsonl"
with jsonlines.open(output_metadata_path, 'w') as writer:
for item in output_metadata:
writer.write(item)
logging.info(f"metadata dumped into {output_metadata_path}")
def extract_feature(duration_file: str,
config,
input_dir: Path,
dump_dir: Path,
pretrained_model_dir: Path):
sentences, vocab_phones, vocab_speaker = get_map(duration_file, dump_dir,
pretrained_model_dir)
mel_extractor, pitch_extractor, energy_extractor = get_extractor(config)
wav_files = sorted(list((input_dir).rglob("*.wav")))
# split data into 3 sections, train: 80%, dev: 10%, test: 10%
num_train = math.ceil(len(wav_files) * 0.8)
num_dev = math.ceil(len(wav_files) * 0.1)
print(num_train, num_dev)
train_wav_files = wav_files[:num_train]
dev_wav_files = wav_files[num_train:num_train + num_dev]
test_wav_files = wav_files[num_train + num_dev:]
train_dump_dir = dump_dir / "train" / "raw"
train_dump_dir.mkdir(parents=True, exist_ok=True)
dev_dump_dir = dump_dir / "dev" / "raw"
dev_dump_dir.mkdir(parents=True, exist_ok=True)
test_dump_dir = dump_dir / "test" / "raw"
test_dump_dir.mkdir(parents=True, exist_ok=True)
# process for the 3 sections
num_cpu = 4
cut_sil = True
spk_emb_dir = None
write_metadata_method = "w"
speech_scaler, pitch_scaler, energy_scaler = get_stats(pretrained_model_dir)
if train_wav_files:
process_sentences(
config=config,
fps=train_wav_files,
sentences=sentences,
output_dir=train_dump_dir,
mel_extractor=mel_extractor,
pitch_extractor=pitch_extractor,
energy_extractor=energy_extractor,
nprocs=num_cpu,
cut_sil=cut_sil,
spk_emb_dir=spk_emb_dir,
write_metadata_method=write_metadata_method)
# norm
normalize(speech_scaler, pitch_scaler, energy_scaler, vocab_phones,
vocab_speaker, dump_dir, "train")
if dev_wav_files:
process_sentences(
config=config,
fps=dev_wav_files,
sentences=sentences,
output_dir=dev_dump_dir,
mel_extractor=mel_extractor,
pitch_extractor=pitch_extractor,
energy_extractor=energy_extractor,
nprocs=num_cpu,
cut_sil=cut_sil,
spk_emb_dir=spk_emb_dir,
write_metadata_method=write_metadata_method)
# norm
normalize(speech_scaler, pitch_scaler, energy_scaler, vocab_phones,
vocab_speaker, dump_dir, "dev")
if test_wav_files:
process_sentences(
config=config,
fps=test_wav_files,
sentences=sentences,
output_dir=test_dump_dir,
mel_extractor=mel_extractor,
pitch_extractor=pitch_extractor,
energy_extractor=energy_extractor,
nprocs=num_cpu,
cut_sil=cut_sil,
spk_emb_dir=spk_emb_dir,
write_metadata_method=write_metadata_method)
# norm
normalize(speech_scaler, pitch_scaler, energy_scaler, vocab_phones,
vocab_speaker, dump_dir, "test")

@ -0,0 +1,63 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pathlib import Path
from typing import List
from typing import Union
def change_baker_label(baker_label_file: Union[str, Path],
out_label_file: Union[str, Path]):
"""change baker label file to regular label file
Args:
baker_label_file (Union[str, Path]): Original baker label file
out_label_file (Union[str, Path]): regular label file
"""
with open(baker_label_file) as f:
lines = f.readlines()
with open(out_label_file, "w") as fw:
for i in range(0, len(lines), 2):
utt_id = lines[i].split()[0]
transcription = lines[i + 1].strip()
fw.write(utt_id + "|" + transcription + "\n")
def get_single_label(label_file: Union[str, Path],
oov_files: List[Union[str, Path]],
input_dir: Union[str, Path]):
"""Divide the label file into individual files according to label_file
Args:
label_file (str or Path): label file, format: utt_id|phones id
input_dir (Path): input dir including audios
"""
input_dir = Path(input_dir).expanduser()
new_dir = input_dir / "newdir"
new_dir.mkdir(parents=True, exist_ok=True)
with open(label_file, "r") as f:
for line in f.readlines():
utt_id = line.split("|")[0]
if utt_id not in oov_files:
transcription = line.split("|")[1].strip()
wav_file = str(input_dir) + "/" + utt_id + ".wav"
new_wav_file = str(new_dir) + "/" + utt_id + ".wav"
os.system("cp %s %s" % (wav_file, new_wav_file))
single_file = str(new_dir) + "/" + utt_id + ".txt"
with open(single_file, "w") as fw:
fw.write(transcription)
return new_dir

@ -0,0 +1,35 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pathlib import Path
def generate_finetune_env(output_dir: Path, pretrained_model_dir: Path):
output_dir = output_dir / "checkpoints/"
output_dir = output_dir.resolve()
output_dir.mkdir(parents=True, exist_ok=True)
model_path = sorted(list((pretrained_model_dir).rglob("*.pdz")))[0]
model_path = model_path.resolve()
iter = int(str(model_path).split("_")[-1].split(".")[0])
model_file = str(model_path).split("/")[-1]
os.system("cp %s %s" % (model_path, output_dir))
records_file = output_dir / "records.jsonl"
with open(records_file, "w") as f:
line = "\"time\": \"2022-08-06 07:51:53.463650\", \"path\": \"%s\", \"iteration\": %d" % (
str(output_dir / model_file), iter)
f.write("{" + line + "}" + "\n")

@ -0,0 +1,13 @@
#!/bin/bash
export MAIN_ROOT=`realpath ${PWD}/../../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
MODEL=fastspeech2
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}

@ -0,0 +1,61 @@
#!/bin/bash
set -e
source path.sh
input_dir=./input/csmsc_mini
pretrained_model_dir=./pretrained_models/fastspeech2_aishell3_ckpt_1.1.0
mfa_dir=./mfa_result
dump_dir=./dump
output_dir=./exp/default
lang=zh
ngpu=2
ckpt=snapshot_iter_96600
gpus=0,1
CUDA_VISIBLE_DEVICES=${gpus}
stage=0
stop_stage=100
# with the following command, you can choose the stage range you want to run
# such as `./run.sh --stage 0 --stop-stage 0`
# this can not be mixed use with `$1`, `$2` ...
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# finetune
python3 finetune.py \
--input_dir=${input_dir} \
--pretrained_model_dir=${pretrained_model_dir} \
--mfa_dir=${mfa_dir} \
--dump_dir=${dump_dir} \
--output_dir=${output_dir} \
--lang=${lang} \
--ngpu=${ngpu} \
--epoch=100
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "in hifigan syn_e2e"
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize_e2e.py \
--am=fastspeech2_aishell3 \
--am_config=${pretrained_model_dir}/default.yaml \
--am_ckpt=${output_dir}/checkpoints/${ckpt}.pdz \
--am_stat=${pretrained_model_dir}/speech_stats.npy \
--voc=hifigan_aishell3 \
--voc_config=pretrained_models/hifigan_aishell3_ckpt_0.2.0/default.yaml \
--voc_ckpt=pretrained_models/hifigan_aishell3_ckpt_0.2.0/snapshot_iter_2500000.pdz \
--voc_stat=pretrained_models/hifigan_aishell3_ckpt_0.2.0/feats_stats.npy \
--lang=zh \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=./test_e2e \
--phones_dict=${dump_dir}/phone_id_map.txt \
--speaker_dict=${dump_dir}/speaker_id_map.txt \
--spk_id=0
fi
Loading…
Cancel
Save