From a232cd8b12ed29444a50b7b3fa1c7e3a7f965860 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Wed, 29 Dec 2021 17:35:46 +0800 Subject: [PATCH 1/7] Update fastspeech2.py --- paddlespeech/t2s/models/fastspeech2/fastspeech2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py index a5fb7fab..295a2e4b 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py @@ -940,7 +940,8 @@ class StyleFastSpeech2Inference(FastSpeech2Inference): Tensor Output sequence of features (L, odim). """ - spk_id = paddle.to_tensor(spk_id) + if spk_id: + spk_id = paddle.to_tensor(spk_id) normalized_mel, d_outs, p_outs, e_outs = self.acoustic_model.inference( text, durations=None, From fbe3c05137feccf27a07fdb22d59bdd0318ca521 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 30 Dec 2021 10:29:12 +0800 Subject: [PATCH 2/7] add style_melgan and hifigan in tts cli, test=tts (#1241) --- paddlespeech/cli/tts/infer.py | 82 ++++++++++++++----- paddlespeech/t2s/exps/synthesize_e2e.py | 42 ++++++---- paddlespeech/t2s/frontend/phonectic.py | 49 ++++++++--- .../zh_normalization/text_normlization.py | 5 +- 4 files changed, 128 insertions(+), 50 deletions(-) diff --git a/paddlespeech/cli/tts/infer.py b/paddlespeech/cli/tts/infer.py index f60f4224..c934d595 100644 --- a/paddlespeech/cli/tts/infer.py +++ b/paddlespeech/cli/tts/infer.py @@ -178,6 +178,32 @@ pretrained_models = { 'speech_stats': 'feats_stats.npy', }, + # style_melgan + "style_melgan_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip', + 'md5': + '5de2d5348f396de0c966926b8c462755', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_1500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + # hifigan + "hifigan_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip', + 'md5': + 'dd40a3d88dfcf64513fba2f0f961ada6', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_2500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, } model_alias = { @@ -199,6 +225,14 @@ model_alias = { "paddlespeech.t2s.models.melgan:MelGANGenerator", "mb_melgan_inference": "paddlespeech.t2s.models.melgan:MelGANInference", + "style_melgan": + "paddlespeech.t2s.models.melgan:StyleMelGANGenerator", + "style_melgan_inference": + "paddlespeech.t2s.models.melgan:StyleMelGANInference", + "hifigan": + "paddlespeech.t2s.models.hifigan:HiFiGANGenerator", + "hifigan_inference": + "paddlespeech.t2s.models.hifigan:HiFiGANInference", } @@ -266,7 +300,7 @@ class TTSExecutor(BaseExecutor): default='pwgan_csmsc', choices=[ 'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk', - 'mb_melgan_csmsc' + 'mb_melgan_csmsc', 'style_melgan_csmsc', 'hifigan_csmsc' ], help='Choose vocoder type of tts task.') @@ -504,37 +538,47 @@ class TTSExecutor(BaseExecutor): am_name = am[:am.rindex('_')] am_dataset = am[am.rindex('_') + 1:] get_tone_ids = False + merge_sentences = False if am_name == 'speedyspeech': get_tone_ids = True if lang == 'zh': input_ids = self.frontend.get_input_ids( - text, merge_sentences=True, get_tone_ids=get_tone_ids) + text, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids) phone_ids = input_ids["phone_ids"] - phone_ids = phone_ids[0] if get_tone_ids: tone_ids = input_ids["tone_ids"] - tone_ids = tone_ids[0] elif lang == 'en': - input_ids = self.frontend.get_input_ids(text) + input_ids = self.frontend.get_input_ids( + text, merge_sentences=merge_sentences) phone_ids = input_ids["phone_ids"] else: print("lang should in {'zh', 'en'}!") - # am - if am_name == 'speedyspeech': - mel = self.am_inference(phone_ids, tone_ids) - # fastspeech2 - else: - # multi speaker - if am_dataset in {"aishell3", "vctk"}: - mel = self.am_inference( - phone_ids, spk_id=paddle.to_tensor(spk_id)) + flags = 0 + for i in range(len(phone_ids)): + part_phone_ids = phone_ids[i] + # am + if am_name == 'speedyspeech': + part_tone_ids = tone_ids[i] + mel = self.am_inference(part_phone_ids, part_tone_ids) + # fastspeech2 else: - mel = self.am_inference(phone_ids) - - # voc - wav = self.voc_inference(mel) - self._outputs['wav'] = wav + # multi speaker + if am_dataset in {"aishell3", "vctk"}: + mel = self.am_inference( + part_phone_ids, spk_id=paddle.to_tensor(spk_id)) + else: + mel = self.am_inference(part_phone_ids) + # voc + wav = self.voc_inference(mel) + if flags == 0: + wav_all = wav + flags = 1 + else: + wav_all = paddle.concat([wav_all, wav]) + self._outputs['wav'] = wav_all def postprocess(self, output: str='output.wav') -> Union[str, os.PathLike]: """ diff --git a/paddlespeech/t2s/exps/synthesize_e2e.py b/paddlespeech/t2s/exps/synthesize_e2e.py index 9a83ec1b..fc822b21 100644 --- a/paddlespeech/t2s/exps/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/synthesize_e2e.py @@ -196,41 +196,47 @@ def evaluate(args): output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) - + merge_sentences = False for utt_id, sentence in sentences: get_tone_ids = False if am_name == 'speedyspeech': get_tone_ids = True if args.lang == 'zh': input_ids = frontend.get_input_ids( - sentence, merge_sentences=True, get_tone_ids=get_tone_ids) + sentence, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) phone_ids = input_ids["phone_ids"] - phone_ids = phone_ids[0] if get_tone_ids: tone_ids = input_ids["tone_ids"] - tone_ids = tone_ids[0] elif args.lang == 'en': - input_ids = frontend.get_input_ids(sentence) + input_ids = frontend.get_input_ids(sentence, merge_sentences=merge_sentences) phone_ids = input_ids["phone_ids"] else: print("lang should in {'zh', 'en'}!") - with paddle.no_grad(): - # acoustic model - if am_name == 'fastspeech2': - # multi speaker - if am_dataset in {"aishell3", "vctk"}: - spk_id = paddle.to_tensor(args.spk_id) - mel = am_inference(phone_ids, spk_id) + flags = 0 + for i in range(len(phone_ids)): + part_phone_ids = phone_ids[i] + # acoustic model + if am_name == 'fastspeech2': + # multi speaker + if am_dataset in {"aishell3", "vctk"}: + spk_id = paddle.to_tensor(args.spk_id) + mel = am_inference(part_phone_ids, spk_id) + else: + mel = am_inference(part_phone_ids) + elif am_name == 'speedyspeech': + part_tone_ids = tone_ids[i] + mel = am_inference(part_phone_ids, part_tone_ids) + # vocoder + wav = voc_inference(mel) + if flags == 0: + wav_all = wav + flags = 1 else: - mel = am_inference(phone_ids) - elif am_name == 'speedyspeech': - mel = am_inference(phone_ids, tone_ids) - # vocoder - wav = voc_inference(mel) + wav_all = paddle.concat([wav_all, wav]) sf.write( str(output_dir / (utt_id + ".wav")), - wav.numpy(), + wav_all.numpy(), samplerate=am_config.fs) print(f"{utt_id} done!") diff --git a/paddlespeech/t2s/frontend/phonectic.py b/paddlespeech/t2s/frontend/phonectic.py index fbc8fd38..25413871 100644 --- a/paddlespeech/t2s/frontend/phonectic.py +++ b/paddlespeech/t2s/frontend/phonectic.py @@ -13,7 +13,9 @@ # limitations under the License. from abc import ABC from abc import abstractmethod +from typing import List +import numpy as np import paddle from g2p_en import G2p from g2pM import G2pM @@ -21,6 +23,7 @@ from g2pM import G2pM from paddlespeech.t2s.frontend.normalizer.normalizer import normalize from paddlespeech.t2s.frontend.punctuation import get_punctuations from paddlespeech.t2s.frontend.vocab import Vocab +from paddlespeech.t2s.frontend.zh_normalization.text_normlization import TextNormalizer # discard opencc untill we find an easy solution to install it on windows # from opencc import OpenCC @@ -53,6 +56,7 @@ class English(Phonetics): self.vocab = Vocab(self.phonemes + self.punctuations) self.vocab_phones = {} self.punc = ":,;。?!“”‘’':,;.?!" + self.text_normalizer = TextNormalizer() if phone_vocab_path: with open(phone_vocab_path, 'rt') as f: phn_id = [line.strip().split() for line in f.readlines()] @@ -78,19 +82,42 @@ class English(Phonetics): phonemes = [item for item in phonemes if item in self.vocab.stoi] return phonemes - def get_input_ids(self, sentence: str) -> paddle.Tensor: - result = {} - phones = self.phoneticize(sentence) - # remove start_symbol and end_symbol - phones = phones[1:-1] - phones = [phn for phn in phones if not phn.isspace()] - phones = [ + def _p2id(self, phonemes: List[str]) -> np.array: + # replace unk phone with sp + phonemes = [ phn if (phn in self.vocab_phones and phn not in self.punc) else "sp" - for phn in phones + for phn in phonemes ] - phone_ids = [self.vocab_phones[phn] for phn in phones] - phone_ids = paddle.to_tensor(phone_ids) - result["phone_ids"] = phone_ids + phone_ids = [self.vocab_phones[item] for item in phonemes] + return np.array(phone_ids, np.int64) + + def get_input_ids(self, sentence: str, + merge_sentences: bool=False) -> paddle.Tensor: + result = {} + sentences = self.text_normalizer._split(sentence, lang="en") + phones_list = [] + temp_phone_ids = [] + for sentence in sentences: + phones = self.phoneticize(sentence) + # remove start_symbol and end_symbol + phones = phones[1:-1] + phones = [phn for phn in phones if not phn.isspace()] + phones_list.append(phones) + + if merge_sentences: + merge_list = sum(phones_list, []) + # rm the last 'sp' to avoid the noise at the end + # cause in the training data, no 'sp' in the end + if merge_list[-1] == 'sp': + merge_list = merge_list[:-1] + phones_list = [] + phones_list.append(merge_list) + + for part_phones_list in phones_list: + phone_ids = self._p2id(part_phones_list) + phone_ids = paddle.to_tensor(phone_ids) + temp_phone_ids.append(phone_ids) + result["phone_ids"] = temp_phone_ids return result def numericalize(self, phonemes): diff --git a/paddlespeech/t2s/frontend/zh_normalization/text_normlization.py b/paddlespeech/t2s/frontend/zh_normalization/text_normlization.py index c68caeeb..c502d882 100644 --- a/paddlespeech/t2s/frontend/zh_normalization/text_normlization.py +++ b/paddlespeech/t2s/frontend/zh_normalization/text_normlization.py @@ -53,7 +53,7 @@ class TextNormalizer(): def __init__(self): self.SENTENCE_SPLITOR = re.compile(r'([:,;。?!,;?!][”’]?)') - def _split(self, text: str) -> List[str]: + def _split(self, text: str, lang="zh") -> List[str]: """Split long text into sentences with sentence-splitting punctuations. Parameters ---------- @@ -65,7 +65,8 @@ class TextNormalizer(): Sentences. """ # Only for pure Chinese here - text = text.replace(" ", "") + if lang == "zh": + text = text.replace(" ", "") text = self.SENTENCE_SPLITOR.sub(r'\1\n', text) text = text.strip() sentences = [sentence.strip() for sentence in re.split(r'\n+', text)] From 420709e5ce66d4bf7ea94b0888d9d64f4a9dbd7c Mon Sep 17 00:00:00 2001 From: Junkun Chen Date: Wed, 29 Dec 2021 18:47:42 -0800 Subject: [PATCH 3/7] [st] Distributed sampler and new dataloader with MIMO (#1239) * update timit result, test=doc_fix * result update * fix bug * add triplet loader * empty preprocess file * sync to u2, updating * sync to u2 config * fix bugs * code refine * update config * customize decoding batch size * update optimizer and lr scheduler * minor * minor * minor * fix bugs of refs * minor * distributed sampler * minor * refine the loader --- examples/ted_en_zh/st0/conf/transformer.yaml | 17 +- .../st0/conf/transformer_mtl_noam.yaml | 6 +- examples/ted_en_zh/st0/local/test.sh | 2 - examples/ted_en_zh/st1/RESULTS.md | 2 +- examples/ted_en_zh/st1/conf/transformer.yaml | 54 ++- .../st1/conf/transformer_mtl_noam.yaml | 42 +-- examples/ted_en_zh/st1/local/test.sh | 2 - paddlespeech/s2t/exps/u2_st/model.py | 317 +++++++++--------- paddlespeech/s2t/io/converter.py | 78 +++-- paddlespeech/s2t/io/dataloader.py | 19 +- paddlespeech/s2t/io/reader.py | 9 +- 11 files changed, 292 insertions(+), 256 deletions(-) diff --git a/examples/ted_en_zh/st0/conf/transformer.yaml b/examples/ted_en_zh/st0/conf/transformer.yaml index 36f287b1..8afb107b 100644 --- a/examples/ted_en_zh/st0/conf/transformer.yaml +++ b/examples/ted_en_zh/st0/conf/transformer.yaml @@ -1,6 +1,6 @@ # https://yaml.org/type/float.html data: - train_manifest: data/manifest.train.tiny + train_manifest: data/manifest.train dev_manifest: data/manifest.dev test_manifest: data/manifest.test min_input_len: 0.05 # second @@ -15,8 +15,10 @@ collator: unit_type: 'spm' spm_model_prefix: data/lang_char/bpe_unigram_8000 mean_std_filepath: "" - # augmentation_config: conf/augmentation.json - batch_size: 10 + augmentation_config: conf/preprocess.yaml + batch_size: 16 + maxlen_in: 5 # if input length > maxlen-in, batchsize is automatically reduced + maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced raw_wav: True # use raw_wav or kaldi feature spectrum_type: fbank #linear, mfcc, fbank feat_dim: 80 @@ -78,13 +80,13 @@ training: global_grad_clip: 5.0 optim: adam optim_conf: - lr: 0.004 + lr: 2.5 weight_decay: 1e-06 - scheduler: warmuplr + scheduler: noam scheduler_conf: warmup_steps: 25000 lr_decay: 1.0 - log_interval: 5 + log_interval: 50 checkpoint: kbest_n: 50 latest_n: 5 @@ -97,6 +99,7 @@ decoding: alpha: 2.5 beta: 0.3 beam_size: 10 + word_reward: 0.7 cutoff_prob: 1.0 cutoff_top_n: 0 num_proc_bsearch: 8 @@ -107,3 +110,5 @@ decoding: # 0: used for training, it's prohibited here. num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. simulate_streaming: False # simulate streaming inference. Defaults to False. + + diff --git a/examples/ted_en_zh/st0/conf/transformer_mtl_noam.yaml b/examples/ted_en_zh/st0/conf/transformer_mtl_noam.yaml index 78887d3c..017230fe 100644 --- a/examples/ted_en_zh/st0/conf/transformer_mtl_noam.yaml +++ b/examples/ted_en_zh/st0/conf/transformer_mtl_noam.yaml @@ -15,8 +15,10 @@ collator: unit_type: 'spm' spm_model_prefix: data/lang_char/bpe_unigram_8000 mean_std_filepath: "" - # augmentation_config: conf/augmentation.json - batch_size: 10 + augmentation_config: conf/preprocess.yaml + batch_size: 16 + maxlen_in: 5 # if input length > maxlen-in, batchsize is automatically reduced + maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced raw_wav: True # use raw_wav or kaldi feature spectrum_type: fbank #linear, mfcc, fbank feat_dim: 80 diff --git a/examples/ted_en_zh/st0/local/test.sh b/examples/ted_en_zh/st0/local/test.sh index a9b18dd9..0796a06e 100755 --- a/examples/ted_en_zh/st0/local/test.sh +++ b/examples/ted_en_zh/st0/local/test.sh @@ -13,14 +13,12 @@ ckpt_prefix=$2 for type in fullsentence; do echo "decoding ${type}" - batch_size=32 python3 -u ${BIN_DIR}/test.py \ --ngpu ${ngpu} \ --config ${config_path} \ --result_file ${ckpt_prefix}.${type}.rsl \ --checkpoint_path ${ckpt_prefix} \ --opts decoding.decoding_method ${type} \ - --opts decoding.batch_size ${batch_size} if [ $? -ne 0 ]; then echo "Failed in evaluation!" diff --git a/examples/ted_en_zh/st1/RESULTS.md b/examples/ted_en_zh/st1/RESULTS.md index e8aed53e..66dbce6c 100644 --- a/examples/ted_en_zh/st1/RESULTS.md +++ b/examples/ted_en_zh/st1/RESULTS.md @@ -12,5 +12,5 @@ ## Transformer | Model | Params | Config | Val loss | Char-BLEU | | --- | --- | --- | --- | --- | -| FAT + Transformer+ASR MTL | 50.26M | conf/transformer_mtl_noam.yaml | 62.86 | 19.45 | +| FAT + Transformer+ASR MTL | 50.26M | conf/transformer_mtl_noam.yaml | 69.91 | 20.26 | | FAT + Transformer+ASR MTL with word reward | 50.26M | conf/transformer_mtl_noam.yaml | 62.86 | 20.80 | diff --git a/examples/ted_en_zh/st1/conf/transformer.yaml b/examples/ted_en_zh/st1/conf/transformer.yaml index 609c5824..a8918a23 100644 --- a/examples/ted_en_zh/st1/conf/transformer.yaml +++ b/examples/ted_en_zh/st1/conf/transformer.yaml @@ -1,39 +1,33 @@ # https://yaml.org/type/float.html data: - train_manifest: data/manifest.train.tiny + train_manifest: data/manifest.train dev_manifest: data/manifest.dev test_manifest: data/manifest.test - min_input_len: 5.0 # frame - max_input_len: 3000.0 # frame - min_output_len: 0.0 # tokens - max_output_len: 400.0 # tokens - min_output_input_ratio: 0.01 - max_output_input_ratio: 20.0 collator: - vocab_filepath: data/lang_char/vocab.txt + vocab_filepath: data/lang_char/ted_en_zh_bpe8000.txt unit_type: 'spm' - spm_model_prefix: data/lang_char/bpe_unigram_8000 + spm_model_prefix: data/lang_char/ted_en_zh_bpe8000 mean_std_filepath: "" # augmentation_config: conf/augmentation.json - batch_size: 10 - raw_wav: True # use raw_wav or kaldi feature - spectrum_type: fbank #linear, mfcc, fbank + batch_size: 20 feat_dim: 83 - delta_delta: False - dither: 1.0 - target_sample_rate: 16000 - max_freq: None - n_fft: None stride_ms: 10.0 window_ms: 25.0 - use_dB_normalization: True - target_dB: -20 - random_seed: 0 - keep_transcription_text: False - sortagrad: True - shuffle_method: batch_shuffle - num_workers: 2 + sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs + maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced + maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced + minibatches: 0 # for debug + batch_count: auto + batch_bins: 0 + batch_frames_in: 0 + batch_frames_out: 0 + batch_frames_inout: 0 + augmentation_config: + num_workers: 0 + subsampling_factor: 1 + num_encs: 1 + # network architecture @@ -73,18 +67,18 @@ model: training: - n_epoch: 20 + n_epoch: 40 accum_grad: 2 global_grad_clip: 5.0 optim: adam optim_conf: - lr: 0.004 - weight_decay: 1e-06 - scheduler: warmuplr + lr: 2.5 + weight_decay: 0. + scheduler: noam scheduler_conf: warmup_steps: 25000 lr_decay: 1.0 - log_interval: 5 + log_interval: 50 checkpoint: kbest_n: 50 latest_n: 5 @@ -107,4 +101,4 @@ decoding: # >0: for decoding, use fixed chunk size as set. # 0: used for training, it's prohibited here. num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. - simulate_streaming: False # simulate streaming inference. Defaults to False. + simulate_streaming: False # simulate streaming inference. Defaults to False. \ No newline at end of file diff --git a/examples/ted_en_zh/st1/conf/transformer_mtl_noam.yaml b/examples/ted_en_zh/st1/conf/transformer_mtl_noam.yaml index 10eccd1e..3787037f 100644 --- a/examples/ted_en_zh/st1/conf/transformer_mtl_noam.yaml +++ b/examples/ted_en_zh/st1/conf/transformer_mtl_noam.yaml @@ -3,12 +3,6 @@ data: train_manifest: data/manifest.train dev_manifest: data/manifest.dev test_manifest: data/manifest.test - min_input_len: 5.0 # frame - max_input_len: 3000.0 # frame - min_output_len: 0.0 # tokens - max_output_len: 400.0 # tokens - min_output_input_ratio: 0.01 - max_output_input_ratio: 20.0 collator: vocab_filepath: data/lang_char/ted_en_zh_bpe8000.txt @@ -16,24 +10,24 @@ collator: spm_model_prefix: data/lang_char/ted_en_zh_bpe8000 mean_std_filepath: "" # augmentation_config: conf/augmentation.json - batch_size: 10 - raw_wav: True # use raw_wav or kaldi feature - spectrum_type: fbank #linear, mfcc, fbank + batch_size: 20 feat_dim: 83 - delta_delta: False - dither: 1.0 - target_sample_rate: 16000 - max_freq: None - n_fft: None stride_ms: 10.0 window_ms: 25.0 - use_dB_normalization: True - target_dB: -20 - random_seed: 0 - keep_transcription_text: False - sortagrad: True - shuffle_method: batch_shuffle - num_workers: 2 + sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs + maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced + maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced + minibatches: 0 # for debug + batch_count: auto + batch_bins: 0 + batch_frames_in: 0 + batch_frames_out: 0 + batch_frames_inout: 0 + augmentation_config: + num_workers: 0 + subsampling_factor: 1 + num_encs: 1 + # network architecture @@ -73,18 +67,18 @@ model: training: - n_epoch: 20 + n_epoch: 40 accum_grad: 2 global_grad_clip: 5.0 optim: adam optim_conf: lr: 2.5 - weight_decay: 1e-06 + weight_decay: 0. scheduler: noam scheduler_conf: warmup_steps: 25000 lr_decay: 1.0 - log_interval: 5 + log_interval: 50 checkpoint: kbest_n: 50 latest_n: 5 diff --git a/examples/ted_en_zh/st1/local/test.sh b/examples/ted_en_zh/st1/local/test.sh index a9b18dd9..0796a06e 100755 --- a/examples/ted_en_zh/st1/local/test.sh +++ b/examples/ted_en_zh/st1/local/test.sh @@ -13,14 +13,12 @@ ckpt_prefix=$2 for type in fullsentence; do echo "decoding ${type}" - batch_size=32 python3 -u ${BIN_DIR}/test.py \ --ngpu ${ngpu} \ --config ${config_path} \ --result_file ${ckpt_prefix}.${type}.rsl \ --checkpoint_path ${ckpt_prefix} \ --opts decoding.decoding_method ${type} \ - --opts decoding.batch_size ${batch_size} if [ $? -ne 0 ]; then echo "Failed in evaluation!" diff --git a/paddlespeech/s2t/exps/u2_st/model.py b/paddlespeech/s2t/exps/u2_st/model.py index a3b39df7..4b671132 100644 --- a/paddlespeech/s2t/exps/u2_st/model.py +++ b/paddlespeech/s2t/exps/u2_st/model.py @@ -16,6 +16,7 @@ import json import os import time from collections import defaultdict +from collections import OrderedDict from contextlib import nullcontext from typing import Optional @@ -23,21 +24,18 @@ import jsonlines import numpy as np import paddle from paddle import distributed as dist -from paddle.io import DataLoader from yacs.config import CfgNode -from paddlespeech.s2t.io.collator import SpeechCollator -from paddlespeech.s2t.io.collator import TripletSpeechCollator -from paddlespeech.s2t.io.dataset import ManifestDataset -from paddlespeech.s2t.io.sampler import SortagradBatchSampler -from paddlespeech.s2t.io.sampler import SortagradDistributedBatchSampler +from paddlespeech.s2t.frontend.featurizer import TextFeaturizer +from paddlespeech.s2t.io.dataloader import BatchDataLoader from paddlespeech.s2t.models.u2_st import U2STModel -from paddlespeech.s2t.training.gradclip import ClipGradByGlobalNormWithLog -from paddlespeech.s2t.training.scheduler import WarmupLR +from paddlespeech.s2t.training.optimizer import OptimizerFactory +from paddlespeech.s2t.training.reporter import ObsScope +from paddlespeech.s2t.training.reporter import report +from paddlespeech.s2t.training.scheduler import LRSchedulerFactory from paddlespeech.s2t.training.timer import Timer from paddlespeech.s2t.training.trainer import Trainer from paddlespeech.s2t.utils import bleu_score -from paddlespeech.s2t.utils import ctc_utils from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils import mp_tools from paddlespeech.s2t.utils.log import Log @@ -96,6 +94,8 @@ class U2STTrainer(Trainer): # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad losses_np = {'loss': float(loss) * train_conf.accum_grad} + if st_loss: + losses_np['st_loss'] = float(st_loss) if attention_loss: losses_np['att_loss'] = float(attention_loss) if ctc_loss: @@ -125,6 +125,12 @@ class U2STTrainer(Trainer): iteration_time = time.time() - start + for k, v in losses_np.items(): + report(k, v) + report("batch_size", self.config.collator.batch_size) + report("accum", train_conf.accum_grad) + report("step_cost", iteration_time) + if (batch_index + 1) % train_conf.log_interval == 0: msg += "train time: {:>.3f}s, ".format(iteration_time) msg += "batch size: {}, ".format(self.config.collator.batch_size) @@ -204,16 +210,34 @@ class U2STTrainer(Trainer): data_start_time = time.time() for batch_index, batch in enumerate(self.train_loader): dataload_time = time.time() - data_start_time - msg = "Train: Rank: {}, ".format(dist.get_rank()) - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(batch_index + 1, - len(self.train_loader)) - msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) - msg += "data time: {:>.3f}s, ".format(dataload_time) - self.train_batch(batch_index, batch, msg) - self.after_train_batch() - data_start_time = time.time() + msg = "Train:" + observation = OrderedDict() + with ObsScope(observation): + report("Rank", dist.get_rank()) + report("epoch", self.epoch) + report('step', self.iteration) + report("lr", self.lr_scheduler()) + self.train_batch(batch_index, batch, msg) + self.after_train_batch() + report('iter', batch_index + 1) + report('total', len(self.train_loader)) + report('reader_cost', dataload_time) + observation['batch_cost'] = observation[ + 'reader_cost'] + observation['step_cost'] + observation['samples'] = observation['batch_size'] + observation['ips,sent./sec'] = observation[ + 'batch_size'] / observation['batch_cost'] + for k, v in observation.items(): + msg += f" {k.split(',')[0]}: " + msg += f"{v:>.8f}" if isinstance(v, + float) else f"{v}" + msg += f" {k.split(',')[1]}" if len( + k.split(',')) == 2 else "" + msg += "," + msg = msg[:-1] # remove the last "," + if (batch_index + 1 + ) % self.config.training.log_interval == 0: + logger.info(msg) except Exception as e: logger.error(e) raise e @@ -244,95 +268,87 @@ class U2STTrainer(Trainer): def setup_dataloader(self): config = self.config.clone() - config.defrost() - config.collator.keep_transcription_text = False - - # train/valid dataset, return token ids - config.data.manifest = config.data.train_manifest - train_dataset = ManifestDataset.from_config(config) - - config.data.manifest = config.data.dev_manifest - dev_dataset = ManifestDataset.from_config(config) - - if config.model.model_conf.asr_weight > 0.: - Collator = TripletSpeechCollator - TestCollator = SpeechCollator - else: - TestCollator = Collator = SpeechCollator - collate_fn_train = Collator.from_config(config) - config.collator.augmentation_config = "" - collate_fn_dev = Collator.from_config(config) + load_transcript = True if config.model.model_conf.asr_weight > 0 else False - if self.parallel: - batch_sampler = SortagradDistributedBatchSampler( - train_dataset, + if self.train: + # train/valid dataset, return token ids + self.train_loader = BatchDataLoader( + json_file=config.data.train_manifest, + train_mode=True, + sortagrad=False, batch_size=config.collator.batch_size, - num_replicas=None, - rank=None, - shuffle=True, - drop_last=True, - sortagrad=config.collator.sortagrad, - shuffle_method=config.collator.shuffle_method) - else: - batch_sampler = SortagradBatchSampler( - train_dataset, - shuffle=True, + maxlen_in=config.collator.maxlen_in, + maxlen_out=config.collator.maxlen_out, + minibatches=0, + mini_batch_size=1, + batch_count='auto', + batch_bins=0, + batch_frames_in=0, + batch_frames_out=0, + batch_frames_inout=0, + preprocess_conf=config.collator. + augmentation_config, # aug will be off when train_mode=False + n_iter_processes=config.collator.num_workers, + subsampling_factor=1, + load_aux_output=load_transcript, + num_encs=1) + + self.valid_loader = BatchDataLoader( + json_file=config.data.dev_manifest, + train_mode=False, + sortagrad=False, batch_size=config.collator.batch_size, - drop_last=True, - sortagrad=config.collator.sortagrad, - shuffle_method=config.collator.shuffle_method) - self.train_loader = DataLoader( - train_dataset, - batch_sampler=batch_sampler, - collate_fn=collate_fn_train, - num_workers=config.collator.num_workers, ) - self.valid_loader = DataLoader( - dev_dataset, - batch_size=config.collator.batch_size, - shuffle=False, - drop_last=False, - collate_fn=collate_fn_dev, - num_workers=config.collator.num_workers, ) - - # test dataset, return raw text - config.data.manifest = config.data.test_manifest - # filter test examples, will cause less examples, but no mismatch with training - # and can use large batch size , save training time, so filter test egs now. - # config.data.min_input_len = 0.0 # second - # config.data.max_input_len = float('inf') # second - # config.data.min_output_len = 0.0 # tokens - # config.data.max_output_len = float('inf') # tokens - # config.data.min_output_input_ratio = 0.00 - # config.data.max_output_input_ratio = float('inf') - test_dataset = ManifestDataset.from_config(config) - # return text ord id - config.collator.keep_transcription_text = True - config.collator.augmentation_config = "" - self.test_loader = DataLoader( - test_dataset, - batch_size=config.decoding.batch_size, - shuffle=False, - drop_last=False, - collate_fn=TestCollator.from_config(config), - num_workers=config.collator.num_workers, ) - # return text token id - config.collator.keep_transcription_text = False - self.align_loader = DataLoader( - test_dataset, - batch_size=config.decoding.batch_size, - shuffle=False, - drop_last=False, - collate_fn=TestCollator.from_config(config), - num_workers=config.collator.num_workers, ) - logger.info("Setup train/valid/test/align Dataloader!") + maxlen_in=float('inf'), + maxlen_out=float('inf'), + minibatches=0, + mini_batch_size=1, + batch_count='auto', + batch_bins=0, + batch_frames_in=0, + batch_frames_out=0, + batch_frames_inout=0, + preprocess_conf=config.collator. + augmentation_config, # aug will be off when train_mode=False + n_iter_processes=config.collator.num_workers, + subsampling_factor=1, + load_aux_output=load_transcript, + num_encs=1) + logger.info("Setup train/valid Dataloader!") + else: + # test dataset, return raw text + self.test_loader = BatchDataLoader( + json_file=config.data.test_manifest, + train_mode=False, + sortagrad=False, + batch_size=config.decoding.batch_size, + maxlen_in=float('inf'), + maxlen_out=float('inf'), + minibatches=0, + mini_batch_size=1, + batch_count='auto', + batch_bins=0, + batch_frames_in=0, + batch_frames_out=0, + batch_frames_inout=0, + preprocess_conf=config.collator. + augmentation_config, # aug will be off when train_mode=False + n_iter_processes=config.collator.num_workers, + subsampling_factor=1, + num_encs=1) + + logger.info("Setup test Dataloader!") def setup_model(self): config = self.config model_conf = config.model with UpdateConfig(model_conf): - model_conf.input_dim = self.train_loader.collate_fn.feature_size - model_conf.output_dim = self.train_loader.collate_fn.vocab_size + if self.train: + model_conf.input_dim = self.train_loader.feat_dim + model_conf.output_dim = self.train_loader.vocab_size + else: + model_conf.input_dim = self.test_loader.feat_dim + model_conf.output_dim = self.test_loader.vocab_size model = U2STModel.from_config(model_conf) @@ -348,35 +364,38 @@ class U2STTrainer(Trainer): scheduler_type = train_config.scheduler scheduler_conf = train_config.scheduler_conf - if scheduler_type == 'expdecaylr': - lr_scheduler = paddle.optimizer.lr.ExponentialDecay( - learning_rate=optim_conf.lr, - gamma=scheduler_conf.lr_decay, - verbose=False) - elif scheduler_type == 'warmuplr': - lr_scheduler = WarmupLR( - learning_rate=optim_conf.lr, - warmup_steps=scheduler_conf.warmup_steps, - verbose=False) - elif scheduler_type == 'noam': - lr_scheduler = paddle.optimizer.lr.NoamDecay( - learning_rate=optim_conf.lr, - d_model=model_conf.encoder_conf.output_size, - warmup_steps=scheduler_conf.warmup_steps, - verbose=False) - else: - raise ValueError(f"Not support scheduler: {scheduler_type}") - - grad_clip = ClipGradByGlobalNormWithLog(train_config.global_grad_clip) - weight_decay = paddle.regularizer.L2Decay(optim_conf.weight_decay) - if optim_type == 'adam': - optimizer = paddle.optimizer.Adam( - learning_rate=lr_scheduler, - parameters=model.parameters(), - weight_decay=weight_decay, - grad_clip=grad_clip) - else: - raise ValueError(f"Not support optim: {optim_type}") + scheduler_args = { + "learning_rate": optim_conf.lr, + "verbose": False, + "warmup_steps": scheduler_conf.warmup_steps, + "gamma": scheduler_conf.lr_decay, + "d_model": model_conf.encoder_conf.output_size, + } + lr_scheduler = LRSchedulerFactory.from_args(scheduler_type, + scheduler_args) + + def optimizer_args( + config, + parameters, + lr_scheduler=None, ): + train_config = config.training + optim_type = train_config.optim + optim_conf = train_config.optim_conf + scheduler_type = train_config.scheduler + scheduler_conf = train_config.scheduler_conf + return { + "grad_clip": train_config.global_grad_clip, + "weight_decay": optim_conf.weight_decay, + "learning_rate": lr_scheduler + if lr_scheduler else optim_conf.lr, + "parameters": parameters, + "epsilon": 1e-9 if optim_type == 'noam' else None, + "beta1": 0.9 if optim_type == 'noam' else None, + "beat2": 0.98 if optim_type == 'noam' else None, + } + + optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler) + optimizer = OptimizerFactory.from_args(optim_type, optimzer_args) self.model = model self.optimizer = optimizer @@ -416,26 +435,30 @@ class U2STTester(U2STTrainer): def __init__(self, config, args): super().__init__(config, args) + self.text_feature = TextFeaturizer( + unit_type=self.config.collator.unit_type, + vocab_filepath=self.config.collator.vocab_filepath, + spm_model_prefix=self.config.collator.spm_model_prefix) + self.vocab_list = self.text_feature.vocab_list - def ordid2token(self, texts, texts_len): + def id2token(self, texts, texts_len, text_feature): """ ord() id to chr() chr """ trans = [] for text, n in zip(texts, texts_len): n = n.numpy().item() ids = text[:n] - trans.append(''.join([chr(i) for i in ids])) + trans.append(text_feature.defeaturize(ids.numpy().tolist())) return trans def translate(self, audio, audio_len): """"E2E translation from extracted audio feature""" cfg = self.config.decoding - text_feature = self.test_loader.collate_fn.text_feature self.model.eval() hyps = self.model.decode( audio, audio_len, - text_feature=text_feature, + text_feature=self.text_feature, decoding_method=cfg.decoding_method, beam_size=cfg.beam_size, word_reward=cfg.word_reward, @@ -456,23 +479,20 @@ class U2STTester(U2STTrainer): len_refs, num_ins = 0, 0 start_time = time.time() - text_feature = self.test_loader.collate_fn.text_feature - refs = [ - "".join(chr(t) for t in text[:text_len]) - for text, text_len in zip(texts, texts_len) - ] + refs = self.id2token(texts, texts_len, self.text_feature) hyps = self.model.decode( audio, audio_len, - text_feature=text_feature, + text_feature=self.text_feature, decoding_method=cfg.decoding_method, beam_size=cfg.beam_size, word_reward=cfg.word_reward, decoding_chunk_size=cfg.decoding_chunk_size, num_decoding_left_chunks=cfg.num_decoding_left_chunks, simulate_streaming=cfg.simulate_streaming) + decode_time = time.time() - start_time for utt, target, result in zip(utts, refs, hyps): @@ -505,7 +525,7 @@ class U2STTester(U2STTrainer): cfg = self.config.decoding bleu_func = bleu_score.char_bleu if cfg.error_rate_type == 'char-bleu' else bleu_score.bleu - stride_ms = self.test_loader.collate_fn.stride_ms + stride_ms = self.config.collator.stride_ms hyps, refs = [], [] len_refs, num_ins = 0, 0 num_frames = 0.0 @@ -522,7 +542,7 @@ class U2STTester(U2STTrainer): len_refs += metrics['len_refs'] num_ins += metrics['num_ins'] rtf = num_time / (num_frames * stride_ms) - logger.info("RTF: %f, BELU (%d) = %f" % (rtf, num_ins, bleu)) + logger.info("RTF: %f, instance (%d), batch BELU = %f" % (rtf, num_ins, bleu)) rtf = num_time / (num_frames * stride_ms) msg = "Test: " @@ -553,13 +573,6 @@ class U2STTester(U2STTrainer): }) f.write(data + '\n') - @paddle.no_grad() - def align(self): - ctc_utils.ctc_align(self.config, self.model, self.align_loader, - self.config.decoding.batch_size, - self.config.collator.stride_ms, self.vocab_list, - self.args.result_file) - def load_inferspec(self): """infer model and input spec. @@ -567,11 +580,11 @@ class U2STTester(U2STTrainer): nn.Layer: inference model List[paddle.static.InputSpec]: input spec. """ - from paddlespeech.s2t.models.u2 import U2InferModel - infer_model = U2InferModel.from_pretrained(self.test_loader, - self.config.model.clone(), - self.args.checkpoint_path) - feat_dim = self.test_loader.collate_fn.feature_size + from paddlespeech.s2t.models.u2_st import U2STInferModel + infer_model = U2STInferModel.from_pretrained(self.test_loader, + self.config.model.clone(), + self.args.checkpoint_path) + feat_dim = self.test_loader.feat_dim input_spec = [ paddle.static.InputSpec(shape=[1, None, feat_dim], dtype='float32'), # audio, [B,T,D] diff --git a/paddlespeech/s2t/io/converter.py b/paddlespeech/s2t/io/converter.py index b217d2b1..c92ef017 100644 --- a/paddlespeech/s2t/io/converter.py +++ b/paddlespeech/s2t/io/converter.py @@ -31,11 +31,17 @@ class CustomConverter(): """ - def __init__(self, subsampling_factor=1, dtype=np.float32): + def __init__(self, + subsampling_factor=1, + dtype=np.float32, + load_aux_input=False, + load_aux_output=False): """Construct a CustomConverter object.""" self.subsampling_factor = subsampling_factor self.ignore_id = -1 self.dtype = dtype + self.load_aux_input = load_aux_input + self.load_aux_output = load_aux_output def __call__(self, batch): """Transform a batch and send it to a device. @@ -49,34 +55,48 @@ class CustomConverter(): """ # batch should be located in list assert len(batch) == 1 - (xs, ys), utts = batch[0] - assert xs[0] is not None, "please check Reader and Augmentation impl." - - # perform subsampling - if self.subsampling_factor > 1: - xs = [x[::self.subsampling_factor, :] for x in xs] - - # get batch of lengths of input sequences - ilens = np.array([x.shape[0] for x in xs]) - - # perform padding and convert to tensor - # currently only support real number - if xs[0].dtype.kind == "c": - xs_pad_real = pad_list([x.real for x in xs], 0).astype(self.dtype) - xs_pad_imag = pad_list([x.imag for x in xs], 0).astype(self.dtype) - # Note(kamo): - # {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E. - # Don't create ComplexTensor and give it E2E here - # because torch.nn.DataParellel can't handle it. - xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag} - else: - xs_pad = pad_list(xs, 0).astype(self.dtype) - + data, utts = batch[0] + xs_data, ys_data = [], [] + for ud in data: + if ud[0].ndim > 1: + # speech data (input): (speech_len, feat_dim) + xs_data.append(ud) + else: + # text data (output): (text_len, ) + ys_data.append(ud) + + assert xs_data[0][0] is not None, "please check Reader and Augmentation impl." + + xs_pad, ilens = [], [] + for xs in xs_data: + # perform subsampling + if self.subsampling_factor > 1: + xs = [x[::self.subsampling_factor, :] for x in xs] + + # get batch of lengths of input sequences + ilens.append(np.array([x.shape[0] for x in xs])) + + # perform padding and convert to tensor + # currently only support real number + xs_pad.append(pad_list(xs, 0).astype(self.dtype)) + + if not self.load_aux_input: + xs_pad, ilens = xs_pad[0], ilens[0] + break + # NOTE: this is for multi-output (e.g., speech translation) - ys_pad = pad_list( - [np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys], - self.ignore_id) + ys_pad, olens = [], [] + + for ys in ys_data: + ys_pad.append(pad_list( + [np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys], + self.ignore_id)) + + olens.append(np.array( + [y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys])) + + if not self.load_aux_output: + ys_pad, olens = ys_pad[0], olens[0] + break - olens = np.array( - [y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys]) return utts, xs_pad, ilens, ys_pad, olens diff --git a/paddlespeech/s2t/io/dataloader.py b/paddlespeech/s2t/io/dataloader.py index b8eb3367..8330b1da 100644 --- a/paddlespeech/s2t/io/dataloader.py +++ b/paddlespeech/s2t/io/dataloader.py @@ -19,6 +19,7 @@ from typing import Text import jsonlines import numpy as np from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler from paddlespeech.s2t.io.batchfy import make_batchset from paddlespeech.s2t.io.converter import CustomConverter @@ -73,6 +74,8 @@ class BatchDataLoader(): preprocess_conf=None, n_iter_processes: int=1, subsampling_factor: int=1, + load_aux_input: bool=False, + load_aux_output: bool=False, num_encs: int=1): self.json_file = json_file self.train_mode = train_mode @@ -89,6 +92,8 @@ class BatchDataLoader(): self.num_encs = num_encs self.preprocess_conf = preprocess_conf self.n_iter_processes = n_iter_processes + self.load_aux_input = load_aux_input + self.load_aux_output = load_aux_output # read json data with jsonlines.open(json_file, 'r') as reader: @@ -126,21 +131,29 @@ class BatchDataLoader(): # Setup a converter if num_encs == 1: self.converter = CustomConverter( - subsampling_factor=subsampling_factor, dtype=np.float32) + subsampling_factor=subsampling_factor, + dtype=np.float32, + load_aux_input=load_aux_input, + load_aux_output=load_aux_output) else: assert NotImplementedError("not impl CustomConverterMulEnc.") # hack to make batchsize argument as 1 # actual bathsize is included in a list - # default collate function converts numpy array to pytorch tensor + # default collate function converts numpy array to paddle tensor # we used an empty collate function instead which returns list self.dataset = TransformDataset(self.minibaches, self.converter, self.reader) - self.dataloader = DataLoader( + self.sampler = DistributedBatchSampler( dataset=self.dataset, batch_size=1, shuffle=not self.use_sortagrad if self.train_mode else False, + ) + + self.dataloader = DataLoader( + dataset=self.dataset, + batch_sampler=self.sampler, collate_fn=batch_collate, num_workers=self.n_iter_processes, ) diff --git a/paddlespeech/s2t/io/reader.py b/paddlespeech/s2t/io/reader.py index 38ff1396..4e136bdc 100644 --- a/paddlespeech/s2t/io/reader.py +++ b/paddlespeech/s2t/io/reader.py @@ -68,7 +68,7 @@ class LoadInputsAndTargets(): if mode not in ["asr"]: raise ValueError("Only asr are allowed: mode={}".format(mode)) - if preprocess_conf is not None: + if preprocess_conf: self.preprocessing = Transformation(preprocess_conf) logger.warning( "[Experimental feature] Some preprocessing will be done " @@ -82,12 +82,11 @@ class LoadInputsAndTargets(): self.load_output = load_output self.load_input = load_input self.sort_in_input_length = sort_in_input_length - if preprocess_args is None: - self.preprocess_args = {} - else: + if preprocess_args: assert isinstance(preprocess_args, dict), type(preprocess_args) self.preprocess_args = dict(preprocess_args) - + else: + self.preprocess_args = {} self.keep_all_data_on_mem = keep_all_data_on_mem def __call__(self, batch, return_uttid=False): From 6d93f3e55e8ff75dec638d5c4e9c9b8352b7b07f Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 30 Dec 2021 11:20:06 +0800 Subject: [PATCH 4/7] add yeyupiaoling's repos' links, test=doc_fix (#1243) --- README.md | 2 +- README_cn.md | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index a1211baa..08d5dc99 100644 --- a/README.md +++ b/README.md @@ -530,7 +530,7 @@ You are warmly welcome to submit questions in [discussions](https://github.com/P ## Acknowledgement -- Many thanks to [yeyupiaoling](https://github.com/yeyupiaoling) for years of attention, constructive advice and great help. +- Many thanks to [yeyupiaoling](https://github.com/yeyupiaoling)/[PPASR](https://github.com/yeyupiaoling/PPASR)/[PaddlePaddle-DeepSpeech](https://github.com/yeyupiaoling/PaddlePaddle-DeepSpeech)/[VoiceprintRecognition-PaddlePaddle](https://github.com/yeyupiaoling/VoiceprintRecognition-PaddlePaddle)/[AudioClassification-PaddlePaddle](https://github.com/yeyupiaoling/AudioClassification-PaddlePaddle) for years of attention, constructive advice and great help. - Many thanks to [AK391](https://github.com/AK391) for TTS web demo on Huggingface Spaces using Gradio. - Many thanks to [mymagicpower](https://github.com/mymagicpower) for the Java implementation of ASR upon [short](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_sdk) and [long](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_long_audio_sdk) audio files. - Many thanks to [JiehangXie](https://github.com/JiehangXie)/[PaddleBoBo](https://github.com/JiehangXie/PaddleBoBo) for developing Virtual Uploader(VUP)/Virtual YouTuber(VTuber) with PaddleSpeech TTS function. diff --git a/README_cn.md b/README_cn.md index 33e66c4b..43790039 100644 --- a/README_cn.md +++ b/README_cn.md @@ -497,7 +497,6 @@ year={2021} ## 参与 PaddleSpeech 的开发 - 热烈欢迎您在[Discussions](https://github.com/PaddlePaddle/PaddleSpeech/discussions) 中提交问题,并在[Issues](https://github.com/PaddlePaddle/PaddleSpeech/issues) 中指出发现的 bug。此外,我们非常希望您参与到 PaddleSpeech 的开发中! ### 贡献者 @@ -539,7 +538,7 @@ year={2021} ## 致谢 -- 非常感谢 [yeyupiaoling](https://github.com/yeyupiaoling) 多年来的关注和建议,以及在诸多问题上的帮助。 +- 非常感谢 [yeyupiaoling](https://github.com/yeyupiaoling)/[PPASR](https://github.com/yeyupiaoling/PPASR)/[PaddlePaddle-DeepSpeech](https://github.com/yeyupiaoling/PaddlePaddle-DeepSpeech)/[VoiceprintRecognition-PaddlePaddle](https://github.com/yeyupiaoling/VoiceprintRecognition-PaddlePaddle)/[AudioClassification-PaddlePaddle](https://github.com/yeyupiaoling/AudioClassification-PaddlePaddle) 多年来的关注和建议,以及在诸多问题上的帮助。 - 非常感谢 [AK391](https://github.com/AK391) 在 Huggingface Spaces 上使用 Gradio 对我们的语音合成功能进行网页版演示。 - 非常感谢 [mymagicpower](https://github.com/mymagicpower) 采用PaddleSpeech 对 ASR 的[短语音](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_sdk)及[长语音](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_long_audio_sdk)进行 Java 实现。 - 非常感谢 [JiehangXie](https://github.com/JiehangXie)/[PaddleBoBo](https://github.com/JiehangXie/PaddleBoBo) 采用 PaddleSpeech 语音合成功能实现 Virtual Uploader(VUP)/Virtual YouTuber(VTuber) 虚拟主播。 From c81a3f0f8388b97c2fea6bccb64cd6dfa0c05439 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 30 Dec 2021 12:52:48 +0800 Subject: [PATCH 5/7] [s2t] DataLoader with BatchSampler or DistributeBatchSampler (#1242) * batchsampler or distributebatchsampler * format --- paddlespeech/s2t/exps/u2_st/model.py | 12 +++++++---- paddlespeech/s2t/io/converter.py | 27 +++++++++++++++---------- paddlespeech/s2t/io/dataloader.py | 25 +++++++++++++++++------ paddlespeech/t2s/exps/synthesize_e2e.py | 7 +++++-- 4 files changed, 48 insertions(+), 23 deletions(-) diff --git a/paddlespeech/s2t/exps/u2_st/model.py b/paddlespeech/s2t/exps/u2_st/model.py index 4b671132..89408786 100644 --- a/paddlespeech/s2t/exps/u2_st/model.py +++ b/paddlespeech/s2t/exps/u2_st/model.py @@ -292,7 +292,8 @@ class U2STTrainer(Trainer): n_iter_processes=config.collator.num_workers, subsampling_factor=1, load_aux_output=load_transcript, - num_encs=1) + num_encs=1, + dist_sampler=True) self.valid_loader = BatchDataLoader( json_file=config.data.dev_manifest, @@ -313,7 +314,8 @@ class U2STTrainer(Trainer): n_iter_processes=config.collator.num_workers, subsampling_factor=1, load_aux_output=load_transcript, - num_encs=1) + num_encs=1, + dist_sampler=True) logger.info("Setup train/valid Dataloader!") else: # test dataset, return raw text @@ -335,7 +337,8 @@ class U2STTrainer(Trainer): augmentation_config, # aug will be off when train_mode=False n_iter_processes=config.collator.num_workers, subsampling_factor=1, - num_encs=1) + num_encs=1, + dist_sampler=False) logger.info("Setup test Dataloader!") @@ -542,7 +545,8 @@ class U2STTester(U2STTrainer): len_refs += metrics['len_refs'] num_ins += metrics['num_ins'] rtf = num_time / (num_frames * stride_ms) - logger.info("RTF: %f, instance (%d), batch BELU = %f" % (rtf, num_ins, bleu)) + logger.info("RTF: %f, instance (%d), batch BELU = %f" % + (rtf, num_ins, bleu)) rtf = num_time / (num_frames * stride_ms) msg = "Test: " diff --git a/paddlespeech/s2t/io/converter.py b/paddlespeech/s2t/io/converter.py index c92ef017..a802ac74 100644 --- a/paddlespeech/s2t/io/converter.py +++ b/paddlespeech/s2t/io/converter.py @@ -65,8 +65,9 @@ class CustomConverter(): # text data (output): (text_len, ) ys_data.append(ud) - assert xs_data[0][0] is not None, "please check Reader and Augmentation impl." - + assert xs_data[0][ + 0] is not None, "please check Reader and Augmentation impl." + xs_pad, ilens = [], [] for xs in xs_data: # perform subsampling @@ -79,22 +80,26 @@ class CustomConverter(): # perform padding and convert to tensor # currently only support real number xs_pad.append(pad_list(xs, 0).astype(self.dtype)) - + if not self.load_aux_input: xs_pad, ilens = xs_pad[0], ilens[0] break - + # NOTE: this is for multi-output (e.g., speech translation) ys_pad, olens = [], [] - + for ys in ys_data: - ys_pad.append(pad_list( - [np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys], - self.ignore_id)) + ys_pad.append( + pad_list([ + np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys + ], self.ignore_id)) + + olens.append( + np.array([ + y[0].shape[0] if isinstance(y, tuple) else y.shape[0] + for y in ys + ])) - olens.append(np.array( - [y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys])) - if not self.load_aux_output: ys_pad, olens = ys_pad[0], olens[0] break diff --git a/paddlespeech/s2t/io/dataloader.py b/paddlespeech/s2t/io/dataloader.py index 8330b1da..455303f7 100644 --- a/paddlespeech/s2t/io/dataloader.py +++ b/paddlespeech/s2t/io/dataloader.py @@ -18,6 +18,7 @@ from typing import Text import jsonlines import numpy as np +from paddle.io import BatchSampler from paddle.io import DataLoader from paddle.io import DistributedBatchSampler @@ -76,7 +77,8 @@ class BatchDataLoader(): subsampling_factor: int=1, load_aux_input: bool=False, load_aux_output: bool=False, - num_encs: int=1): + num_encs: int=1, + dist_sampler: bool=False): self.json_file = json_file self.train_mode = train_mode self.use_sortagrad = sortagrad == -1 or sortagrad > 0 @@ -94,6 +96,7 @@ class BatchDataLoader(): self.n_iter_processes = n_iter_processes self.load_aux_input = load_aux_input self.load_aux_output = load_aux_output + self.dist_sampler = dist_sampler # read json data with jsonlines.open(json_file, 'r') as reader: @@ -145,11 +148,18 @@ class BatchDataLoader(): self.dataset = TransformDataset(self.minibaches, self.converter, self.reader) - self.sampler = DistributedBatchSampler( - dataset=self.dataset, - batch_size=1, - shuffle=not self.use_sortagrad if self.train_mode else False, - ) + if self.dist_sampler: + self.sampler = DistributedBatchSampler( + dataset=self.dataset, + batch_size=1, + shuffle=not self.use_sortagrad if self.train_mode else False, + drop_last=False, ) + else: + self.sampler = BatchSampler( + dataset=self.dataset, + batch_size=1, + shuffle=not self.use_sortagrad if self.train_mode else False, + drop_last=False, ) self.dataloader = DataLoader( dataset=self.dataset, @@ -181,5 +191,8 @@ class BatchDataLoader(): echo += f"subsampling_factor: {self.subsampling_factor}, " echo += f"num_encs: {self.num_encs}, " echo += f"num_workers: {self.n_iter_processes}, " + echo += f"load_aux_input: {self.load_aux_input}, " + echo += f"load_aux_output: {self.load_aux_output}, " + echo += f"dist_sampler: {self.dist_sampler}, " echo += f"file: {self.json_file}" return echo diff --git a/paddlespeech/t2s/exps/synthesize_e2e.py b/paddlespeech/t2s/exps/synthesize_e2e.py index fc822b21..15ed1e4d 100644 --- a/paddlespeech/t2s/exps/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/synthesize_e2e.py @@ -203,12 +203,15 @@ def evaluate(args): get_tone_ids = True if args.lang == 'zh': input_ids = frontend.get_input_ids( - sentence, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) + sentence, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids) phone_ids = input_ids["phone_ids"] if get_tone_ids: tone_ids = input_ids["tone_ids"] elif args.lang == 'en': - input_ids = frontend.get_input_ids(sentence, merge_sentences=merge_sentences) + input_ids = frontend.get_input_ids( + sentence, merge_sentences=merge_sentences) phone_ids = input_ids["phone_ids"] else: print("lang should in {'zh', 'en'}!") From 326fcd520ae9d5646b56d79bf27a9c899b78d108 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 30 Dec 2021 08:00:20 +0000 Subject: [PATCH 6/7] fix config, test=tts --- examples/aishell3/voc1/conf/default.yaml | 5 +---- examples/csmsc/voc1/conf/default.yaml | 5 +---- examples/csmsc/voc4/conf/default.yaml | 2 +- examples/csmsc/voc5/conf/default.yaml | 2 +- examples/csmsc/voc5/conf/finetune.yaml | 2 +- examples/ljspeech/voc1/conf/default.yaml | 5 +---- examples/vctk/voc1/conf/default.yaml | 5 +---- 7 files changed, 7 insertions(+), 19 deletions(-) diff --git a/examples/aishell3/voc1/conf/default.yaml b/examples/aishell3/voc1/conf/default.yaml index 88968d6f..7fbffbdd 100644 --- a/examples/aishell3/voc1/conf/default.yaml +++ b/examples/aishell3/voc1/conf/default.yaml @@ -72,10 +72,7 @@ lambda_adv: 4.0 # Loss balancing coefficient. ########################################################### batch_size: 8 # Batch size. batch_max_steps: 24000 # Length of each audio in batch. Make sure dividable by n_shift. -pin_memory: true # Whether to pin memory in Pytorch DataLoader. -num_workers: 4 # Number of workers in Pytorch DataLoader. -remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. -allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. +num_workers: 2 # Number of workers in DataLoader. ########################################################### # OPTIMIZER & SCHEDULER SETTING # diff --git a/examples/csmsc/voc1/conf/default.yaml b/examples/csmsc/voc1/conf/default.yaml index 9ea81b8d..28d218ff 100644 --- a/examples/csmsc/voc1/conf/default.yaml +++ b/examples/csmsc/voc1/conf/default.yaml @@ -79,10 +79,7 @@ lambda_adv: 4.0 # Loss balancing coefficient. ########################################################### batch_size: 8 # Batch size. batch_max_steps: 25500 # Length of each audio in batch. Make sure dividable by n_shift. -pin_memory: true # Whether to pin memory in Pytorch DataLoader. -num_workers: 2 # Number of workers in Pytorch DataLoader. -remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. -allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. +num_workers: 2 # Number of workers in DataLoader. ########################################################### # OPTIMIZER & SCHEDULER SETTING # diff --git a/examples/csmsc/voc4/conf/default.yaml b/examples/csmsc/voc4/conf/default.yaml index 6f7d0f2b..c9abf78d 100644 --- a/examples/csmsc/voc4/conf/default.yaml +++ b/examples/csmsc/voc4/conf/default.yaml @@ -88,7 +88,7 @@ discriminator_adv_loss_params: batch_size: 32 # Batch size. # batch_max_steps(24000) == prod(noise_upsample_scales)(80) * prod(upsample_scales)(300, n_shift) batch_max_steps: 24000 # Length of each audio in batch. Make sure dividable by n_shift. -num_workers: 2 # Number of workers in Pytorch DataLoader. +num_workers: 2 # Number of workers in DataLoader. ########################################################### # OPTIMIZER & SCHEDULER SETTING # diff --git a/examples/csmsc/voc5/conf/default.yaml b/examples/csmsc/voc5/conf/default.yaml index 5192d389..f42fc385 100644 --- a/examples/csmsc/voc5/conf/default.yaml +++ b/examples/csmsc/voc5/conf/default.yaml @@ -119,7 +119,7 @@ lambda_feat_match: 2.0 # Loss balancing coefficient for feat match loss.. ########################################################### batch_size: 16 # Batch size. batch_max_steps: 8400 # Length of each audio in batch. Make sure dividable by hop_size. -num_workers: 2 # Number of workers in Pytorch DataLoader. +num_workers: 2 # Number of workers in DataLoader. ########################################################### # OPTIMIZER & SCHEDULER SETTING # diff --git a/examples/csmsc/voc5/conf/finetune.yaml b/examples/csmsc/voc5/conf/finetune.yaml index 9876e93d..73420625 100644 --- a/examples/csmsc/voc5/conf/finetune.yaml +++ b/examples/csmsc/voc5/conf/finetune.yaml @@ -119,7 +119,7 @@ lambda_feat_match: 2.0 # Loss balancing coefficient for feat match loss.. ########################################################### batch_size: 16 # Batch size. batch_max_steps: 8400 # Length of each audio in batch. Make sure dividable by hop_size. -num_workers: 2 # Number of workers in Pytorch DataLoader. +num_workers: 2 # Number of workers in DataLoader. ########################################################### # OPTIMIZER & SCHEDULER SETTING # diff --git a/examples/ljspeech/voc1/conf/default.yaml b/examples/ljspeech/voc1/conf/default.yaml index bef2d681..2d39beb7 100644 --- a/examples/ljspeech/voc1/conf/default.yaml +++ b/examples/ljspeech/voc1/conf/default.yaml @@ -72,10 +72,7 @@ lambda_adv: 4.0 # Loss balancing coefficient. ########################################################### batch_size: 8 # Batch size. batch_max_steps: 25600 # Length of each audio in batch. Make sure dividable by n_shift. -pin_memory: true # Whether to pin memory in Pytorch DataLoader. -num_workers: 4 # Number of workers in Pytorch DataLoader. -remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. -allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. +num_workers: 2 # Number of workers in DataLoader. ########################################################### # OPTIMIZER & SCHEDULER SETTING # diff --git a/examples/vctk/voc1/conf/default.yaml b/examples/vctk/voc1/conf/default.yaml index d95eaad9..aa382e21 100644 --- a/examples/vctk/voc1/conf/default.yaml +++ b/examples/vctk/voc1/conf/default.yaml @@ -72,10 +72,7 @@ lambda_adv: 4.0 # Loss balancing coefficient. ########################################################### batch_size: 8 # Batch size. batch_max_steps: 24000 # Length of each audio in batch. Make sure dividable by n_shift. -pin_memory: true # Whether to pin memory in Pytorch DataLoader. -num_workers: 4 # Number of workers in Pytorch DataLoader. -remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. -allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. +num_workers: 2 # Number of workers in DataLoader. ########################################################### # OPTIMIZER & SCHEDULER SETTING # From b9a55262f1414977363ba6c380355dabbc766016 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 30 Dec 2021 16:57:01 +0800 Subject: [PATCH 7/7] Update fastspeech2.py --- paddlespeech/t2s/models/fastspeech2/fastspeech2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py index 295a2e4b..405ad957 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py @@ -940,8 +940,6 @@ class StyleFastSpeech2Inference(FastSpeech2Inference): Tensor Output sequence of features (L, odim). """ - if spk_id: - spk_id = paddle.to_tensor(spk_id) normalized_mel, d_outs, p_outs, e_outs = self.acoustic_model.inference( text, durations=None,