diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug-report-asr.md similarity index 87% rename from .github/ISSUE_TEMPLATE/bug_report.md rename to .github/ISSUE_TEMPLATE/bug-report-asr.md index b31d98631..44f3c1401 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug-report-asr.md @@ -1,9 +1,9 @@ --- -name: Bug report +name: "\U0001F41B ASR Bug Report" about: Create a report to help us improve title: '' -labels: '' -assignees: '' +labels: Bug, S2T +assignees: zh794390558 --- @@ -27,7 +27,7 @@ A clear and concise description of what you expected to happen. **Screenshots** If applicable, add screenshots to help explain your problem. -** Environment (please complete the following information):** +**Environment (please complete the following information):** - OS: [e.g. Ubuntu] - GCC/G++ Version [e.g. 8.3] - Python Version [e.g. 3.7] diff --git a/.github/ISSUE_TEMPLATE/bug-report-tts.md b/.github/ISSUE_TEMPLATE/bug-report-tts.md new file mode 100644 index 000000000..d8f7afa82 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report-tts.md @@ -0,0 +1,42 @@ +--- +name: "\U0001F41B TTS Bug Report" +about: Create a report to help us improve +title: '' +labels: Bug, T2S +assignees: yt605155624 + +--- + +For support and discussions, please use our [Discourse forums](https://github.com/PaddlePaddle/DeepSpeech/discussions). + +If you've found a bug then please create an issue with the following information: + +**Describe the bug** +A clear and concise description of what the bug is. + +**To Reproduce** +Steps to reproduce the behavior: +1. Go to '...' +2. Click on '....' +3. Scroll down to '....' +4. See error + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Screenshots** +If applicable, add screenshots to help explain your problem. + +**Environment (please complete the following information):** + - OS: [e.g. Ubuntu] + - GCC/G++ Version [e.g. 8.3] + - Python Version [e.g. 3.7] + - PaddlePaddle Version [e.g. 2.0.0] + - Model Version [e.g. 2.0.0] + - GPU/DRIVER Informationo [e.g. Tesla V100-SXM2-32GB/440.64.00] + - CUDA/CUDNN Version [e.g. cuda-10.2] + - MKL Version +- TensorRT Version + +**Additional context** +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature-request.md b/.github/ISSUE_TEMPLATE/feature-request.md new file mode 100644 index 000000000..8f7e094da --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request.md @@ -0,0 +1,19 @@ +--- +name: "\U0001F680 Feature Request" +about: As a user, I want to request a New Feature on the product. +title: '' +labels: feature request +assignees: '' + +--- + +## Feature Request + +**Is your feature request related to a problem? Please describe:** + + +**Describe the feature you'd like:** + + +**Describe alternatives you've considered:** + diff --git a/.github/ISSUE_TEMPLATE/question.md b/.github/ISSUE_TEMPLATE/question.md new file mode 100644 index 000000000..445905c61 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/question.md @@ -0,0 +1,19 @@ +--- +name: "\U0001F914 Ask a Question" +about: I want to ask a question. +title: '' +labels: Question +assignees: '' + +--- + +## General Question + + diff --git a/README.md b/README.md index 7f10fc02e..acbe12309 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,3 @@ - ([简体中文](./README_cn.md)|English)

@@ -535,7 +534,7 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r - Acoustic Model + Acoustic Model Tacotron2 LJSpeech / CSMSC @@ -558,9 +557,16 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r FastSpeech2 - LJSpeech / VCTK / CSMSC / AISHELL-3 / ZH_EN + LJSpeech / VCTK / CSMSC / AISHELL-3 / ZH_EN / finetune + + fastspeech2-ljspeech / fastspeech2-vctk / fastspeech2-csmsc / fastspeech2-aishell3 / fastspeech2-zh_en / fastspeech2-finetune + + + + ERNIE-SAT + VCTK / AISHELL-3 / ZH_EN - fastspeech2-ljspeech / fastspeech2-vctk / fastspeech2-csmsc / fastspeech2-aishell3 / fastspeech2-zh_en + ERNIE-SAT-vctk / ERNIE-SAT-aishell3 / ERNIE-SAT-zh_en @@ -793,40 +799,73 @@ You are warmly welcome to submit questions in [discussions](https://github.com/P ### Contributors

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

## Acknowledgement diff --git a/README_cn.md b/README_cn.md index b4bd53f36..dbbc13ac0 100644 --- a/README_cn.md +++ b/README_cn.md @@ -1,4 +1,3 @@ - (简体中文|[English](./README.md))

@@ -530,7 +529,7 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声 - 声学模型 + 声学模型 Tacotron2 LJSpeech / CSMSC @@ -553,9 +552,16 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声 FastSpeech2 - LJSpeech / VCTK / CSMSC / AISHELL-3 / ZH_EN + LJSpeech / VCTK / CSMSC / AISHELL-3 / ZH_EN / finetune + + fastspeech2-ljspeech / fastspeech2-vctk / fastspeech2-csmsc / fastspeech2-aishell3 / fastspeech2-zh_en / fastspeech2-finetune + + + + ERNIE-SAT + VCTK / AISHELL-3 / ZH_EN - fastspeech2-ljspeech / fastspeech2-vctk / fastspeech2-csmsc / fastspeech2-aishell3 / fastspeech2-zh_en + ERNIE-SAT-vctk / ERNIE-SAT-aishell3 / ERNIE-SAT-zh_en @@ -797,40 +803,73 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声 ### 贡献者

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

## 致谢 diff --git a/examples/ernie_sat/.meta/framework.png b/examples/ernie_sat/.meta/framework.png deleted file mode 100644 index c68f62467..000000000 Binary files a/examples/ernie_sat/.meta/framework.png and /dev/null differ diff --git a/examples/ernie_sat/README.md b/examples/ernie_sat/README.md deleted file mode 100644 index d3bd13372..000000000 --- a/examples/ernie_sat/README.md +++ /dev/null @@ -1,137 +0,0 @@ -ERNIE-SAT 是可以同时处理中英文的跨语言的语音-语言跨模态大模型,其在语音编辑、个性化语音合成以及跨语言的语音合成等多个任务取得了领先效果。可以应用于语音编辑、个性化合成、语音克隆、同传翻译等一系列场景,该项目供研究使用。 - -## 模型框架 -ERNIE-SAT 中我们提出了两项创新: -- 在预训练过程中将中英双语对应的音素作为输入,实现了跨语言、个性化的软音素映射 -- 采用语言和语音的联合掩码学习实现了语言和语音的对齐 - -[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3lOXKJXE-1655380879339)(.meta/framework.png)] - -## 使用说明 - -### 1.安装飞桨与环境依赖 - -- 本项目的代码基于 Paddle(version>=2.0) -- 本项目开放提供加载 torch 版本的 vocoder 的功能 - - torch version>=1.8 - -- 安装 htk: 在[官方地址](https://htk.eng.cam.ac.uk/)注册完成后,即可进行下载较新版本的 htk (例如 3.4.1)。同时提供[历史版本 htk 下载地址](https://htk.eng.cam.ac.uk/ftp/software/) - - - 1.注册账号,下载 htk - - 2.解压 htk 文件,**放入项目根目录的 tools 文件夹中, 以 htk 文件夹名称放入** - - 3.**注意**: 如果您下载的是 3.4.1 或者更高版本, 需要进入 HTKLib/HRec.c 文件中, **修改 1626 行和 1650 行**, 即把**以下两行的 dur<=0 都修改为 dur<0**,如下所示: - ```bash - 以htk3.4.1版本举例: - (1)第1626行: if (dur<=0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<=0"); - 修改为: if (dur<0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<0"); - - (2)1650行: if (dur<=0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<=0 "); - 修改为: if (dur<0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<0 "); - ``` - - 4.**编译**: 详情参见解压后的 htk 中的 README 文件(如果未编译, 则无法正常运行) - - - -- 安装 ParallelWaveGAN: 参见[官方地址](https://github.com/kan-bayashi/ParallelWaveGAN):按照该官方链接的安装流程,直接在**项目的根目录下** git clone ParallelWaveGAN 项目并且安装相关依赖即可。 - - -- 安装其他依赖: **sox, libsndfile**等 - -### 2.预训练模型 -预训练模型 ERNIE-SAT 的模型如下所示: -- [ERNIE-SAT_ZH](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-zh.tar.gz) -- [ERNIE-SAT_EN](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-en.tar.gz) -- [ERNIE-SAT_ZH_and_EN](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-en_zh.tar.gz) - - -创建 pretrained_model 文件夹,下载上述 ERNIE-SAT 预训练模型并将其解压: -```bash -mkdir pretrained_model -cd pretrained_model -tar -zxvf model-ernie-sat-base-en.tar.gz -tar -zxvf model-ernie-sat-base-zh.tar.gz -tar -zxvf model-ernie-sat-base-en_zh.tar.gz -``` - -### 3.下载 - -1. 本项目使用 parallel wavegan 作为声码器(vocoder): - - [pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip) - - 创建 download 文件夹,下载上述预训练的声码器(vocoder)模型并将其解压: - - ```bash - mkdir download - cd download - unzip pwg_aishell3_ckpt_0.5.zip - ``` - -2. 本项目使用 [FastSpeech2](https://arxiv.org/abs/2006.04558) 作为音素(phoneme)的持续时间预测器: - - [fastspeech2_conformer_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip) 中文场景下使用 - - [fastspeech2_nosil_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip) 英文场景下使用 - - 下载上述预训练的 fastspeech2 模型并将其解压: - - ```bash - cd download - unzip fastspeech2_conformer_baker_ckpt_0.5.zip - unzip fastspeech2_nosil_ljspeech_ckpt_0.5.zip - ``` - -3. 本项目使用 HTK 获取输入音频和文本的对齐信息: - - - [aligner.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/aligner.zip) - - 下载上述文件到 tools 文件夹并将其解压: - ```bash - cd tools - unzip aligner.zip - ``` - - -### 4.推理 - -本项目当前开源了语音编辑、个性化语音合成、跨语言语音合成的推理代码,后续会逐步开源。 -注:当前英文场下的合成语音采用的声码器默认为 vctk_parallel_wavegan.v1.long, 可在[该链接](https://github.com/kan-bayashi/ParallelWaveGAN)中找到; 若 use_pt_vocoder 参数设置为 False,则英文场景下使用 paddle 版本的声码器。 - -我们提供特定音频文件, 以及其对应的文本、音素相关文件: -- prompt_wav: 提供的音频文件 -- prompt/dev: 基于上述特定音频对应的文本、音素相关文件 - - -```text -prompt_wav -├── p299_096.wav # 样例语音文件1 -├── p243_313.wav # 样例语音文件2 -└── ... -``` - -```text -prompt/dev -├── text # 样例语音对应文本 -├── wav.scp # 样例语音路径 -├── mfa_text # 样例语音对应音素 -├── mfa_start # 样例语音中各个音素的开始时间 -└── mfa_end # 样例语音中各个音素的结束时间 -``` -1. `--am` 声学模型格式符合 {model_name}_{dataset} -2. `--am_config`, `--am_checkpoint`, `--am_stat` 和 `--phones_dict` 是声学模型的参数,对应于 fastspeech2 预训练模型中的 4 个文件。 -3. `--voc` 声码器(vocoder)格式是否符合 {model_name}_{dataset} -4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` 是声码器的参数,对应于 parallel wavegan 预训练模型中的 3 个文件。 -5. `--lang` 对应模型的语言可以是 `zh` 或 `en` 。 -6. `--ngpu` 要使用的 GPU 数,如果 ngpu==0,则使用 cpu。 -7. `--model_name` 模型名称 -8. `--uid` 特定提示(prompt)语音的 id -9. `--new_str` 输入的文本(本次开源暂时先设置特定的文本) -10. `--prefix` 特定音频对应的文本、音素相关文件的地址 -11. `--source_lang` , 源语言 -12. `--target_lang` , 目标语言 -13. `--output_name` , 合成语音名称 -14. `--task_name` , 任务名称, 包括:语音编辑任务、个性化语音合成任务、跨语言语音合成任务 - -运行以下脚本即可进行实验 -```shell -./run_sedit_en.sh # 语音编辑任务(英文) -./run_gen_en.sh # 个性化语音合成任务(英文) -./run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆) -``` diff --git a/examples/ernie_sat/local/align.py b/examples/ernie_sat/local/align.py deleted file mode 100755 index ff47cac5b..000000000 --- a/examples/ernie_sat/local/align.py +++ /dev/null @@ -1,454 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Usage: - align.py wavfile trsfile outwordfile outphonefile -""" -import os -import sys - -PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme' -MODEL_DIR_EN = 'tools/aligner/english' -MODEL_DIR_ZH = 'tools/aligner/mandarin' -HVITE = 'tools/htk/HTKTools/HVite' -HCOPY = 'tools/htk/HTKTools/HCopy' - - -def get_unk_phns(word_str: str): - tmpbase = '/tmp/tp.' - f = open(tmpbase + 'temp.words', 'w') - f.write(word_str) - f.close() - os.system(PHONEME + ' ' + tmpbase + 'temp.words' + ' ' + tmpbase + - 'temp.phons') - f = open(tmpbase + 'temp.phons', 'r') - lines2 = f.readline().strip().split() - f.close() - phns = [] - for phn in lines2: - phons = phn.replace('\n', '').replace(' ', '') - seq = [] - j = 0 - while (j < len(phons)): - if (phons[j] > 'Z'): - if (phons[j] == 'j'): - seq.append('JH') - elif (phons[j] == 'h'): - seq.append('HH') - else: - seq.append(phons[j].upper()) - j += 1 - else: - p = phons[j:j + 2] - if (p == 'WH'): - seq.append('W') - elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']): - seq.append(p) - elif (p == 'AX'): - seq.append('AH0') - else: - seq.append(p + '1') - j += 2 - phns.extend(seq) - return phns - - -def words2phns(line: str): - ''' - Args: - line (str): input text. - eg: for that reason cover is impossible to be given. - Returns: - List[str]: phones of input text. - eg: - ['F', 'AO1', 'R', 'DH', 'AE1', 'T', 'R', 'IY1', 'Z', 'AH0', 'N', 'K', 'AH1', 'V', 'ER0', - 'IH1', 'Z', 'IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L', 'T', 'UW1', 'B', 'IY1', - 'G', 'IH1', 'V', 'AH0', 'N'] - - Dict(str, str): key - idx_word - value - phones - eg: - {'0_FOR': ['F', 'AO1', 'R'], '1_THAT': ['DH', 'AE1', 'T'], '2_REASON': ['R', 'IY1', 'Z', 'AH0', 'N'], - '3_COVER': ['K', 'AH1', 'V', 'ER0'], '4_IS': ['IH1', 'Z'], '5_IMPOSSIBLE': ['IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L'], - '6_TO': ['T', 'UW1'], '7_BE': ['B', 'IY1'], '8_GIVEN': ['G', 'IH1', 'V', 'AH0', 'N']} - ''' - dictfile = MODEL_DIR_EN + '/dict' - line = line.strip() - words = [] - for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---']: - line = line.replace(pun, ' ') - for wrd in line.split(): - if (wrd[-1] == '-'): - wrd = wrd[:-1] - if (wrd[0] == "'"): - wrd = wrd[1:] - if wrd: - words.append(wrd) - ds = set([]) - word2phns_dict = {} - with open(dictfile, 'r') as fid: - for line in fid: - word = line.split()[0] - ds.add(word) - if word not in word2phns_dict.keys(): - word2phns_dict[word] = " ".join(line.split()[1:]) - - phns = [] - wrd2phns = {} - for index, wrd in enumerate(words): - if wrd == '[MASK]': - wrd2phns[str(index) + "_" + wrd] = [wrd] - phns.append(wrd) - elif (wrd.upper() not in ds): - wrd2phns[str(index) + "_" + wrd.upper()] = get_unk_phns(wrd) - phns.extend(get_unk_phns(wrd)) - else: - wrd2phns[str(index) + - "_" + wrd.upper()] = word2phns_dict[wrd.upper()].split() - phns.extend(word2phns_dict[wrd.upper()].split()) - return phns, wrd2phns - - -def words2phns_zh(line: str): - dictfile = MODEL_DIR_ZH + '/dict' - line = line.strip() - words = [] - for pun in [ - ',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',', - u'。', u':', u';', u'!', u'?', u'(', u')' - ]: - line = line.replace(pun, ' ') - for wrd in line.split(): - if (wrd[-1] == '-'): - wrd = wrd[:-1] - if (wrd[0] == "'"): - wrd = wrd[1:] - if wrd: - words.append(wrd) - - ds = set([]) - word2phns_dict = {} - with open(dictfile, 'r') as fid: - for line in fid: - word = line.split()[0] - ds.add(word) - if word not in word2phns_dict.keys(): - word2phns_dict[word] = " ".join(line.split()[1:]) - - phns = [] - wrd2phns = {} - for index, wrd in enumerate(words): - if wrd == '[MASK]': - wrd2phns[str(index) + "_" + wrd] = [wrd] - phns.append(wrd) - elif (wrd.upper() not in ds): - print("出现非法词错误,请输入正确的文本...") - else: - wrd2phns[str(index) + "_" + wrd] = word2phns_dict[wrd].split() - phns.extend(word2phns_dict[wrd].split()) - - return phns, wrd2phns - - -def prep_txt_zh(line: str, tmpbase: str, dictfile: str): - - words = [] - line = line.strip() - for pun in [ - ',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',', - u'。', u':', u';', u'!', u'?', u'(', u')' - ]: - line = line.replace(pun, ' ') - for wrd in line.split(): - if (wrd[-1] == '-'): - wrd = wrd[:-1] - if (wrd[0] == "'"): - wrd = wrd[1:] - if wrd: - words.append(wrd) - - ds = set([]) - with open(dictfile, 'r') as fid: - for line in fid: - ds.add(line.split()[0]) - - unk_words = set([]) - with open(tmpbase + '.txt', 'w') as fwid: - for wrd in words: - if (wrd not in ds): - unk_words.add(wrd) - fwid.write(wrd + ' ') - fwid.write('\n') - return unk_words - - -def prep_txt_en(line: str, tmpbase, dictfile): - - words = [] - - line = line.strip() - for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---']: - line = line.replace(pun, ' ') - for wrd in line.split(): - if (wrd[-1] == '-'): - wrd = wrd[:-1] - if (wrd[0] == "'"): - wrd = wrd[1:] - if wrd: - words.append(wrd) - - ds = set([]) - with open(dictfile, 'r') as fid: - for line in fid: - ds.add(line.split()[0]) - - unk_words = set([]) - with open(tmpbase + '.txt', 'w') as fwid: - for wrd in words: - if (wrd.upper() not in ds): - unk_words.add(wrd.upper()) - fwid.write(wrd + ' ') - fwid.write('\n') - - #generate pronounciations for unknows words using 'letter to sound' - with open(tmpbase + '_unk.words', 'w') as fwid: - for unk in unk_words: - fwid.write(unk + '\n') - try: - os.system(PHONEME + ' ' + tmpbase + '_unk.words' + ' ' + tmpbase + - '_unk.phons') - except Exception: - print('english2phoneme error!') - sys.exit(1) - - #add unknown words to the standard dictionary, generate a tmp dictionary for alignment - fw = open(tmpbase + '.dict', 'w') - with open(dictfile, 'r') as fid: - for line in fid: - fw.write(line) - f = open(tmpbase + '_unk.words', 'r') - lines1 = f.readlines() - f.close() - f = open(tmpbase + '_unk.phons', 'r') - lines2 = f.readlines() - f.close() - for i in range(len(lines1)): - wrd = lines1[i].replace('\n', '') - phons = lines2[i].replace('\n', '').replace(' ', '') - seq = [] - j = 0 - while (j < len(phons)): - if (phons[j] > 'Z'): - if (phons[j] == 'j'): - seq.append('JH') - elif (phons[j] == 'h'): - seq.append('HH') - else: - seq.append(phons[j].upper()) - j += 1 - else: - p = phons[j:j + 2] - if (p == 'WH'): - seq.append('W') - elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']): - seq.append(p) - elif (p == 'AX'): - seq.append('AH0') - else: - seq.append(p + '1') - j += 2 - - fw.write(wrd + ' ') - for s in seq: - fw.write(' ' + s) - fw.write('\n') - fw.close() - - -def prep_mlf(txt: str, tmpbase: str): - - with open(tmpbase + '.mlf', 'w') as fwid: - fwid.write('#!MLF!#\n') - fwid.write('"' + tmpbase + '.lab"\n') - fwid.write('sp\n') - wrds = txt.split() - for wrd in wrds: - fwid.write(wrd.upper() + '\n') - fwid.write('sp\n') - fwid.write('.\n') - - -def _get_user(): - return os.path.expanduser('~').split("/")[-1] - - -def alignment(wav_path: str, text: str): - ''' - intervals: List[phn, start, end] - ''' - tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid()) - - #prepare wav and trs files - try: - os.system('sox ' + wav_path + ' -r 16000 ' + tmpbase + '.wav remix -') - except Exception: - print('sox error!') - return None - - #prepare clean_transcript file - try: - prep_txt_en(line=text, tmpbase=tmpbase, dictfile=MODEL_DIR_EN + '/dict') - except Exception: - print('prep_txt error!') - return None - - #prepare mlf file - try: - with open(tmpbase + '.txt', 'r') as fid: - txt = fid.readline() - prep_mlf(txt, tmpbase) - except Exception: - print('prep_mlf error!') - return None - - #prepare scp - try: - os.system(HCOPY + ' -C ' + MODEL_DIR_EN + '/16000/config ' + tmpbase + - '.wav' + ' ' + tmpbase + '.plp') - except Exception: - print('HCopy error!') - return None - - #run alignment - try: - os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase + - '.mlf -H ' + MODEL_DIR_EN + '/16000/macros -H ' + MODEL_DIR_EN - + '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase + - '.dict ' + MODEL_DIR_EN + '/monophones ' + tmpbase + - '.plp 2>&1 > /dev/null') - except Exception: - print('HVite error!') - return None - - with open(tmpbase + '.txt', 'r') as fid: - words = fid.readline().strip().split() - words = txt.strip().split() - words.reverse() - - with open(tmpbase + '.aligned', 'r') as fid: - lines = fid.readlines() - i = 2 - intervals = [] - word2phns = {} - current_word = '' - index = 0 - while (i < len(lines)): - splited_line = lines[i].strip().split() - if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]): - phn = splited_line[2] - pst = (int(splited_line[0]) / 1000 + 125) / 10000 - pen = (int(splited_line[1]) / 1000 + 125) / 10000 - intervals.append([phn, pst, pen]) - # splited_line[-1]!='sp' - if len(splited_line) == 5: - current_word = str(index) + '_' + splited_line[-1] - word2phns[current_word] = phn - index += 1 - elif len(splited_line) == 4: - word2phns[current_word] += ' ' + phn - i += 1 - return intervals, word2phns - - -def alignment_zh(wav_path: str, text: str): - tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid()) - - #prepare wav and trs files - try: - os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase + - '.wav remix -') - - except Exception: - print('sox error!') - return None - - #prepare clean_transcript file - try: - unk_words = prep_txt_zh( - line=text, tmpbase=tmpbase, dictfile=MODEL_DIR_ZH + '/dict') - if unk_words: - print('Error! Please add the following words to dictionary:') - for unk in unk_words: - print("非法words: ", unk) - except Exception: - print('prep_txt error!') - return None - - #prepare mlf file - try: - with open(tmpbase + '.txt', 'r') as fid: - txt = fid.readline() - prep_mlf(txt, tmpbase) - except Exception: - print('prep_mlf error!') - return None - - #prepare scp - try: - os.system(HCOPY + ' -C ' + MODEL_DIR_ZH + '/16000/config ' + tmpbase + - '.wav' + ' ' + tmpbase + '.plp') - except Exception: - print('HCopy error!') - return None - - #run alignment - try: - os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase + - '.mlf -H ' + MODEL_DIR_ZH + '/16000/macros -H ' + MODEL_DIR_ZH - + '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + MODEL_DIR_ZH - + '/dict ' + MODEL_DIR_ZH + '/monophones ' + tmpbase + - '.plp 2>&1 > /dev/null') - - except Exception: - print('HVite error!') - return None - - with open(tmpbase + '.txt', 'r') as fid: - words = fid.readline().strip().split() - words = txt.strip().split() - words.reverse() - - with open(tmpbase + '.aligned', 'r') as fid: - lines = fid.readlines() - - i = 2 - intervals = [] - word2phns = {} - current_word = '' - index = 0 - while (i < len(lines)): - splited_line = lines[i].strip().split() - if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]): - phn = splited_line[2] - pst = (int(splited_line[0]) / 1000 + 125) / 10000 - pen = (int(splited_line[1]) / 1000 + 125) / 10000 - intervals.append([phn, pst, pen]) - # splited_line[-1]!='sp' - if len(splited_line) == 5: - current_word = str(index) + '_' + splited_line[-1] - word2phns[current_word] = phn - index += 1 - elif len(splited_line) == 4: - word2phns[current_word] += ' ' + phn - i += 1 - return intervals, word2phns diff --git a/examples/ernie_sat/local/inference.py b/examples/ernie_sat/local/inference.py deleted file mode 100644 index e6a0788fd..000000000 --- a/examples/ernie_sat/local/inference.py +++ /dev/null @@ -1,609 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import random -from typing import Dict -from typing import List - -import librosa -import numpy as np -import paddle -import soundfile as sf -from align import alignment -from align import alignment_zh -from align import words2phns -from align import words2phns_zh -from paddle import nn -from sedit_arg_parser import parse_args -from utils import eval_durs -from utils import get_voc_out -from utils import is_chinese -from utils import load_num_sequence_text -from utils import read_2col_text - -from paddlespeech.t2s.datasets.am_batch_fn import build_mlm_collate_fn -from paddlespeech.t2s.models.ernie_sat.mlm import build_model_from_file - -random.seed(0) -np.random.seed(0) - - -def get_wav(wav_path: str, - source_lang: str='english', - target_lang: str='english', - model_name: str="paddle_checkpoint_en", - old_str: str="", - new_str: str="", - non_autoreg: bool=True): - wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length = get_mlm_output( - source_lang=source_lang, - target_lang=target_lang, - model_name=model_name, - wav_path=wav_path, - old_str=old_str, - new_str=new_str, - use_teacher_forcing=non_autoreg) - - masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]] - - alt_wav = get_voc_out(masked_feat) - - old_time_bdy = [hop_length * x for x in old_span_bdy] - - wav_replaced = np.concatenate( - [wav_org[:old_time_bdy[0]], alt_wav, wav_org[old_time_bdy[1]:]]) - - data_dict = {"origin": wav_org, "output": wav_replaced} - - return data_dict - - -def load_model(model_name: str="paddle_checkpoint_en"): - config_path = './pretrained_model/{}/config.yaml'.format(model_name) - model_path = './pretrained_model/{}/model.pdparams'.format(model_name) - mlm_model, conf = build_model_from_file( - config_file=config_path, model_file=model_path) - return mlm_model, conf - - -def read_data(uid: str, prefix: os.PathLike): - # 获取 uid 对应的文本 - mfa_text = read_2col_text(prefix + '/text')[uid] - # 获取 uid 对应的音频路径 - mfa_wav_path = read_2col_text(prefix + '/wav.scp')[uid] - if not os.path.isabs(mfa_wav_path): - mfa_wav_path = prefix + mfa_wav_path - return mfa_text, mfa_wav_path - - -def get_align_data(uid: str, prefix: os.PathLike): - mfa_path = prefix + "mfa_" - mfa_text = read_2col_text(mfa_path + 'text')[uid] - mfa_start = load_num_sequence_text( - mfa_path + 'start', loader_type='text_float')[uid] - mfa_end = load_num_sequence_text( - mfa_path + 'end', loader_type='text_float')[uid] - mfa_wav_path = read_2col_text(mfa_path + 'wav.scp')[uid] - return mfa_text, mfa_start, mfa_end, mfa_wav_path - - -# 获取需要被 mask 的 mel 帧的范围 -def get_masked_mel_bdy(mfa_start: List[float], - mfa_end: List[float], - fs: int, - hop_length: int, - span_to_repl: List[List[int]]): - align_start = np.array(mfa_start) - align_end = np.array(mfa_end) - align_start = np.floor(fs * align_start / hop_length).astype('int') - align_end = np.floor(fs * align_end / hop_length).astype('int') - if span_to_repl[0] >= len(mfa_start): - span_bdy = [align_end[-1], align_end[-1]] - else: - span_bdy = [ - align_start[span_to_repl[0]], align_end[span_to_repl[1] - 1] - ] - return span_bdy, align_start, align_end - - -def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]): - dic = {} - keys_to_del = [] - exist_idx = [] - sp_count = 0 - add_sp_count = 0 - for key in word2phns.keys(): - idx, wrd = key.split('_') - if wrd == 'sp': - sp_count += 1 - exist_idx.append(int(idx)) - else: - keys_to_del.append(key) - - for key in keys_to_del: - del word2phns[key] - - cur_id = 0 - for key in tp_word2phns.keys(): - if cur_id in exist_idx: - dic[str(cur_id) + "_sp"] = 'sp' - cur_id += 1 - add_sp_count += 1 - idx, wrd = key.split('_') - dic[str(cur_id) + "_" + wrd] = tp_word2phns[key] - cur_id += 1 - - if add_sp_count + 1 == sp_count: - dic[str(cur_id) + "_sp"] = 'sp' - add_sp_count += 1 - - assert add_sp_count == sp_count, "sp are not added in dic" - return dic - - -def get_max_idx(dic): - return sorted([int(key.split('_')[0]) for key in dic.keys()])[-1] - - -def get_phns_and_spans(wav_path: str, - old_str: str="", - new_str: str="", - source_lang: str="english", - target_lang: str="english"): - is_append = (old_str == new_str[:len(old_str)]) - old_phns, mfa_start, mfa_end = [], [], [] - # source - if source_lang == "english": - intervals, word2phns = alignment(wav_path, old_str) - elif source_lang == "chinese": - intervals, word2phns = alignment_zh(wav_path, old_str) - _, tp_word2phns = words2phns_zh(old_str) - - for key, value in tp_word2phns.items(): - idx, wrd = key.split('_') - cur_val = " ".join(value) - tp_word2phns[key] = cur_val - - word2phns = recover_dict(word2phns, tp_word2phns) - else: - assert source_lang == "chinese" or source_lang == "english", \ - "source_lang is wrong..." - - for item in intervals: - old_phns.append(item[0]) - mfa_start.append(float(item[1])) - mfa_end.append(float(item[2])) - # target - if is_append and (source_lang != target_lang): - cross_lingual_clone = True - else: - cross_lingual_clone = False - - if cross_lingual_clone: - str_origin = new_str[:len(old_str)] - str_append = new_str[len(old_str):] - - if target_lang == "chinese": - phns_origin, origin_word2phns = words2phns(str_origin) - phns_append, append_word2phns_tmp = words2phns_zh(str_append) - - elif target_lang == "english": - # 原始句子 - phns_origin, origin_word2phns = words2phns_zh(str_origin) - # clone 句子 - phns_append, append_word2phns_tmp = words2phns(str_append) - else: - assert target_lang == "chinese" or target_lang == "english", \ - "cloning is not support for this language, please check it." - - new_phns = phns_origin + phns_append - - append_word2phns = {} - length = len(origin_word2phns) - for key, value in append_word2phns_tmp.items(): - idx, wrd = key.split('_') - append_word2phns[str(int(idx) + length) + '_' + wrd] = value - new_word2phns = origin_word2phns.copy() - new_word2phns.update(append_word2phns) - - else: - if source_lang == target_lang and target_lang == "english": - new_phns, new_word2phns = words2phns(new_str) - elif source_lang == target_lang and target_lang == "chinese": - new_phns, new_word2phns = words2phns_zh(new_str) - else: - assert source_lang == target_lang, \ - "source language is not same with target language..." - - span_to_repl = [0, len(old_phns) - 1] - span_to_add = [0, len(new_phns) - 1] - left_idx = 0 - new_phns_left = [] - sp_count = 0 - # find the left different index - for key in word2phns.keys(): - idx, wrd = key.split('_') - if wrd == 'sp': - sp_count += 1 - new_phns_left.append('sp') - else: - idx = str(int(idx) - sp_count) - if idx + '_' + wrd in new_word2phns: - left_idx += len(new_word2phns[idx + '_' + wrd]) - new_phns_left.extend(word2phns[key].split()) - else: - span_to_repl[0] = len(new_phns_left) - span_to_add[0] = len(new_phns_left) - break - - # reverse word2phns and new_word2phns - right_idx = 0 - new_phns_right = [] - sp_count = 0 - word2phns_max_idx = get_max_idx(word2phns) - new_word2phns_max_idx = get_max_idx(new_word2phns) - new_phns_mid = [] - if is_append: - new_phns_right = [] - new_phns_mid = new_phns[left_idx:] - span_to_repl[0] = len(new_phns_left) - span_to_add[0] = len(new_phns_left) - span_to_add[1] = len(new_phns_left) + len(new_phns_mid) - span_to_repl[1] = len(old_phns) - len(new_phns_right) - # speech edit - else: - for key in list(word2phns.keys())[::-1]: - idx, wrd = key.split('_') - if wrd == 'sp': - sp_count += 1 - new_phns_right = ['sp'] + new_phns_right - else: - idx = str(new_word2phns_max_idx - (word2phns_max_idx - int(idx) - - sp_count)) - if idx + '_' + wrd in new_word2phns: - right_idx -= len(new_word2phns[idx + '_' + wrd]) - new_phns_right = word2phns[key].split() + new_phns_right - else: - span_to_repl[1] = len(old_phns) - len(new_phns_right) - new_phns_mid = new_phns[left_idx:right_idx] - span_to_add[1] = len(new_phns_left) + len(new_phns_mid) - if len(new_phns_mid) == 0: - span_to_add[1] = min(span_to_add[1] + 1, len(new_phns)) - span_to_add[0] = max(0, span_to_add[0] - 1) - span_to_repl[0] = max(0, span_to_repl[0] - 1) - span_to_repl[1] = min(span_to_repl[1] + 1, - len(old_phns)) - break - new_phns = new_phns_left + new_phns_mid + new_phns_right - ''' - For that reason cover should not be given. - For that reason cover is impossible to be given. - span_to_repl: [17, 23] "should not" - span_to_add: [17, 30] "is impossible to" - ''' - return mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add - - -# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同 -# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放 -def get_dur_adj_factor(orig_dur: List[int], - pred_dur: List[int], - phns: List[str]): - length = 0 - factor_list = [] - for orig, pred, phn in zip(orig_dur, pred_dur, phns): - if pred == 0 or phn == 'sp': - continue - else: - factor_list.append(orig / pred) - factor_list = np.array(factor_list) - factor_list.sort() - if len(factor_list) < 5: - return 1 - length = 2 - avg = np.average(factor_list[length:-length]) - return avg - - -def prep_feats_with_dur(wav_path: str, - source_lang: str="English", - target_lang: str="English", - old_str: str="", - new_str: str="", - mask_reconstruct: bool=False, - duration_adjust: bool=True, - start_end_sp: bool=False, - fs: int=24000, - hop_length: int=300): - ''' - Returns: - np.ndarray: new wav, replace the part to be edited in original wav with 0 - List[str]: new phones - List[float]: mfa start of new wav - List[float]: mfa end of new wav - List[int]: masked mel boundary of original wav - List[int]: masked mel boundary of new wav - ''' - wav_org, _ = librosa.load(wav_path, sr=fs) - - mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add = get_phns_and_spans( - wav_path=wav_path, - old_str=old_str, - new_str=new_str, - source_lang=source_lang, - target_lang=target_lang) - - if start_end_sp: - if new_phns[-1] != 'sp': - new_phns = new_phns + ['sp'] - # 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替 - if target_lang == "english" or target_lang == "chinese": - old_durs = eval_durs(old_phns, target_lang=source_lang) - else: - assert target_lang == "chinese" or target_lang == "english", \ - "calculate duration_predict is not support for this language..." - - orig_old_durs = [e - s for e, s in zip(mfa_end, mfa_start)] - if '[MASK]' in new_str: - new_phns = old_phns - span_to_add = span_to_repl - d_factor_left = get_dur_adj_factor( - orig_dur=orig_old_durs[:span_to_repl[0]], - pred_dur=old_durs[:span_to_repl[0]], - phns=old_phns[:span_to_repl[0]]) - d_factor_right = get_dur_adj_factor( - orig_dur=orig_old_durs[span_to_repl[1]:], - pred_dur=old_durs[span_to_repl[1]:], - phns=old_phns[span_to_repl[1]:]) - d_factor = (d_factor_left + d_factor_right) / 2 - new_durs_adjusted = [d_factor * i for i in old_durs] - else: - if duration_adjust: - d_factor = get_dur_adj_factor( - orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns) - d_factor = d_factor * 1.25 - else: - d_factor = 1 - - if target_lang == "english" or target_lang == "chinese": - new_durs = eval_durs(new_phns, target_lang=target_lang) - else: - assert target_lang == "chinese" or target_lang == "english", \ - "calculate duration_predict is not support for this language..." - - new_durs_adjusted = [d_factor * i for i in new_durs] - - new_span_dur_sum = sum(new_durs_adjusted[span_to_add[0]:span_to_add[1]]) - old_span_dur_sum = sum(orig_old_durs[span_to_repl[0]:span_to_repl[1]]) - dur_offset = new_span_dur_sum - old_span_dur_sum - new_mfa_start = mfa_start[:span_to_repl[0]] - new_mfa_end = mfa_end[:span_to_repl[0]] - for i in new_durs_adjusted[span_to_add[0]:span_to_add[1]]: - if len(new_mfa_end) == 0: - new_mfa_start.append(0) - new_mfa_end.append(i) - else: - new_mfa_start.append(new_mfa_end[-1]) - new_mfa_end.append(new_mfa_end[-1] + i) - new_mfa_start += [i + dur_offset for i in mfa_start[span_to_repl[1]:]] - new_mfa_end += [i + dur_offset for i in mfa_end[span_to_repl[1]:]] - - # 3. get new wav - # 在原始句子后拼接 - if span_to_repl[0] >= len(mfa_start): - left_idx = len(wav_org) - right_idx = left_idx - # 在原始句子中间替换 - else: - left_idx = int(np.floor(mfa_start[span_to_repl[0]] * fs)) - right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * fs)) - blank_wav = np.zeros( - (int(np.ceil(new_span_dur_sum * fs)), ), dtype=wav_org.dtype) - # 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定 - new_wav = np.concatenate( - [wav_org[:left_idx], blank_wav, wav_org[right_idx:]]) - - # 4. get old and new mel span to be mask - # [92, 92] - - old_span_bdy, mfa_start, mfa_end = get_masked_mel_bdy( - mfa_start=mfa_start, - mfa_end=mfa_end, - fs=fs, - hop_length=hop_length, - span_to_repl=span_to_repl) - # [92, 174] - # new_mfa_start, new_mfa_end 时间级别的开始和结束时间 -> 帧级别 - new_span_bdy, new_mfa_start, new_mfa_end = get_masked_mel_bdy( - mfa_start=new_mfa_start, - mfa_end=new_mfa_end, - fs=fs, - hop_length=hop_length, - span_to_repl=span_to_add) - - # old_span_bdy, new_span_bdy 是帧级别的范围 - return new_wav, new_phns, new_mfa_start, new_mfa_end, old_span_bdy, new_span_bdy - - -def prep_feats(wav_path: str, - source_lang: str="english", - target_lang: str="english", - old_str: str="", - new_str: str="", - duration_adjust: bool=True, - start_end_sp: bool=False, - mask_reconstruct: bool=False, - fs: int=24000, - hop_length: int=300, - token_list: List[str]=[]): - wav, phns, mfa_start, mfa_end, old_span_bdy, new_span_bdy = prep_feats_with_dur( - source_lang=source_lang, - target_lang=target_lang, - old_str=old_str, - new_str=new_str, - wav_path=wav_path, - duration_adjust=duration_adjust, - start_end_sp=start_end_sp, - mask_reconstruct=mask_reconstruct, - fs=fs, - hop_length=hop_length) - - token_to_id = {item: i for i, item in enumerate(token_list)} - text = np.array( - list(map(lambda x: token_to_id.get(x, token_to_id['']), phns))) - span_bdy = np.array(new_span_bdy) - - batch = [('1', { - "speech": wav, - "align_start": mfa_start, - "align_end": mfa_end, - "text": text, - "span_bdy": span_bdy - })] - - return batch, old_span_bdy, new_span_bdy - - -def decode_with_model(mlm_model: nn.Layer, - collate_fn, - wav_path: str, - source_lang: str="english", - target_lang: str="english", - old_str: str="", - new_str: str="", - use_teacher_forcing: bool=False, - duration_adjust: bool=True, - start_end_sp: bool=False, - fs: int=24000, - hop_length: int=300, - token_list: List[str]=[]): - batch, old_span_bdy, new_span_bdy = prep_feats( - source_lang=source_lang, - target_lang=target_lang, - wav_path=wav_path, - old_str=old_str, - new_str=new_str, - duration_adjust=duration_adjust, - start_end_sp=start_end_sp, - fs=fs, - hop_length=hop_length, - token_list=token_list) - - feats = collate_fn(batch)[1] - - if 'text_masked_pos' in feats.keys(): - feats.pop('text_masked_pos') - - output = mlm_model.inference( - text=feats['text'], - speech=feats['speech'], - masked_pos=feats['masked_pos'], - speech_mask=feats['speech_mask'], - text_mask=feats['text_mask'], - speech_seg_pos=feats['speech_seg_pos'], - text_seg_pos=feats['text_seg_pos'], - span_bdy=new_span_bdy, - use_teacher_forcing=use_teacher_forcing) - - # 拼接音频 - output_feat = paddle.concat(x=output, axis=0) - wav_org, _ = librosa.load(wav_path, sr=fs) - return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length - - -def get_mlm_output(wav_path: str, - model_name: str="paddle_checkpoint_en", - source_lang: str="english", - target_lang: str="english", - old_str: str="", - new_str: str="", - use_teacher_forcing: bool=False, - duration_adjust: bool=True, - start_end_sp: bool=False): - mlm_model, train_conf = load_model(model_name) - mlm_model.eval() - - collate_fn = build_mlm_collate_fn( - sr=train_conf.feats_extract_conf['fs'], - n_fft=train_conf.feats_extract_conf['n_fft'], - hop_length=train_conf.feats_extract_conf['hop_length'], - win_length=train_conf.feats_extract_conf['win_length'], - n_mels=train_conf.feats_extract_conf['n_mels'], - fmin=train_conf.feats_extract_conf['fmin'], - fmax=train_conf.feats_extract_conf['fmax'], - mlm_prob=train_conf['mlm_prob'], - mean_phn_span=train_conf['mean_phn_span'], - seg_emb=train_conf.encoder_conf['input_layer'] == 'sega_mlm') - - return decode_with_model( - source_lang=source_lang, - target_lang=target_lang, - mlm_model=mlm_model, - collate_fn=collate_fn, - wav_path=wav_path, - old_str=old_str, - new_str=new_str, - use_teacher_forcing=use_teacher_forcing, - duration_adjust=duration_adjust, - start_end_sp=start_end_sp, - fs=train_conf.feats_extract_conf['fs'], - hop_length=train_conf.feats_extract_conf['hop_length'], - token_list=train_conf.token_list) - - -def evaluate(uid: str, - source_lang: str="english", - target_lang: str="english", - prefix: os.PathLike="./prompt/dev/", - model_name: str="paddle_checkpoint_en", - new_str: str="", - prompt_decoding: bool=False, - task_name: str=None): - - # get origin text and path of origin wav - old_str, wav_path = read_data(uid=uid, prefix=prefix) - - if task_name == 'edit': - new_str = new_str - elif task_name == 'synthesize': - new_str = old_str + new_str - else: - new_str = old_str + ' '.join([ch for ch in new_str if is_chinese(ch)]) - - print('new_str is ', new_str) - - results_dict = get_wav( - source_lang=source_lang, - target_lang=target_lang, - model_name=model_name, - wav_path=wav_path, - old_str=old_str, - new_str=new_str) - return results_dict - - -if __name__ == "__main__": - # parse config and args - args = parse_args() - - data_dict = evaluate( - uid=args.uid, - source_lang=args.source_lang, - target_lang=args.target_lang, - prefix=args.prefix, - model_name=args.model_name, - new_str=args.new_str, - task_name=args.task_name) - sf.write(args.output_name, data_dict['output'], samplerate=24000) - print("finished...") diff --git a/examples/ernie_sat/local/inference_new.py b/examples/ernie_sat/local/inference_new.py deleted file mode 100644 index 525967eb1..000000000 --- a/examples/ernie_sat/local/inference_new.py +++ /dev/null @@ -1,622 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import random -from typing import Dict -from typing import List - -import librosa -import numpy as np -import paddle -import soundfile as sf -import yaml -from align import alignment -from align import alignment_zh -from align import words2phns -from align import words2phns_zh -from paddle import nn -from sedit_arg_parser import parse_args -from utils import eval_durs -from utils import get_voc_out -from utils import is_chinese -from utils import load_num_sequence_text -from utils import read_2col_text -from yacs.config import CfgNode - -from paddlespeech.t2s.datasets.am_batch_fn import build_mlm_collate_fn -from paddlespeech.t2s.models.ernie_sat.ernie_sat import ErnieSAT - -random.seed(0) -np.random.seed(0) - - -def get_wav(wav_path: str, - source_lang: str='english', - target_lang: str='english', - model_name: str="paddle_checkpoint_en", - old_str: str="", - new_str: str="", - non_autoreg: bool=True): - wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length = get_mlm_output( - source_lang=source_lang, - target_lang=target_lang, - model_name=model_name, - wav_path=wav_path, - old_str=old_str, - new_str=new_str, - use_teacher_forcing=non_autoreg) - - masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]] - - alt_wav = get_voc_out(masked_feat) - - old_time_bdy = [hop_length * x for x in old_span_bdy] - - wav_replaced = np.concatenate( - [wav_org[:old_time_bdy[0]], alt_wav, wav_org[old_time_bdy[1]:]]) - - data_dict = {"origin": wav_org, "output": wav_replaced} - - return data_dict - - -def load_model(model_name: str="paddle_checkpoint_en"): - config_path = './pretrained_model/{}/default.yaml'.format(model_name) - model_path = './pretrained_model/{}/model.pdparams'.format(model_name) - with open(config_path) as f: - conf = CfgNode(yaml.safe_load(f)) - token_list = list(conf.token_list) - vocab_size = len(token_list) - odim = conf.n_mels - mlm_model = ErnieSAT(idim=vocab_size, odim=odim, **conf["model"]) - state_dict = paddle.load(model_path) - new_state_dict = {} - for key, value in state_dict.items(): - new_key = "model." + key - new_state_dict[new_key] = value - mlm_model.set_state_dict(new_state_dict) - mlm_model.eval() - - return mlm_model, conf - - -def read_data(uid: str, prefix: os.PathLike): - # 获取 uid 对应的文本 - mfa_text = read_2col_text(prefix + '/text')[uid] - # 获取 uid 对应的音频路径 - mfa_wav_path = read_2col_text(prefix + '/wav.scp')[uid] - if not os.path.isabs(mfa_wav_path): - mfa_wav_path = prefix + mfa_wav_path - return mfa_text, mfa_wav_path - - -def get_align_data(uid: str, prefix: os.PathLike): - mfa_path = prefix + "mfa_" - mfa_text = read_2col_text(mfa_path + 'text')[uid] - mfa_start = load_num_sequence_text( - mfa_path + 'start', loader_type='text_float')[uid] - mfa_end = load_num_sequence_text( - mfa_path + 'end', loader_type='text_float')[uid] - mfa_wav_path = read_2col_text(mfa_path + 'wav.scp')[uid] - return mfa_text, mfa_start, mfa_end, mfa_wav_path - - -# 获取需要被 mask 的 mel 帧的范围 -def get_masked_mel_bdy(mfa_start: List[float], - mfa_end: List[float], - fs: int, - hop_length: int, - span_to_repl: List[List[int]]): - align_start = np.array(mfa_start) - align_end = np.array(mfa_end) - align_start = np.floor(fs * align_start / hop_length).astype('int') - align_end = np.floor(fs * align_end / hop_length).astype('int') - if span_to_repl[0] >= len(mfa_start): - span_bdy = [align_end[-1], align_end[-1]] - else: - span_bdy = [ - align_start[span_to_repl[0]], align_end[span_to_repl[1] - 1] - ] - return span_bdy, align_start, align_end - - -def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]): - dic = {} - keys_to_del = [] - exist_idx = [] - sp_count = 0 - add_sp_count = 0 - for key in word2phns.keys(): - idx, wrd = key.split('_') - if wrd == 'sp': - sp_count += 1 - exist_idx.append(int(idx)) - else: - keys_to_del.append(key) - - for key in keys_to_del: - del word2phns[key] - - cur_id = 0 - for key in tp_word2phns.keys(): - if cur_id in exist_idx: - dic[str(cur_id) + "_sp"] = 'sp' - cur_id += 1 - add_sp_count += 1 - idx, wrd = key.split('_') - dic[str(cur_id) + "_" + wrd] = tp_word2phns[key] - cur_id += 1 - - if add_sp_count + 1 == sp_count: - dic[str(cur_id) + "_sp"] = 'sp' - add_sp_count += 1 - - assert add_sp_count == sp_count, "sp are not added in dic" - return dic - - -def get_max_idx(dic): - return sorted([int(key.split('_')[0]) for key in dic.keys()])[-1] - - -def get_phns_and_spans(wav_path: str, - old_str: str="", - new_str: str="", - source_lang: str="english", - target_lang: str="english"): - is_append = (old_str == new_str[:len(old_str)]) - old_phns, mfa_start, mfa_end = [], [], [] - # source - if source_lang == "english": - intervals, word2phns = alignment(wav_path, old_str) - elif source_lang == "chinese": - intervals, word2phns = alignment_zh(wav_path, old_str) - _, tp_word2phns = words2phns_zh(old_str) - - for key, value in tp_word2phns.items(): - idx, wrd = key.split('_') - cur_val = " ".join(value) - tp_word2phns[key] = cur_val - - word2phns = recover_dict(word2phns, tp_word2phns) - else: - assert source_lang == "chinese" or source_lang == "english", \ - "source_lang is wrong..." - - for item in intervals: - old_phns.append(item[0]) - mfa_start.append(float(item[1])) - mfa_end.append(float(item[2])) - # target - if is_append and (source_lang != target_lang): - cross_lingual_clone = True - else: - cross_lingual_clone = False - - if cross_lingual_clone: - str_origin = new_str[:len(old_str)] - str_append = new_str[len(old_str):] - - if target_lang == "chinese": - phns_origin, origin_word2phns = words2phns(str_origin) - phns_append, append_word2phns_tmp = words2phns_zh(str_append) - - elif target_lang == "english": - # 原始句子 - phns_origin, origin_word2phns = words2phns_zh(str_origin) - # clone 句子 - phns_append, append_word2phns_tmp = words2phns(str_append) - else: - assert target_lang == "chinese" or target_lang == "english", \ - "cloning is not support for this language, please check it." - - new_phns = phns_origin + phns_append - - append_word2phns = {} - length = len(origin_word2phns) - for key, value in append_word2phns_tmp.items(): - idx, wrd = key.split('_') - append_word2phns[str(int(idx) + length) + '_' + wrd] = value - new_word2phns = origin_word2phns.copy() - new_word2phns.update(append_word2phns) - - else: - if source_lang == target_lang and target_lang == "english": - new_phns, new_word2phns = words2phns(new_str) - elif source_lang == target_lang and target_lang == "chinese": - new_phns, new_word2phns = words2phns_zh(new_str) - else: - assert source_lang == target_lang, \ - "source language is not same with target language..." - - span_to_repl = [0, len(old_phns) - 1] - span_to_add = [0, len(new_phns) - 1] - left_idx = 0 - new_phns_left = [] - sp_count = 0 - # find the left different index - for key in word2phns.keys(): - idx, wrd = key.split('_') - if wrd == 'sp': - sp_count += 1 - new_phns_left.append('sp') - else: - idx = str(int(idx) - sp_count) - if idx + '_' + wrd in new_word2phns: - left_idx += len(new_word2phns[idx + '_' + wrd]) - new_phns_left.extend(word2phns[key].split()) - else: - span_to_repl[0] = len(new_phns_left) - span_to_add[0] = len(new_phns_left) - break - - # reverse word2phns and new_word2phns - right_idx = 0 - new_phns_right = [] - sp_count = 0 - word2phns_max_idx = get_max_idx(word2phns) - new_word2phns_max_idx = get_max_idx(new_word2phns) - new_phns_mid = [] - if is_append: - new_phns_right = [] - new_phns_mid = new_phns[left_idx:] - span_to_repl[0] = len(new_phns_left) - span_to_add[0] = len(new_phns_left) - span_to_add[1] = len(new_phns_left) + len(new_phns_mid) - span_to_repl[1] = len(old_phns) - len(new_phns_right) - # speech edit - else: - for key in list(word2phns.keys())[::-1]: - idx, wrd = key.split('_') - if wrd == 'sp': - sp_count += 1 - new_phns_right = ['sp'] + new_phns_right - else: - idx = str(new_word2phns_max_idx - (word2phns_max_idx - int(idx) - - sp_count)) - if idx + '_' + wrd in new_word2phns: - right_idx -= len(new_word2phns[idx + '_' + wrd]) - new_phns_right = word2phns[key].split() + new_phns_right - else: - span_to_repl[1] = len(old_phns) - len(new_phns_right) - new_phns_mid = new_phns[left_idx:right_idx] - span_to_add[1] = len(new_phns_left) + len(new_phns_mid) - if len(new_phns_mid) == 0: - span_to_add[1] = min(span_to_add[1] + 1, len(new_phns)) - span_to_add[0] = max(0, span_to_add[0] - 1) - span_to_repl[0] = max(0, span_to_repl[0] - 1) - span_to_repl[1] = min(span_to_repl[1] + 1, - len(old_phns)) - break - new_phns = new_phns_left + new_phns_mid + new_phns_right - ''' - For that reason cover should not be given. - For that reason cover is impossible to be given. - span_to_repl: [17, 23] "should not" - span_to_add: [17, 30] "is impossible to" - ''' - return mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add - - -# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同 -# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放 -def get_dur_adj_factor(orig_dur: List[int], - pred_dur: List[int], - phns: List[str]): - length = 0 - factor_list = [] - for orig, pred, phn in zip(orig_dur, pred_dur, phns): - if pred == 0 or phn == 'sp': - continue - else: - factor_list.append(orig / pred) - factor_list = np.array(factor_list) - factor_list.sort() - if len(factor_list) < 5: - return 1 - length = 2 - avg = np.average(factor_list[length:-length]) - return avg - - -def prep_feats_with_dur(wav_path: str, - source_lang: str="English", - target_lang: str="English", - old_str: str="", - new_str: str="", - mask_reconstruct: bool=False, - duration_adjust: bool=True, - start_end_sp: bool=False, - fs: int=24000, - hop_length: int=300): - ''' - Returns: - np.ndarray: new wav, replace the part to be edited in original wav with 0 - List[str]: new phones - List[float]: mfa start of new wav - List[float]: mfa end of new wav - List[int]: masked mel boundary of original wav - List[int]: masked mel boundary of new wav - ''' - wav_org, _ = librosa.load(wav_path, sr=fs) - - mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add = get_phns_and_spans( - wav_path=wav_path, - old_str=old_str, - new_str=new_str, - source_lang=source_lang, - target_lang=target_lang) - - if start_end_sp: - if new_phns[-1] != 'sp': - new_phns = new_phns + ['sp'] - # 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替 - if target_lang == "english" or target_lang == "chinese": - old_durs = eval_durs(old_phns, target_lang=source_lang) - else: - assert target_lang == "chinese" or target_lang == "english", \ - "calculate duration_predict is not support for this language..." - - orig_old_durs = [e - s for e, s in zip(mfa_end, mfa_start)] - if '[MASK]' in new_str: - new_phns = old_phns - span_to_add = span_to_repl - d_factor_left = get_dur_adj_factor( - orig_dur=orig_old_durs[:span_to_repl[0]], - pred_dur=old_durs[:span_to_repl[0]], - phns=old_phns[:span_to_repl[0]]) - d_factor_right = get_dur_adj_factor( - orig_dur=orig_old_durs[span_to_repl[1]:], - pred_dur=old_durs[span_to_repl[1]:], - phns=old_phns[span_to_repl[1]:]) - d_factor = (d_factor_left + d_factor_right) / 2 - new_durs_adjusted = [d_factor * i for i in old_durs] - else: - if duration_adjust: - d_factor = get_dur_adj_factor( - orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns) - d_factor = d_factor * 1.25 - else: - d_factor = 1 - - if target_lang == "english" or target_lang == "chinese": - new_durs = eval_durs(new_phns, target_lang=target_lang) - else: - assert target_lang == "chinese" or target_lang == "english", \ - "calculate duration_predict is not support for this language..." - - new_durs_adjusted = [d_factor * i for i in new_durs] - - new_span_dur_sum = sum(new_durs_adjusted[span_to_add[0]:span_to_add[1]]) - old_span_dur_sum = sum(orig_old_durs[span_to_repl[0]:span_to_repl[1]]) - dur_offset = new_span_dur_sum - old_span_dur_sum - new_mfa_start = mfa_start[:span_to_repl[0]] - new_mfa_end = mfa_end[:span_to_repl[0]] - for i in new_durs_adjusted[span_to_add[0]:span_to_add[1]]: - if len(new_mfa_end) == 0: - new_mfa_start.append(0) - new_mfa_end.append(i) - else: - new_mfa_start.append(new_mfa_end[-1]) - new_mfa_end.append(new_mfa_end[-1] + i) - new_mfa_start += [i + dur_offset for i in mfa_start[span_to_repl[1]:]] - new_mfa_end += [i + dur_offset for i in mfa_end[span_to_repl[1]:]] - - # 3. get new wav - # 在原始句子后拼接 - if span_to_repl[0] >= len(mfa_start): - left_idx = len(wav_org) - right_idx = left_idx - # 在原始句子中间替换 - else: - left_idx = int(np.floor(mfa_start[span_to_repl[0]] * fs)) - right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * fs)) - blank_wav = np.zeros( - (int(np.ceil(new_span_dur_sum * fs)), ), dtype=wav_org.dtype) - # 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定 - new_wav = np.concatenate( - [wav_org[:left_idx], blank_wav, wav_org[right_idx:]]) - - # 4. get old and new mel span to be mask - # [92, 92] - - old_span_bdy, mfa_start, mfa_end = get_masked_mel_bdy( - mfa_start=mfa_start, - mfa_end=mfa_end, - fs=fs, - hop_length=hop_length, - span_to_repl=span_to_repl) - # [92, 174] - # new_mfa_start, new_mfa_end 时间级别的开始和结束时间 -> 帧级别 - new_span_bdy, new_mfa_start, new_mfa_end = get_masked_mel_bdy( - mfa_start=new_mfa_start, - mfa_end=new_mfa_end, - fs=fs, - hop_length=hop_length, - span_to_repl=span_to_add) - - # old_span_bdy, new_span_bdy 是帧级别的范围 - return new_wav, new_phns, new_mfa_start, new_mfa_end, old_span_bdy, new_span_bdy - - -def prep_feats(wav_path: str, - source_lang: str="english", - target_lang: str="english", - old_str: str="", - new_str: str="", - duration_adjust: bool=True, - start_end_sp: bool=False, - mask_reconstruct: bool=False, - fs: int=24000, - hop_length: int=300, - token_list: List[str]=[]): - wav, phns, mfa_start, mfa_end, old_span_bdy, new_span_bdy = prep_feats_with_dur( - source_lang=source_lang, - target_lang=target_lang, - old_str=old_str, - new_str=new_str, - wav_path=wav_path, - duration_adjust=duration_adjust, - start_end_sp=start_end_sp, - mask_reconstruct=mask_reconstruct, - fs=fs, - hop_length=hop_length) - - token_to_id = {item: i for i, item in enumerate(token_list)} - text = np.array( - list(map(lambda x: token_to_id.get(x, token_to_id['']), phns))) - span_bdy = np.array(new_span_bdy) - - batch = [('1', { - "speech": wav, - "align_start": mfa_start, - "align_end": mfa_end, - "text": text, - "span_bdy": span_bdy - })] - - return batch, old_span_bdy, new_span_bdy - - -def decode_with_model(mlm_model: nn.Layer, - collate_fn, - wav_path: str, - source_lang: str="english", - target_lang: str="english", - old_str: str="", - new_str: str="", - use_teacher_forcing: bool=False, - duration_adjust: bool=True, - start_end_sp: bool=False, - fs: int=24000, - hop_length: int=300, - token_list: List[str]=[]): - batch, old_span_bdy, new_span_bdy = prep_feats( - source_lang=source_lang, - target_lang=target_lang, - wav_path=wav_path, - old_str=old_str, - new_str=new_str, - duration_adjust=duration_adjust, - start_end_sp=start_end_sp, - fs=fs, - hop_length=hop_length, - token_list=token_list) - - feats = collate_fn(batch)[1] - - if 'text_masked_pos' in feats.keys(): - feats.pop('text_masked_pos') - - output = mlm_model.inference( - text=feats['text'], - speech=feats['speech'], - masked_pos=feats['masked_pos'], - speech_mask=feats['speech_mask'], - text_mask=feats['text_mask'], - speech_seg_pos=feats['speech_seg_pos'], - text_seg_pos=feats['text_seg_pos'], - span_bdy=new_span_bdy, - use_teacher_forcing=use_teacher_forcing) - - # 拼接音频 - output_feat = paddle.concat(x=output, axis=0) - wav_org, _ = librosa.load(wav_path, sr=fs) - return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length - - -def get_mlm_output(wav_path: str, - model_name: str="paddle_checkpoint_en", - source_lang: str="english", - target_lang: str="english", - old_str: str="", - new_str: str="", - use_teacher_forcing: bool=False, - duration_adjust: bool=True, - start_end_sp: bool=False): - mlm_model, train_conf = load_model(model_name) - - collate_fn = build_mlm_collate_fn( - sr=train_conf.fs, - n_fft=train_conf.n_fft, - hop_length=train_conf.n_shift, - win_length=train_conf.win_length, - n_mels=train_conf.n_mels, - fmin=train_conf.fmin, - fmax=train_conf.fmax, - mlm_prob=train_conf.mlm_prob, - mean_phn_span=train_conf.mean_phn_span, - seg_emb=train_conf.model['enc_input_layer'] == 'sega_mlm') - - return decode_with_model( - source_lang=source_lang, - target_lang=target_lang, - mlm_model=mlm_model, - collate_fn=collate_fn, - wav_path=wav_path, - old_str=old_str, - new_str=new_str, - use_teacher_forcing=use_teacher_forcing, - duration_adjust=duration_adjust, - start_end_sp=start_end_sp, - fs=train_conf.fs, - hop_length=train_conf.n_shift, - token_list=train_conf.token_list) - - -def evaluate(uid: str, - source_lang: str="english", - target_lang: str="english", - prefix: os.PathLike="./prompt/dev/", - model_name: str="paddle_checkpoint_en", - new_str: str="", - prompt_decoding: bool=False, - task_name: str=None): - - # get origin text and path of origin wav - old_str, wav_path = read_data(uid=uid, prefix=prefix) - - if task_name == 'edit': - new_str = new_str - elif task_name == 'synthesize': - new_str = old_str + new_str - else: - new_str = old_str + ' '.join([ch for ch in new_str if is_chinese(ch)]) - - print('new_str is ', new_str) - - results_dict = get_wav( - source_lang=source_lang, - target_lang=target_lang, - model_name=model_name, - wav_path=wav_path, - old_str=old_str, - new_str=new_str) - return results_dict - - -if __name__ == "__main__": - # parse config and args - args = parse_args() - - data_dict = evaluate( - uid=args.uid, - source_lang=args.source_lang, - target_lang=args.target_lang, - prefix=args.prefix, - model_name=args.model_name, - new_str=args.new_str, - task_name=args.task_name) - sf.write(args.output_name, data_dict['output'], samplerate=24000) - print("finished...") diff --git a/examples/ernie_sat/local/sedit_arg_parser.py b/examples/ernie_sat/local/sedit_arg_parser.py deleted file mode 100644 index ad7e57191..000000000 --- a/examples/ernie_sat/local/sedit_arg_parser.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse - - -def parse_args(): - # parse args and config and redirect to train_sp - parser = argparse.ArgumentParser( - description="Synthesize with acoustic model & vocoder") - # acoustic model - parser.add_argument( - '--am', - type=str, - default='fastspeech2_csmsc', - choices=[ - 'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_ljspeech', - 'fastspeech2_aishell3', 'fastspeech2_vctk', 'tacotron2_csmsc', - 'tacotron2_ljspeech', 'tacotron2_aishell3' - ], - help='Choose acoustic model type of tts task.') - parser.add_argument( - '--am_config', - type=str, - default=None, - help='Config of acoustic model. Use deault config when it is None.') - parser.add_argument( - '--am_ckpt', - type=str, - default=None, - help='Checkpoint file of acoustic model.') - parser.add_argument( - "--am_stat", - type=str, - default=None, - help="mean and standard deviation used to normalize spectrogram when training acoustic model." - ) - parser.add_argument( - "--phones_dict", type=str, default=None, help="phone vocabulary file.") - parser.add_argument( - "--tones_dict", type=str, default=None, help="tone vocabulary file.") - parser.add_argument( - "--speaker_dict", type=str, default=None, help="speaker id map file.") - - # vocoder - parser.add_argument( - '--voc', - type=str, - default='pwgan_aishell3', - choices=[ - 'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk', - 'mb_melgan_csmsc', 'wavernn_csmsc', 'hifigan_csmsc', - 'hifigan_ljspeech', 'hifigan_aishell3', 'hifigan_vctk', - 'style_melgan_csmsc' - ], - help='Choose vocoder type of tts task.') - parser.add_argument( - '--voc_config', - type=str, - default=None, - help='Config of voc. Use deault config when it is None.') - parser.add_argument( - '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') - parser.add_argument( - "--voc_stat", - type=str, - default=None, - help="mean and standard deviation used to normalize spectrogram when training voc." - ) - # other - parser.add_argument( - "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") - - parser.add_argument("--model_name", type=str, help="model name") - parser.add_argument("--uid", type=str, help="uid") - parser.add_argument("--new_str", type=str, help="new string") - parser.add_argument("--prefix", type=str, help="prefix") - parser.add_argument( - "--source_lang", type=str, default="english", help="source language") - parser.add_argument( - "--target_lang", type=str, default="english", help="target language") - parser.add_argument("--output_name", type=str, help="output name") - parser.add_argument("--task_name", type=str, help="task name") - - # pre - args = parser.parse_args() - return args diff --git a/examples/ernie_sat/local/utils.py b/examples/ernie_sat/local/utils.py deleted file mode 100644 index f2dce504a..000000000 --- a/examples/ernie_sat/local/utils.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from pathlib import Path -from typing import Dict -from typing import List -from typing import Union - -import numpy as np -import paddle -import yaml -from sedit_arg_parser import parse_args -from yacs.config import CfgNode - -from paddlespeech.t2s.exps.syn_utils import get_am_inference -from paddlespeech.t2s.exps.syn_utils import get_voc_inference - - -def read_2col_text(path: Union[Path, str]) -> Dict[str, str]: - """Read a text file having 2 column as dict object. - - Examples: - wav.scp: - key1 /some/path/a.wav - key2 /some/path/b.wav - - >>> read_2col_text('wav.scp') - {'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'} - - """ - - data = {} - with Path(path).open("r", encoding="utf-8") as f: - for linenum, line in enumerate(f, 1): - sps = line.rstrip().split(maxsplit=1) - if len(sps) == 1: - k, v = sps[0], "" - else: - k, v = sps - if k in data: - raise RuntimeError(f"{k} is duplicated ({path}:{linenum})") - data[k] = v - return data - - -def load_num_sequence_text(path: Union[Path, str], loader_type: str="csv_int" - ) -> Dict[str, List[Union[float, int]]]: - """Read a text file indicating sequences of number - - Examples: - key1 1 2 3 - key2 34 5 6 - - >>> d = load_num_sequence_text('text') - >>> np.testing.assert_array_equal(d["key1"], np.array([1, 2, 3])) - """ - if loader_type == "text_int": - delimiter = " " - dtype = int - elif loader_type == "text_float": - delimiter = " " - dtype = float - elif loader_type == "csv_int": - delimiter = "," - dtype = int - elif loader_type == "csv_float": - delimiter = "," - dtype = float - else: - raise ValueError(f"Not supported loader_type={loader_type}") - - # path looks like: - # utta 1,0 - # uttb 3,4,5 - # -> return {'utta': np.ndarray([1, 0]), - # 'uttb': np.ndarray([3, 4, 5])} - d = read_2column_text(path) - # Using for-loop instead of dict-comprehension for debuggability - retval = {} - for k, v in d.items(): - try: - retval[k] = [dtype(i) for i in v.split(delimiter)] - except TypeError: - print(f'Error happened with path="{path}", id="{k}", value="{v}"') - raise - return retval - - -def is_chinese(ch): - if u'\u4e00' <= ch <= u'\u9fff': - return True - else: - return False - - -def get_voc_out(mel): - # vocoder - args = parse_args() - with open(args.voc_config) as f: - voc_config = CfgNode(yaml.safe_load(f)) - voc_inference = get_voc_inference( - voc=args.voc, - voc_config=voc_config, - voc_ckpt=args.voc_ckpt, - voc_stat=args.voc_stat) - - with paddle.no_grad(): - wav = voc_inference(mel) - return np.squeeze(wav) - - -def eval_durs(phns, target_lang="chinese", fs=24000, hop_length=300): - args = parse_args() - - if target_lang == 'english': - args.am = "fastspeech2_ljspeech" - args.am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml" - args.am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz" - args.am_stat = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy" - args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt" - - elif target_lang == 'chinese': - args.am = "fastspeech2_csmsc" - args.am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml" - args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz" - args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy" - args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt" - - if args.ngpu == 0: - paddle.set_device("cpu") - elif args.ngpu > 0: - paddle.set_device("gpu") - else: - print("ngpu should >= 0 !") - - # Init body. - with open(args.am_config) as f: - am_config = CfgNode(yaml.safe_load(f)) - - am_inference, am = get_am_inference( - am=args.am, - am_config=am_config, - am_ckpt=args.am_ckpt, - am_stat=args.am_stat, - phones_dict=args.phones_dict, - tones_dict=args.tones_dict, - speaker_dict=args.speaker_dict, - return_am=True) - - vocab_phones = {} - with open(args.phones_dict, "r") as f: - phn_id = [line.strip().split() for line in f.readlines()] - for tone, id in phn_id: - vocab_phones[tone] = int(id) - vocab_size = len(vocab_phones) - phonemes = [phn if phn in vocab_phones else "sp" for phn in phns] - - phone_ids = [vocab_phones[item] for item in phonemes] - phone_ids.append(vocab_size - 1) - phone_ids = paddle.to_tensor(np.array(phone_ids, np.int64)) - _, d_outs, _, _ = am.inference(phone_ids, spk_id=None, spk_emb=None) - pre_d_outs = d_outs - phu_durs_new = pre_d_outs * hop_length / fs - phu_durs_new = phu_durs_new.tolist()[:-1] - return phu_durs_new diff --git a/examples/ernie_sat/path.sh b/examples/ernie_sat/path.sh deleted file mode 100755 index d46d2f612..000000000 --- a/examples/ernie_sat/path.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -export MAIN_ROOT=`realpath ${PWD}/../../` - -export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} -export LC_ALL=C - -export PYTHONDONTWRITEBYTECODE=1 -# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C -export PYTHONIOENCODING=UTF-8 -export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} - -MODEL=ernie_sat -export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL} \ No newline at end of file diff --git a/examples/ernie_sat/prompt/dev/text b/examples/ernie_sat/prompt/dev/text deleted file mode 100644 index f79cdcb42..000000000 --- a/examples/ernie_sat/prompt/dev/text +++ /dev/null @@ -1,3 +0,0 @@ -p243_new For that reason cover should not be given. -Prompt_003_new This was not the show for me. -p299_096 We are trying to establish a date. diff --git a/examples/ernie_sat/prompt/dev/wav.scp b/examples/ernie_sat/prompt/dev/wav.scp deleted file mode 100644 index eb0e8e48d..000000000 --- a/examples/ernie_sat/prompt/dev/wav.scp +++ /dev/null @@ -1,3 +0,0 @@ -p243_new ../../prompt_wav/p243_313.wav -Prompt_003_new ../../prompt_wav/this_was_not_the_show_for_me.wav -p299_096 ../../prompt_wav/p299_096.wav diff --git a/examples/ernie_sat/run_clone_en_to_zh.sh b/examples/ernie_sat/run_clone_en_to_zh.sh deleted file mode 100755 index 68b1c7544..000000000 --- a/examples/ernie_sat/run_clone_en_to_zh.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash - -set -e -source path.sh - -# en --> zh 的 语音合成 -# 根据 Prompt_003_new 作为提示语音: This was not the show for me. 来合成: '今天天气很好' -# 注: 输入的 new_str 需为中文汉字, 否则会通过预处理只保留中文汉字, 即合成预处理后的中文语音。 - -python local/inference.py \ - --task_name=cross-lingual_clone \ - --model_name=paddle_checkpoint_dual_mask_enzh \ - --uid=Prompt_003_new \ - --new_str='今天天气很好.' \ - --prefix='./prompt/dev/' \ - --source_lang=english \ - --target_lang=chinese \ - --output_name=pred_clone.wav \ - --voc=pwgan_aishell3 \ - --voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \ - --voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ - --voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \ - --am=fastspeech2_csmsc \ - --am_config=download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml \ - --am_ckpt=download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz \ - --am_stat=download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy \ - --phones_dict=download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt diff --git a/examples/ernie_sat/run_clone_en_to_zh_new.sh b/examples/ernie_sat/run_clone_en_to_zh_new.sh deleted file mode 100755 index 12fdf23f1..000000000 --- a/examples/ernie_sat/run_clone_en_to_zh_new.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash - -set -e -source path.sh - -# en --> zh 的 语音合成 -# 根据 Prompt_003_new 作为提示语音: This was not the show for me. 来合成: '今天天气很好' -# 注: 输入的 new_str 需为中文汉字, 否则会通过预处理只保留中文汉字, 即合成预处理后的中文语音。 - -python local/inference_new.py \ - --task_name=cross-lingual_clone \ - --model_name=paddle_checkpoint_dual_mask_enzh \ - --uid=Prompt_003_new \ - --new_str='今天天气很好.' \ - --prefix='./prompt/dev/' \ - --source_lang=english \ - --target_lang=chinese \ - --output_name=pred_clone.wav \ - --voc=pwgan_aishell3 \ - --voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \ - --voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ - --voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \ - --am=fastspeech2_csmsc \ - --am_config=download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml \ - --am_ckpt=download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz \ - --am_stat=download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy \ - --phones_dict=download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt diff --git a/examples/ernie_sat/run_gen_en.sh b/examples/ernie_sat/run_gen_en.sh deleted file mode 100755 index a0641bc7f..000000000 --- a/examples/ernie_sat/run_gen_en.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/bin/bash - -set -e -source path.sh - -# 纯英文的语音合成 -# 样例为根据 p299_096 对应的语音作为提示语音: This was not the show for me. 来合成: 'I enjoy my life.' - -python local/inference.py \ - --task_name=synthesize \ - --model_name=paddle_checkpoint_en \ - --uid=p299_096 \ - --new_str='I enjoy my life, do you?' \ - --prefix='./prompt/dev/' \ - --source_lang=english \ - --target_lang=english \ - --output_name=pred_gen.wav \ - --voc=pwgan_aishell3 \ - --voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \ - --voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ - --voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \ - --am=fastspeech2_ljspeech \ - --am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \ - --am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \ - --am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \ - --phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt diff --git a/examples/ernie_sat/run_gen_en_new.sh b/examples/ernie_sat/run_gen_en_new.sh deleted file mode 100755 index d76b00430..000000000 --- a/examples/ernie_sat/run_gen_en_new.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/bin/bash - -set -e -source path.sh - -# 纯英文的语音合成 -# 样例为根据 p299_096 对应的语音作为提示语音: This was not the show for me. 来合成: 'I enjoy my life.' - -python local/inference_new.py \ - --task_name=synthesize \ - --model_name=paddle_checkpoint_en \ - --uid=p299_096 \ - --new_str='I enjoy my life, do you?' \ - --prefix='./prompt/dev/' \ - --source_lang=english \ - --target_lang=english \ - --output_name=pred_gen.wav \ - --voc=pwgan_aishell3 \ - --voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \ - --voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ - --voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \ - --am=fastspeech2_ljspeech \ - --am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \ - --am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \ - --am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \ - --phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt diff --git a/examples/ernie_sat/run_sedit_en.sh b/examples/ernie_sat/run_sedit_en.sh deleted file mode 100755 index eec7d6402..000000000 --- a/examples/ernie_sat/run_sedit_en.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash - -set -e -source path.sh - -# 纯英文的语音编辑 -# 样例为把 p243_new 对应的原始语音: For that reason cover should not be given.编辑成 'for that reason cover is impossible to be given.' 对应的语音 -# NOTE: 语音编辑任务暂支持句子中 1 个位置的替换或者插入文本操作 - -python local/inference.py \ - --task_name=edit \ - --model_name=paddle_checkpoint_en \ - --uid=p243_new \ - --new_str='for that reason cover is impossible to be given.' \ - --prefix='./prompt/dev/' \ - --source_lang=english \ - --target_lang=english \ - --output_name=pred_edit.wav \ - --voc=pwgan_aishell3 \ - --voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \ - --voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ - --voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \ - --am=fastspeech2_ljspeech \ - --am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \ - --am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \ - --am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \ - --phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt diff --git a/examples/ernie_sat/run_sedit_en_new.sh b/examples/ernie_sat/run_sedit_en_new.sh deleted file mode 100755 index 0952d280c..000000000 --- a/examples/ernie_sat/run_sedit_en_new.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash - -set -e -source path.sh - -# 纯英文的语音编辑 -# 样例为把 p243_new 对应的原始语音: For that reason cover should not be given.编辑成 'for that reason cover is impossible to be given.' 对应的语音 -# NOTE: 语音编辑任务暂支持句子中 1 个位置的替换或者插入文本操作 - -python local/inference_new.py \ - --task_name=edit \ - --model_name=paddle_checkpoint_en \ - --uid=p243_new \ - --new_str='for that reason cover is impossible to be given.' \ - --prefix='./prompt/dev/' \ - --source_lang=english \ - --target_lang=english \ - --output_name=pred_edit.wav \ - --voc=pwgan_aishell3 \ - --voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \ - --voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ - --voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \ - --am=fastspeech2_ljspeech \ - --am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \ - --am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \ - --am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \ - --phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt diff --git a/examples/ernie_sat/test_run.sh b/examples/ernie_sat/test_run.sh deleted file mode 100755 index 75b6a5691..000000000 --- a/examples/ernie_sat/test_run.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash - -rm -rf *.wav -./run_sedit_en.sh # 语音编辑任务(英文) -./run_gen_en.sh # 个性化语音合成任务(英文) -./run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆) \ No newline at end of file diff --git a/examples/ernie_sat/test_run_new.sh b/examples/ernie_sat/test_run_new.sh deleted file mode 100755 index bf8a4e02d..000000000 --- a/examples/ernie_sat/test_run_new.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash - -rm -rf *.wav -./run_sedit_en_new.sh # 语音编辑任务(英文) -./run_gen_en_new.sh # 个性化语音合成任务(英文) -./run_clone_en_to_zh_new.sh # 跨语言语音合成任务(英文到中文的语音克隆) \ No newline at end of file diff --git a/examples/ernie_sat/tools/.gitkeep b/examples/ernie_sat/tools/.gitkeep deleted file mode 100644 index e69de29bb..000000000 diff --git a/paddlespeech/server/conf/application.yaml b/paddlespeech/server/conf/application.yaml index 8650154e9..55f241ec7 100644 --- a/paddlespeech/server/conf/application.yaml +++ b/paddlespeech/server/conf/application.yaml @@ -25,6 +25,7 @@ asr_python: cfg_path: # [optional] ckpt_path: # [optional] decode_method: 'attention_rescoring' + num_decoding_left_chunks: -1 force_yes: True device: # set 'gpu:id' or 'cpu' @@ -38,6 +39,7 @@ asr_inference: lang: 'zh' sample_rate: 16000 cfg_path: + num_decoding_left_chunks: -1 decode_method: force_yes: True diff --git a/paddlespeech/server/engine/asr/python/asr_engine.py b/paddlespeech/server/engine/asr/python/asr_engine.py index 02c40fd12..9ce05d97a 100644 --- a/paddlespeech/server/engine/asr/python/asr_engine.py +++ b/paddlespeech/server/engine/asr/python/asr_engine.py @@ -66,11 +66,12 @@ class ASREngine(BaseEngine): ) logger.error(e) return False - + self.executor._init_from_path( - self.config.model, self.config.lang, self.config.sample_rate, - self.config.cfg_path, self.config.decode_method, - self.config.ckpt_path) + model_type = self.config.model, lang = self.config.lang, sample_rate = self.config.sample_rate, + cfg_path = self.config.cfg_path, decode_method = self.config.decode_method, + ckpt_path = self.config.ckpt_path) + logger.info("Initialize ASR server engine successfully on device: %s." % (self.device)) diff --git a/paddlespeech/t2s/datasets/am_batch_fn.py b/paddlespeech/t2s/datasets/am_batch_fn.py index 2cb7a11a2..c4c9e5d73 100644 --- a/paddlespeech/t2s/datasets/am_batch_fn.py +++ b/paddlespeech/t2s/datasets/am_batch_fn.py @@ -11,19 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Collection -from typing import Dict -from typing import List -from typing import Tuple - import numpy as np import paddle from paddlespeech.t2s.datasets.batch import batch_sequences -from paddlespeech.t2s.datasets.get_feats import LogMelFBank from paddlespeech.t2s.modules.nets_utils import get_seg_pos from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask -from paddlespeech.t2s.modules.nets_utils import pad_list from paddlespeech.t2s.modules.nets_utils import phones_masking from paddlespeech.t2s.modules.nets_utils import phones_text_masking @@ -490,182 +483,3 @@ def vits_single_spk_batch_fn(examples): "speech": speech } return batch - - -# for ERNIE SAT -class MLMCollateFn: - """Functor class of common_collate_fn()""" - - def __init__( - self, - feats_extract, - mlm_prob: float=0.8, - mean_phn_span: int=8, - seg_emb: bool=False, - text_masking: bool=False, - attention_window: int=0, - not_sequence: Collection[str]=(), ): - self.mlm_prob = mlm_prob - self.mean_phn_span = mean_phn_span - self.feats_extract = feats_extract - self.not_sequence = set(not_sequence) - self.attention_window = attention_window - self.seg_emb = seg_emb - self.text_masking = text_masking - - def __call__(self, data: Collection[Tuple[str, Dict[str, np.ndarray]]] - ) -> Tuple[List[str], Dict[str, paddle.Tensor]]: - return mlm_collate_fn( - data, - feats_extract=self.feats_extract, - mlm_prob=self.mlm_prob, - mean_phn_span=self.mean_phn_span, - seg_emb=self.seg_emb, - text_masking=self.text_masking, - not_sequence=self.not_sequence) - - -def mlm_collate_fn( - data: Collection[Tuple[str, Dict[str, np.ndarray]]], - feats_extract=None, - mlm_prob: float=0.8, - mean_phn_span: int=8, - seg_emb: bool=False, - text_masking: bool=False, - pad_value: int=0, - not_sequence: Collection[str]=(), -) -> Tuple[List[str], Dict[str, paddle.Tensor]]: - uttids = [u for u, _ in data] - data = [d for _, d in data] - - assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching" - assert all(not k.endswith("_lens") - for k in data[0]), f"*_lens is reserved: {list(data[0])}" - - output = {} - for key in data[0]: - - array_list = [d[key] for d in data] - - # Assume the first axis is length: - # tensor_list: Batch x (Length, ...) - tensor_list = [paddle.to_tensor(a) for a in array_list] - # tensor: (Batch, Length, ...) - tensor = pad_list(tensor_list, pad_value) - output[key] = tensor - - # lens: (Batch,) - if key not in not_sequence: - lens = paddle.to_tensor( - [d[key].shape[0] for d in data], dtype=paddle.int64) - output[key + "_lens"] = lens - - feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0])) - feats = paddle.to_tensor(feats) - print("feats.shape:", feats.shape) - feats_lens = paddle.shape(feats)[0] - feats = paddle.unsqueeze(feats, 0) - - text = output["text"] - text_lens = output["text_lens"] - align_start = output["align_start"] - align_start_lens = output["align_start_lens"] - align_end = output["align_end"] - - max_tlen = max(text_lens) - max_slen = max(feats_lens) - - speech_pad = feats[:, :max_slen] - - text_pad = text - text_mask = make_non_pad_mask( - text_lens, text_pad, length_dim=1).unsqueeze(-2) - speech_mask = make_non_pad_mask( - feats_lens, speech_pad[:, :, 0], length_dim=1).unsqueeze(-2) - - span_bdy = None - if 'span_bdy' in output.keys(): - span_bdy = output['span_bdy'] - - # dual_mask 的是混合中英时候同时 mask 语音和文本 - # ernie sat 在实现跨语言的时候都 mask 了 - if text_masking: - masked_pos, text_masked_pos = phones_text_masking( - xs_pad=speech_pad, - src_mask=speech_mask, - text_pad=text_pad, - text_mask=text_mask, - align_start=align_start, - align_end=align_end, - align_start_lens=align_start_lens, - mlm_prob=mlm_prob, - mean_phn_span=mean_phn_span, - span_bdy=span_bdy) - # 训练纯中文和纯英文的 -> a3t 没有对 phoneme 做 mask, 只对语音 mask 了 - # a3t 和 ernie sat 的区别主要在于做 mask 的时候 - else: - masked_pos = phones_masking( - xs_pad=speech_pad, - src_mask=speech_mask, - align_start=align_start, - align_end=align_end, - align_start_lens=align_start_lens, - mlm_prob=mlm_prob, - mean_phn_span=mean_phn_span, - span_bdy=span_bdy) - text_masked_pos = paddle.zeros(paddle.shape(text_pad)) - - output_dict = {} - - speech_seg_pos, text_seg_pos = get_seg_pos( - speech_pad=speech_pad, - text_pad=text_pad, - align_start=align_start, - align_end=align_end, - align_start_lens=align_start_lens, - seg_emb=seg_emb) - output_dict['speech'] = speech_pad - output_dict['text'] = text_pad - output_dict['masked_pos'] = masked_pos - output_dict['text_masked_pos'] = text_masked_pos - output_dict['speech_mask'] = speech_mask - output_dict['text_mask'] = text_mask - output_dict['speech_seg_pos'] = speech_seg_pos - output_dict['text_seg_pos'] = text_seg_pos - output = (uttids, output_dict) - return output - - -def build_mlm_collate_fn( - sr: int=24000, - n_fft: int=2048, - hop_length: int=300, - win_length: int=None, - n_mels: int=80, - fmin: int=80, - fmax: int=7600, - mlm_prob: float=0.8, - mean_phn_span: int=8, - seg_emb: bool=False, - epoch: int=-1, ): - feats_extract_class = LogMelFBank - - feats_extract = feats_extract_class( - sr=sr, - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - n_mels=n_mels, - fmin=fmin, - fmax=fmax) - - if epoch == -1: - mlm_prob_factor = 1 - else: - mlm_prob_factor = 0.8 - - return MLMCollateFn( - feats_extract=feats_extract, - mlm_prob=mlm_prob * mlm_prob_factor, - mean_phn_span=mean_phn_span, - seg_emb=seg_emb) diff --git a/paddlespeech/t2s/models/ernie_sat/__init__.py b/paddlespeech/t2s/models/ernie_sat/__init__.py index 7e795370e..87e7afe85 100644 --- a/paddlespeech/t2s/models/ernie_sat/__init__.py +++ b/paddlespeech/t2s/models/ernie_sat/__init__.py @@ -13,4 +13,3 @@ # limitations under the License. from .ernie_sat import * from .ernie_sat_updater import * -from .mlm import * diff --git a/paddlespeech/t2s/models/ernie_sat/mlm.py b/paddlespeech/t2s/models/ernie_sat/mlm.py deleted file mode 100644 index 647fdd9b4..000000000 --- a/paddlespeech/t2s/models/ernie_sat/mlm.py +++ /dev/null @@ -1,579 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -from typing import Dict -from typing import List -from typing import Optional - -import paddle -import yaml -from paddle import nn -from yacs.config import CfgNode - -from paddlespeech.t2s.modules.activation import get_activation -from paddlespeech.t2s.modules.conformer.convolution import ConvolutionModule -from paddlespeech.t2s.modules.conformer.encoder_layer import EncoderLayer -from paddlespeech.t2s.modules.layer_norm import LayerNorm -from paddlespeech.t2s.modules.masked_fill import masked_fill -from paddlespeech.t2s.modules.nets_utils import initialize -from paddlespeech.t2s.modules.tacotron2.decoder import Postnet -from paddlespeech.t2s.modules.transformer.attention import LegacyRelPositionMultiHeadedAttention -from paddlespeech.t2s.modules.transformer.attention import MultiHeadedAttention -from paddlespeech.t2s.modules.transformer.attention import RelPositionMultiHeadedAttention -from paddlespeech.t2s.modules.transformer.embedding import LegacyRelPositionalEncoding -from paddlespeech.t2s.modules.transformer.embedding import PositionalEncoding -from paddlespeech.t2s.modules.transformer.embedding import RelPositionalEncoding -from paddlespeech.t2s.modules.transformer.embedding import ScaledPositionalEncoding -from paddlespeech.t2s.modules.transformer.multi_layer_conv import Conv1dLinear -from paddlespeech.t2s.modules.transformer.multi_layer_conv import MultiLayeredConv1d -from paddlespeech.t2s.modules.transformer.positionwise_feed_forward import PositionwiseFeedForward -from paddlespeech.t2s.modules.transformer.repeat import repeat -from paddlespeech.t2s.modules.transformer.subsampling import Conv2dSubsampling - - -# MLM -> Mask Language Model -class mySequential(nn.Sequential): - def forward(self, *inputs): - for module in self._sub_layers.values(): - if type(inputs) == tuple: - inputs = module(*inputs) - else: - inputs = module(inputs) - return inputs - - -class MaskInputLayer(nn.Layer): - def __init__(self, out_features: int) -> None: - super().__init__() - self.mask_feature = paddle.create_parameter( - shape=(1, 1, out_features), - dtype=paddle.float32, - default_initializer=paddle.nn.initializer.Assign( - paddle.normal(shape=(1, 1, out_features)))) - - def forward(self, input: paddle.Tensor, - masked_pos: paddle.Tensor=None) -> paddle.Tensor: - masked_pos = paddle.expand_as(paddle.unsqueeze(masked_pos, -1), input) - masked_input = masked_fill(input, masked_pos, 0) + masked_fill( - paddle.expand_as(self.mask_feature, input), ~masked_pos, 0) - return masked_input - - -class MLMEncoder(nn.Layer): - """Conformer encoder module. - - Args: - idim (int): Input dimension. - attention_dim (int): Dimension of attention. - attention_heads (int): The number of heads of multi head attention. - linear_units (int): The number of units of position-wise feed forward. - num_blocks (int): The number of decoder blocks. - dropout_rate (float): Dropout rate. - positional_dropout_rate (float): Dropout rate after adding positional encoding. - attention_dropout_rate (float): Dropout rate in attention. - input_layer (Union[str, paddle.nn.Layer]): Input layer type. - normalize_before (bool): Whether to use layer_norm before the first block. - concat_after (bool): Whether to concat attention layer's input and output. - if True, additional linear will be applied. - i.e. x -> x + linear(concat(x, att(x))) - if False, no additional linear will be applied. i.e. x -> x + att(x) - positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". - positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. - macaron_style (bool): Whether to use macaron style for positionwise layer. - pos_enc_layer_type (str): Encoder positional encoding layer type. - selfattention_layer_type (str): Encoder attention layer type. - activation_type (str): Encoder activation function type. - use_cnn_module (bool): Whether to use convolution module. - zero_triu (bool): Whether to zero the upper triangular part of attention matrix. - cnn_module_kernel (int): Kernerl size of convolution module. - padding_idx (int): Padding idx for input_layer=embed. - stochastic_depth_rate (float): Maximum probability to skip the encoder layer. - - """ - - def __init__(self, - idim: int, - vocab_size: int=0, - pre_speech_layer: int=0, - attention_dim: int=256, - attention_heads: int=4, - linear_units: int=2048, - num_blocks: int=6, - dropout_rate: float=0.1, - positional_dropout_rate: float=0.1, - attention_dropout_rate: float=0.0, - input_layer: str="conv2d", - normalize_before: bool=True, - concat_after: bool=False, - positionwise_layer_type: str="linear", - positionwise_conv_kernel_size: int=1, - macaron_style: bool=False, - pos_enc_layer_type: str="abs_pos", - selfattention_layer_type: str="selfattn", - activation_type: str="swish", - use_cnn_module: bool=False, - zero_triu: bool=False, - cnn_module_kernel: int=31, - padding_idx: int=-1, - stochastic_depth_rate: float=0.0, - text_masking: bool=False): - """Construct an Encoder object.""" - super().__init__() - self._output_size = attention_dim - self.text_masking = text_masking - if self.text_masking: - self.text_masking_layer = MaskInputLayer(attention_dim) - activation = get_activation(activation_type) - if pos_enc_layer_type == "abs_pos": - pos_enc_class = PositionalEncoding - elif pos_enc_layer_type == "scaled_abs_pos": - pos_enc_class = ScaledPositionalEncoding - elif pos_enc_layer_type == "rel_pos": - assert selfattention_layer_type == "rel_selfattn" - pos_enc_class = RelPositionalEncoding - elif pos_enc_layer_type == "legacy_rel_pos": - pos_enc_class = LegacyRelPositionalEncoding - assert selfattention_layer_type == "legacy_rel_selfattn" - else: - raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) - - self.conv_subsampling_factor = 1 - if input_layer == "linear": - self.embed = nn.Sequential( - nn.Linear(idim, attention_dim), - nn.LayerNorm(attention_dim), - nn.Dropout(dropout_rate), - nn.ReLU(), - pos_enc_class(attention_dim, positional_dropout_rate), ) - elif input_layer == "conv2d": - self.embed = Conv2dSubsampling( - idim, - attention_dim, - dropout_rate, - pos_enc_class(attention_dim, positional_dropout_rate), ) - self.conv_subsampling_factor = 4 - elif input_layer == "embed": - self.embed = nn.Sequential( - nn.Embedding(idim, attention_dim, padding_idx=padding_idx), - pos_enc_class(attention_dim, positional_dropout_rate), ) - elif input_layer == "mlm": - self.segment_emb = None - self.speech_embed = mySequential( - MaskInputLayer(idim), - nn.Linear(idim, attention_dim), - nn.LayerNorm(attention_dim), - nn.ReLU(), - pos_enc_class(attention_dim, positional_dropout_rate)) - self.text_embed = nn.Sequential( - nn.Embedding( - vocab_size, attention_dim, padding_idx=padding_idx), - pos_enc_class(attention_dim, positional_dropout_rate), ) - elif input_layer == "sega_mlm": - self.segment_emb = nn.Embedding( - 500, attention_dim, padding_idx=padding_idx) - self.speech_embed = mySequential( - MaskInputLayer(idim), - nn.Linear(idim, attention_dim), - nn.LayerNorm(attention_dim), - nn.ReLU(), - pos_enc_class(attention_dim, positional_dropout_rate)) - self.text_embed = nn.Sequential( - nn.Embedding( - vocab_size, attention_dim, padding_idx=padding_idx), - pos_enc_class(attention_dim, positional_dropout_rate), ) - elif isinstance(input_layer, nn.Layer): - self.embed = nn.Sequential( - input_layer, - pos_enc_class(attention_dim, positional_dropout_rate), ) - elif input_layer is None: - self.embed = nn.Sequential( - pos_enc_class(attention_dim, positional_dropout_rate)) - else: - raise ValueError("unknown input_layer: " + input_layer) - self.normalize_before = normalize_before - - # self-attention module definition - if selfattention_layer_type == "selfattn": - encoder_selfattn_layer = MultiHeadedAttention - encoder_selfattn_layer_args = (attention_heads, attention_dim, - attention_dropout_rate, ) - elif selfattention_layer_type == "legacy_rel_selfattn": - assert pos_enc_layer_type == "legacy_rel_pos" - encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention - encoder_selfattn_layer_args = (attention_heads, attention_dim, - attention_dropout_rate, ) - elif selfattention_layer_type == "rel_selfattn": - assert pos_enc_layer_type == "rel_pos" - encoder_selfattn_layer = RelPositionMultiHeadedAttention - encoder_selfattn_layer_args = (attention_heads, attention_dim, - attention_dropout_rate, zero_triu, ) - else: - raise ValueError("unknown encoder_attn_layer: " + - selfattention_layer_type) - - # feed-forward module definition - if positionwise_layer_type == "linear": - positionwise_layer = PositionwiseFeedForward - positionwise_layer_args = (attention_dim, linear_units, - dropout_rate, activation, ) - elif positionwise_layer_type == "conv1d": - positionwise_layer = MultiLayeredConv1d - positionwise_layer_args = (attention_dim, linear_units, - positionwise_conv_kernel_size, - dropout_rate, ) - elif positionwise_layer_type == "conv1d-linear": - positionwise_layer = Conv1dLinear - positionwise_layer_args = (attention_dim, linear_units, - positionwise_conv_kernel_size, - dropout_rate, ) - else: - raise NotImplementedError("Support only linear or conv1d.") - - # convolution module definition - convolution_layer = ConvolutionModule - convolution_layer_args = (attention_dim, cnn_module_kernel, activation) - - self.encoders = repeat( - num_blocks, - lambda lnum: EncoderLayer( - attention_dim, - encoder_selfattn_layer(*encoder_selfattn_layer_args), - positionwise_layer(*positionwise_layer_args), - positionwise_layer(*positionwise_layer_args) if macaron_style else None, - convolution_layer(*convolution_layer_args) if use_cnn_module else None, - dropout_rate, - normalize_before, - concat_after, - stochastic_depth_rate * float(1 + lnum) / num_blocks, ), ) - self.pre_speech_layer = pre_speech_layer - self.pre_speech_encoders = repeat( - self.pre_speech_layer, - lambda lnum: EncoderLayer( - attention_dim, - encoder_selfattn_layer(*encoder_selfattn_layer_args), - positionwise_layer(*positionwise_layer_args), - positionwise_layer(*positionwise_layer_args) if macaron_style else None, - convolution_layer(*convolution_layer_args) if use_cnn_module else None, - dropout_rate, - normalize_before, - concat_after, - stochastic_depth_rate * float(1 + lnum) / self.pre_speech_layer, ), - ) - if self.normalize_before: - self.after_norm = LayerNorm(attention_dim) - - def forward(self, - speech: paddle.Tensor, - text: paddle.Tensor, - masked_pos: paddle.Tensor, - speech_mask: paddle.Tensor=None, - text_mask: paddle.Tensor=None, - speech_seg_pos: paddle.Tensor=None, - text_seg_pos: paddle.Tensor=None): - """Encode input sequence. - - """ - if masked_pos is not None: - speech = self.speech_embed(speech, masked_pos) - else: - speech = self.speech_embed(speech) - if text is not None: - text = self.text_embed(text) - if speech_seg_pos is not None and text_seg_pos is not None and self.segment_emb: - speech_seg_emb = self.segment_emb(speech_seg_pos) - text_seg_emb = self.segment_emb(text_seg_pos) - text = (text[0] + text_seg_emb, text[1]) - speech = (speech[0] + speech_seg_emb, speech[1]) - if self.pre_speech_encoders: - speech, _ = self.pre_speech_encoders(speech, speech_mask) - - if text is not None: - xs = paddle.concat([speech[0], text[0]], axis=1) - xs_pos_emb = paddle.concat([speech[1], text[1]], axis=1) - masks = paddle.concat([speech_mask, text_mask], axis=-1) - else: - xs = speech[0] - xs_pos_emb = speech[1] - masks = speech_mask - - xs, masks = self.encoders((xs, xs_pos_emb), masks) - - if isinstance(xs, tuple): - xs = xs[0] - if self.normalize_before: - xs = self.after_norm(xs) - - return xs, masks - - -class MLMDecoder(MLMEncoder): - def forward(self, xs: paddle.Tensor, masks: paddle.Tensor): - """Encode input sequence. - - Args: - xs (paddle.Tensor): Input tensor (#batch, time, idim). - masks (paddle.Tensor): Mask tensor (#batch, time). - - Returns: - paddle.Tensor: Output tensor (#batch, time, attention_dim). - paddle.Tensor: Mask tensor (#batch, time). - - """ - xs = self.embed(xs) - xs, masks = self.encoders(xs, masks) - - if isinstance(xs, tuple): - xs = xs[0] - if self.normalize_before: - xs = self.after_norm(xs) - - return xs, masks - - -# encoder and decoder is nn.Layer, not str -class MLM(nn.Layer): - def __init__(self, - odim: int, - encoder: nn.Layer, - decoder: Optional[nn.Layer], - postnet_layers: int=0, - postnet_chans: int=0, - postnet_filts: int=0, - text_masking: bool=False): - - super().__init__() - self.odim = odim - self.encoder = encoder - self.decoder = decoder - self.vocab_size = encoder.text_embed[0]._num_embeddings - - if self.decoder is None or not (hasattr(self.decoder, - 'output_layer') and - self.decoder.output_layer is not None): - self.sfc = nn.Linear(self.encoder._output_size, odim) - else: - self.sfc = None - if text_masking: - self.text_sfc = nn.Linear( - self.encoder.text_embed[0]._embedding_dim, - self.vocab_size, - weight_attr=self.encoder.text_embed[0]._weight_attr) - else: - self.text_sfc = None - - self.postnet = (None if postnet_layers == 0 else Postnet( - idim=self.encoder._output_size, - odim=odim, - n_layers=postnet_layers, - n_chans=postnet_chans, - n_filts=postnet_filts, - use_batch_norm=True, - dropout_rate=0.5, )) - - def inference( - self, - speech: paddle.Tensor, - text: paddle.Tensor, - masked_pos: paddle.Tensor, - speech_mask: paddle.Tensor, - text_mask: paddle.Tensor, - speech_seg_pos: paddle.Tensor, - text_seg_pos: paddle.Tensor, - span_bdy: List[int], - use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]: - ''' - Args: - speech (paddle.Tensor): input speech (1, Tmax, D). - text (paddle.Tensor): input text (1, Tmax2). - masked_pos (paddle.Tensor): masked position of input speech (1, Tmax) - speech_mask (paddle.Tensor): mask of speech (1, 1, Tmax). - text_mask (paddle.Tensor): mask of text (1, 1, Tmax2). - speech_seg_pos (paddle.Tensor): n-th phone of each mel, 0<=n<=Tmax2 (1, Tmax). - text_seg_pos (paddle.Tensor): n-th phone of each phone, 0<=n<=Tmax2 (1, Tmax2). - span_bdy (List[int]): masked mel boundary of input speech (2,) - use_teacher_forcing (bool): whether to use teacher forcing - Returns: - List[Tensor]: - eg: - [Tensor(shape=[1, 181, 80]), Tensor(shape=[80, 80]), Tensor(shape=[1, 67, 80])] - ''' - - z_cache = None - if use_teacher_forcing: - before_outs, zs, *_ = self.forward( - speech=speech, - text=text, - masked_pos=masked_pos, - speech_mask=speech_mask, - text_mask=text_mask, - speech_seg_pos=speech_seg_pos, - text_seg_pos=text_seg_pos) - if zs is None: - zs = before_outs - - speech = speech.squeeze(0) - outs = [speech[:span_bdy[0]]] - outs += [zs[0][span_bdy[0]:span_bdy[1]]] - outs += [speech[span_bdy[1]:]] - return outs - return None - - -class MLMEncAsDecoder(MLM): - def forward(self, - speech: paddle.Tensor, - text: paddle.Tensor, - masked_pos: paddle.Tensor, - speech_mask: paddle.Tensor, - text_mask: paddle.Tensor, - speech_seg_pos: paddle.Tensor, - text_seg_pos: paddle.Tensor): - # feats: (Batch, Length, Dim) - # -> encoder_out: (Batch, Length2, Dim2) - encoder_out, h_masks = self.encoder( - speech=speech, - text=text, - masked_pos=masked_pos, - speech_mask=speech_mask, - text_mask=text_mask, - speech_seg_pos=speech_seg_pos, - text_seg_pos=text_seg_pos) - if self.decoder is not None: - zs, _ = self.decoder(encoder_out, h_masks) - else: - zs = encoder_out - speech_hidden_states = zs[:, :paddle.shape(speech)[1], :] - if self.sfc is not None: - before_outs = paddle.reshape( - self.sfc(speech_hidden_states), - (paddle.shape(speech_hidden_states)[0], -1, self.odim)) - else: - before_outs = speech_hidden_states - if self.postnet is not None: - after_outs = before_outs + paddle.transpose( - self.postnet(paddle.transpose(before_outs, [0, 2, 1])), - [0, 2, 1]) - else: - after_outs = None - return before_outs, after_outs, None - - -class MLMDualMaksing(MLM): - def forward(self, - speech: paddle.Tensor, - text: paddle.Tensor, - masked_pos: paddle.Tensor, - speech_mask: paddle.Tensor, - text_mask: paddle.Tensor, - speech_seg_pos: paddle.Tensor, - text_seg_pos: paddle.Tensor): - # feats: (Batch, Length, Dim) - # -> encoder_out: (Batch, Length2, Dim2) - encoder_out, h_masks = self.encoder( - speech=speech, - text=text, - masked_pos=masked_pos, - speech_mask=speech_mask, - text_mask=text_mask, - speech_seg_pos=speech_seg_pos, - text_seg_pos=text_seg_pos) - if self.decoder is not None: - zs, _ = self.decoder(encoder_out, h_masks) - else: - zs = encoder_out - speech_hidden_states = zs[:, :paddle.shape(speech)[1], :] - if self.text_sfc: - text_hiddent_states = zs[:, paddle.shape(speech)[1]:, :] - text_outs = paddle.reshape( - self.text_sfc(text_hiddent_states), - (paddle.shape(text_hiddent_states)[0], -1, self.vocab_size)) - if self.sfc is not None: - before_outs = paddle.reshape( - self.sfc(speech_hidden_states), - (paddle.shape(speech_hidden_states)[0], -1, self.odim)) - else: - before_outs = speech_hidden_states - if self.postnet is not None: - after_outs = before_outs + paddle.transpose( - self.postnet(paddle.transpose(before_outs, [0, 2, 1])), - [0, 2, 1]) - else: - after_outs = None - return before_outs, after_outs, text_outs - - -def build_model_from_file(config_file, model_file): - - state_dict = paddle.load(model_file) - model_class = MLMDualMaksing if 'conformer_combine_vctk_aishell3_dual_masking' in config_file \ - else MLMEncAsDecoder - - # 构建模型 - with open(config_file) as f: - conf = CfgNode(yaml.safe_load(f)) - model = build_model(conf, model_class) - model.set_state_dict(state_dict) - return model, conf - - -# select encoder and decoder here -def build_model(args: argparse.Namespace, model_class=MLMEncAsDecoder) -> MLM: - if isinstance(args.token_list, str): - with open(args.token_list, encoding="utf-8") as f: - token_list = [line.rstrip() for line in f] - - # Overwriting token_list to keep it as "portable". - args.token_list = list(token_list) - elif isinstance(args.token_list, (tuple, list)): - token_list = list(args.token_list) - else: - raise RuntimeError("token_list must be str or list") - - vocab_size = len(token_list) - odim = 80 - - # Encoder - encoder_class = MLMEncoder - - if 'text_masking' in args.model_conf.keys() and args.model_conf[ - 'text_masking']: - args.encoder_conf['text_masking'] = True - else: - args.encoder_conf['text_masking'] = False - - encoder = encoder_class( - args.input_size, vocab_size=vocab_size, **args.encoder_conf) - - # Decoder - if args.decoder != 'no_decoder': - decoder_class = MLMDecoder - decoder = decoder_class( - idim=0, - input_layer=None, - **args.decoder_conf, ) - else: - decoder = None - - # Build model - model = model_class( - odim=odim, - encoder=encoder, - decoder=decoder, - **args.model_conf, ) - - # Initialize - if args.init is not None: - initialize(model, args.init) - - return model diff --git a/tools/get_contributors.ipynb b/tools/get_contributors.ipynb new file mode 100644 index 000000000..a8ad99efa --- /dev/null +++ b/tools/get_contributors.ipynb @@ -0,0 +1,146 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "automotive-trailer", + "metadata": {}, + "outputs": [], + "source": [ + "from selenium import webdriver\n", + "chromeOptions = webdriver.ChromeOptions()\n", + "driver = webdriver.Chrome('./chromedriver', chrome_options=chromeOptions)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "physical-croatia", + "metadata": {}, + "outputs": [], + "source": [ + "driver.get(\"https://github.com/PaddlePaddle/PaddleSpeech/graphs/contributors\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "seventh-latitude", + "metadata": {}, + "outputs": [], + "source": [ + "

\n", + " \n", + " \"zh794390558\"\n", + " \n", + " #1\n", + " zh794390558\n", + " \n", + " \n", + "
\n", + " 655 commits\n", + "   \n", + " 3,671,956 ++\n", + "   \n", + " 1,966,288 --\n", + "
\n", + "
\n", + "
\n", + "

" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "modified-argument", + "metadata": {}, + "outputs": [], + "source": [ + "from selenium.webdriver.common.by import By" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "demonstrated-aging", + "metadata": {}, + "outputs": [], + "source": [ + "elements = driver.find_elements(By.CLASS_NAME, 'lh-condensed')\n", + "for element in elements:\n", + " zhuye = element.find_elements(By.CLASS_NAME, 'd-inline-block')[0].get_attribute(\"href\")\n", + " img = element.find_elements(By.CLASS_NAME, 'avatar')[0].get_attribute(\"src\")\n", + " mkdown = f\"\"\"\"\"\"\n", + " print(mkdown)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "general-torture", + "metadata": {}, + "outputs": [], + "source": [ + "element.find_elements(By.CLASS_NAME, 'd-inline-block')[0].get_attribute(\"href\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "downtown-institute", + "metadata": {}, + "outputs": [], + "source": [ + "element.find_elements(By.CLASS_NAME, 'avatar')[0].get_attribute(\"src\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "worthy-planet", + "metadata": {}, + "outputs": [], + "source": [ + "len(elements)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}