diff --git a/examples/aishell3/ernie_sat/README.md b/examples/aishell3/ernie_sat/README.md new file mode 100644 index 000000000..8086d007c --- /dev/null +++ b/examples/aishell3/ernie_sat/README.md @@ -0,0 +1 @@ +# ERNIE SAT with AISHELL3 dataset diff --git a/examples/aishell3_vctk/README.md b/examples/aishell3_vctk/README.md new file mode 100644 index 000000000..330b25934 --- /dev/null +++ b/examples/aishell3_vctk/README.md @@ -0,0 +1 @@ +# Mixed Chinese and English TTS with AISHELL3 and VCTK datasets diff --git a/examples/aishell3_vctk/ernie_sat/README.md b/examples/aishell3_vctk/ernie_sat/README.md new file mode 100644 index 000000000..1c6bbe230 --- /dev/null +++ b/examples/aishell3_vctk/ernie_sat/README.md @@ -0,0 +1 @@ +# ERNIE SAT with AISHELL3 and VCTK dataset diff --git a/examples/ernie_sat/.meta/framework.png b/examples/ernie_sat/.meta/framework.png new file mode 100644 index 000000000..c68f62467 Binary files /dev/null and b/examples/ernie_sat/.meta/framework.png differ diff --git a/examples/ernie_sat/README.md b/examples/ernie_sat/README.md new file mode 100644 index 000000000..d3bd13372 --- /dev/null +++ b/examples/ernie_sat/README.md @@ -0,0 +1,137 @@ +ERNIE-SAT 是可以同时处理中英文的跨语言的语音-语言跨模态大模型,其在语音编辑、个性化语音合成以及跨语言的语音合成等多个任务取得了领先效果。可以应用于语音编辑、个性化合成、语音克隆、同传翻译等一系列场景,该项目供研究使用。 + +## 模型框架 +ERNIE-SAT 中我们提出了两项创新: +- 在预训练过程中将中英双语对应的音素作为输入,实现了跨语言、个性化的软音素映射 +- 采用语言和语音的联合掩码学习实现了语言和语音的对齐 + +[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3lOXKJXE-1655380879339)(.meta/framework.png)] + +## 使用说明 + +### 1.安装飞桨与环境依赖 + +- 本项目的代码基于 Paddle(version>=2.0) +- 本项目开放提供加载 torch 版本的 vocoder 的功能 + - torch version>=1.8 + +- 安装 htk: 在[官方地址](https://htk.eng.cam.ac.uk/)注册完成后,即可进行下载较新版本的 htk (例如 3.4.1)。同时提供[历史版本 htk 下载地址](https://htk.eng.cam.ac.uk/ftp/software/) + + - 1.注册账号,下载 htk + - 2.解压 htk 文件,**放入项目根目录的 tools 文件夹中, 以 htk 文件夹名称放入** + - 3.**注意**: 如果您下载的是 3.4.1 或者更高版本, 需要进入 HTKLib/HRec.c 文件中, **修改 1626 行和 1650 行**, 即把**以下两行的 dur<=0 都修改为 dur<0**,如下所示: + ```bash + 以htk3.4.1版本举例: + (1)第1626行: if (dur<=0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<=0"); + 修改为: if (dur<0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<0"); + + (2)1650行: if (dur<=0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<=0 "); + 修改为: if (dur<0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<0 "); + ``` + - 4.**编译**: 详情参见解压后的 htk 中的 README 文件(如果未编译, 则无法正常运行) + + + +- 安装 ParallelWaveGAN: 参见[官方地址](https://github.com/kan-bayashi/ParallelWaveGAN):按照该官方链接的安装流程,直接在**项目的根目录下** git clone ParallelWaveGAN 项目并且安装相关依赖即可。 + + +- 安装其他依赖: **sox, libsndfile**等 + +### 2.预训练模型 +预训练模型 ERNIE-SAT 的模型如下所示: +- [ERNIE-SAT_ZH](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-zh.tar.gz) +- [ERNIE-SAT_EN](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-en.tar.gz) +- [ERNIE-SAT_ZH_and_EN](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-en_zh.tar.gz) + + +创建 pretrained_model 文件夹,下载上述 ERNIE-SAT 预训练模型并将其解压: +```bash +mkdir pretrained_model +cd pretrained_model +tar -zxvf model-ernie-sat-base-en.tar.gz +tar -zxvf model-ernie-sat-base-zh.tar.gz +tar -zxvf model-ernie-sat-base-en_zh.tar.gz +``` + +### 3.下载 + +1. 本项目使用 parallel wavegan 作为声码器(vocoder): + - [pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip) + + 创建 download 文件夹,下载上述预训练的声码器(vocoder)模型并将其解压: + + ```bash + mkdir download + cd download + unzip pwg_aishell3_ckpt_0.5.zip + ``` + +2. 本项目使用 [FastSpeech2](https://arxiv.org/abs/2006.04558) 作为音素(phoneme)的持续时间预测器: + - [fastspeech2_conformer_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip) 中文场景下使用 + - [fastspeech2_nosil_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip) 英文场景下使用 + + 下载上述预训练的 fastspeech2 模型并将其解压: + + ```bash + cd download + unzip fastspeech2_conformer_baker_ckpt_0.5.zip + unzip fastspeech2_nosil_ljspeech_ckpt_0.5.zip + ``` + +3. 本项目使用 HTK 获取输入音频和文本的对齐信息: + + - [aligner.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/aligner.zip) + + 下载上述文件到 tools 文件夹并将其解压: + ```bash + cd tools + unzip aligner.zip + ``` + + +### 4.推理 + +本项目当前开源了语音编辑、个性化语音合成、跨语言语音合成的推理代码,后续会逐步开源。 +注:当前英文场下的合成语音采用的声码器默认为 vctk_parallel_wavegan.v1.long, 可在[该链接](https://github.com/kan-bayashi/ParallelWaveGAN)中找到; 若 use_pt_vocoder 参数设置为 False,则英文场景下使用 paddle 版本的声码器。 + +我们提供特定音频文件, 以及其对应的文本、音素相关文件: +- prompt_wav: 提供的音频文件 +- prompt/dev: 基于上述特定音频对应的文本、音素相关文件 + + +```text +prompt_wav +├── p299_096.wav # 样例语音文件1 +├── p243_313.wav # 样例语音文件2 +└── ... +``` + +```text +prompt/dev +├── text # 样例语音对应文本 +├── wav.scp # 样例语音路径 +├── mfa_text # 样例语音对应音素 +├── mfa_start # 样例语音中各个音素的开始时间 +└── mfa_end # 样例语音中各个音素的结束时间 +``` +1. `--am` 声学模型格式符合 {model_name}_{dataset} +2. `--am_config`, `--am_checkpoint`, `--am_stat` 和 `--phones_dict` 是声学模型的参数,对应于 fastspeech2 预训练模型中的 4 个文件。 +3. `--voc` 声码器(vocoder)格式是否符合 {model_name}_{dataset} +4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` 是声码器的参数,对应于 parallel wavegan 预训练模型中的 3 个文件。 +5. `--lang` 对应模型的语言可以是 `zh` 或 `en` 。 +6. `--ngpu` 要使用的 GPU 数,如果 ngpu==0,则使用 cpu。 +7. `--model_name` 模型名称 +8. `--uid` 特定提示(prompt)语音的 id +9. `--new_str` 输入的文本(本次开源暂时先设置特定的文本) +10. `--prefix` 特定音频对应的文本、音素相关文件的地址 +11. `--source_lang` , 源语言 +12. `--target_lang` , 目标语言 +13. `--output_name` , 合成语音名称 +14. `--task_name` , 任务名称, 包括:语音编辑任务、个性化语音合成任务、跨语言语音合成任务 + +运行以下脚本即可进行实验 +```shell +./run_sedit_en.sh # 语音编辑任务(英文) +./run_gen_en.sh # 个性化语音合成任务(英文) +./run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆) +``` diff --git a/examples/ernie_sat/local/align.py b/examples/ernie_sat/local/align.py new file mode 100755 index 000000000..025877ddf --- /dev/null +++ b/examples/ernie_sat/local/align.py @@ -0,0 +1,441 @@ +""" Usage: + align.py wavfile trsfile outwordfile outphonefile +""" +import os +import sys + +PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme' +MODEL_DIR_EN = 'tools/aligner/english' +MODEL_DIR_ZH = 'tools/aligner/mandarin' +HVITE = 'tools/htk/HTKTools/HVite' +HCOPY = 'tools/htk/HTKTools/HCopy' + + +def get_unk_phns(word_str: str): + tmpbase = '/tmp/tp.' + f = open(tmpbase + 'temp.words', 'w') + f.write(word_str) + f.close() + os.system(PHONEME + ' ' + tmpbase + 'temp.words' + ' ' + tmpbase + + 'temp.phons') + f = open(tmpbase + 'temp.phons', 'r') + lines2 = f.readline().strip().split() + f.close() + phns = [] + for phn in lines2: + phons = phn.replace('\n', '').replace(' ', '') + seq = [] + j = 0 + while (j < len(phons)): + if (phons[j] > 'Z'): + if (phons[j] == 'j'): + seq.append('JH') + elif (phons[j] == 'h'): + seq.append('HH') + else: + seq.append(phons[j].upper()) + j += 1 + else: + p = phons[j:j + 2] + if (p == 'WH'): + seq.append('W') + elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']): + seq.append(p) + elif (p == 'AX'): + seq.append('AH0') + else: + seq.append(p + '1') + j += 2 + phns.extend(seq) + return phns + + +def words2phns(line: str): + ''' + Args: + line (str): input text. + eg: for that reason cover is impossible to be given. + Returns: + List[str]: phones of input text. + eg: + ['F', 'AO1', 'R', 'DH', 'AE1', 'T', 'R', 'IY1', 'Z', 'AH0', 'N', 'K', 'AH1', 'V', 'ER0', + 'IH1', 'Z', 'IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L', 'T', 'UW1', 'B', 'IY1', + 'G', 'IH1', 'V', 'AH0', 'N'] + + Dict(str, str): key - idx_word + value - phones + eg: + {'0_FOR': ['F', 'AO1', 'R'], '1_THAT': ['DH', 'AE1', 'T'], '2_REASON': ['R', 'IY1', 'Z', 'AH0', 'N'], + '3_COVER': ['K', 'AH1', 'V', 'ER0'], '4_IS': ['IH1', 'Z'], '5_IMPOSSIBLE': ['IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L'], + '6_TO': ['T', 'UW1'], '7_BE': ['B', 'IY1'], '8_GIVEN': ['G', 'IH1', 'V', 'AH0', 'N']} + ''' + dictfile = MODEL_DIR_EN + '/dict' + line = line.strip() + words = [] + for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---']: + line = line.replace(pun, ' ') + for wrd in line.split(): + if (wrd[-1] == '-'): + wrd = wrd[:-1] + if (wrd[0] == "'"): + wrd = wrd[1:] + if wrd: + words.append(wrd) + ds = set([]) + word2phns_dict = {} + with open(dictfile, 'r') as fid: + for line in fid: + word = line.split()[0] + ds.add(word) + if word not in word2phns_dict.keys(): + word2phns_dict[word] = " ".join(line.split()[1:]) + + phns = [] + wrd2phns = {} + for index, wrd in enumerate(words): + if wrd == '[MASK]': + wrd2phns[str(index) + "_" + wrd] = [wrd] + phns.append(wrd) + elif (wrd.upper() not in ds): + wrd2phns[str(index) + "_" + wrd.upper()] = get_unk_phns(wrd) + phns.extend(get_unk_phns(wrd)) + else: + wrd2phns[str(index) + + "_" + wrd.upper()] = word2phns_dict[wrd.upper()].split() + phns.extend(word2phns_dict[wrd.upper()].split()) + return phns, wrd2phns + + +def words2phns_zh(line: str): + dictfile = MODEL_DIR_ZH + '/dict' + line = line.strip() + words = [] + for pun in [ + ',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',', + u'。', u':', u';', u'!', u'?', u'(', u')' + ]: + line = line.replace(pun, ' ') + for wrd in line.split(): + if (wrd[-1] == '-'): + wrd = wrd[:-1] + if (wrd[0] == "'"): + wrd = wrd[1:] + if wrd: + words.append(wrd) + + ds = set([]) + word2phns_dict = {} + with open(dictfile, 'r') as fid: + for line in fid: + word = line.split()[0] + ds.add(word) + if word not in word2phns_dict.keys(): + word2phns_dict[word] = " ".join(line.split()[1:]) + + phns = [] + wrd2phns = {} + for index, wrd in enumerate(words): + if wrd == '[MASK]': + wrd2phns[str(index) + "_" + wrd] = [wrd] + phns.append(wrd) + elif (wrd.upper() not in ds): + print("出现非法词错误,请输入正确的文本...") + else: + wrd2phns[str(index) + "_" + wrd] = word2phns_dict[wrd].split() + phns.extend(word2phns_dict[wrd].split()) + + return phns, wrd2phns + + +def prep_txt_zh(line: str, tmpbase: str, dictfile: str): + + words = [] + line = line.strip() + for pun in [ + ',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',', + u'。', u':', u';', u'!', u'?', u'(', u')' + ]: + line = line.replace(pun, ' ') + for wrd in line.split(): + if (wrd[-1] == '-'): + wrd = wrd[:-1] + if (wrd[0] == "'"): + wrd = wrd[1:] + if wrd: + words.append(wrd) + + ds = set([]) + with open(dictfile, 'r') as fid: + for line in fid: + ds.add(line.split()[0]) + + unk_words = set([]) + with open(tmpbase + '.txt', 'w') as fwid: + for wrd in words: + if (wrd not in ds): + unk_words.add(wrd) + fwid.write(wrd + ' ') + fwid.write('\n') + return unk_words + + +def prep_txt_en(line: str, tmpbase, dictfile): + + words = [] + + line = line.strip() + for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---']: + line = line.replace(pun, ' ') + for wrd in line.split(): + if (wrd[-1] == '-'): + wrd = wrd[:-1] + if (wrd[0] == "'"): + wrd = wrd[1:] + if wrd: + words.append(wrd) + + ds = set([]) + with open(dictfile, 'r') as fid: + for line in fid: + ds.add(line.split()[0]) + + unk_words = set([]) + with open(tmpbase + '.txt', 'w') as fwid: + for wrd in words: + if (wrd.upper() not in ds): + unk_words.add(wrd.upper()) + fwid.write(wrd + ' ') + fwid.write('\n') + + #generate pronounciations for unknows words using 'letter to sound' + with open(tmpbase + '_unk.words', 'w') as fwid: + for unk in unk_words: + fwid.write(unk + '\n') + try: + os.system(PHONEME + ' ' + tmpbase + '_unk.words' + ' ' + tmpbase + + '_unk.phons') + except Exception: + print('english2phoneme error!') + sys.exit(1) + + #add unknown words to the standard dictionary, generate a tmp dictionary for alignment + fw = open(tmpbase + '.dict', 'w') + with open(dictfile, 'r') as fid: + for line in fid: + fw.write(line) + f = open(tmpbase + '_unk.words', 'r') + lines1 = f.readlines() + f.close() + f = open(tmpbase + '_unk.phons', 'r') + lines2 = f.readlines() + f.close() + for i in range(len(lines1)): + wrd = lines1[i].replace('\n', '') + phons = lines2[i].replace('\n', '').replace(' ', '') + seq = [] + j = 0 + while (j < len(phons)): + if (phons[j] > 'Z'): + if (phons[j] == 'j'): + seq.append('JH') + elif (phons[j] == 'h'): + seq.append('HH') + else: + seq.append(phons[j].upper()) + j += 1 + else: + p = phons[j:j + 2] + if (p == 'WH'): + seq.append('W') + elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']): + seq.append(p) + elif (p == 'AX'): + seq.append('AH0') + else: + seq.append(p + '1') + j += 2 + + fw.write(wrd + ' ') + for s in seq: + fw.write(' ' + s) + fw.write('\n') + fw.close() + + +def prep_mlf(txt: str, tmpbase: str): + + with open(tmpbase + '.mlf', 'w') as fwid: + fwid.write('#!MLF!#\n') + fwid.write('"' + tmpbase + '.lab"\n') + fwid.write('sp\n') + wrds = txt.split() + for wrd in wrds: + fwid.write(wrd.upper() + '\n') + fwid.write('sp\n') + fwid.write('.\n') + + +def _get_user(): + return os.path.expanduser('~').split("/")[-1] + + +def alignment(wav_path: str, text: str): + ''' + intervals: List[phn, start, end] + ''' + tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid()) + + #prepare wav and trs files + try: + os.system('sox ' + wav_path + ' -r 16000 ' + tmpbase + '.wav remix -') + except Exception: + print('sox error!') + return None + + #prepare clean_transcript file + try: + prep_txt_en(line=text, tmpbase=tmpbase, dictfile=MODEL_DIR_EN + '/dict') + except Exception: + print('prep_txt error!') + return None + + #prepare mlf file + try: + with open(tmpbase + '.txt', 'r') as fid: + txt = fid.readline() + prep_mlf(txt, tmpbase) + except Exception: + print('prep_mlf error!') + return None + + #prepare scp + try: + os.system(HCOPY + ' -C ' + MODEL_DIR_EN + '/16000/config ' + tmpbase + + '.wav' + ' ' + tmpbase + '.plp') + except Exception: + print('HCopy error!') + return None + + #run alignment + try: + os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase + + '.mlf -H ' + MODEL_DIR_EN + '/16000/macros -H ' + MODEL_DIR_EN + + '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase + + '.dict ' + MODEL_DIR_EN + '/monophones ' + tmpbase + + '.plp 2>&1 > /dev/null') + except Exception: + print('HVite error!') + return None + + with open(tmpbase + '.txt', 'r') as fid: + words = fid.readline().strip().split() + words = txt.strip().split() + words.reverse() + + with open(tmpbase + '.aligned', 'r') as fid: + lines = fid.readlines() + i = 2 + intervals = [] + word2phns = {} + current_word = '' + index = 0 + while (i < len(lines)): + splited_line = lines[i].strip().split() + if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]): + phn = splited_line[2] + pst = (int(splited_line[0]) / 1000 + 125) / 10000 + pen = (int(splited_line[1]) / 1000 + 125) / 10000 + intervals.append([phn, pst, pen]) + # splited_line[-1]!='sp' + if len(splited_line) == 5: + current_word = str(index) + '_' + splited_line[-1] + word2phns[current_word] = phn + index += 1 + elif len(splited_line) == 4: + word2phns[current_word] += ' ' + phn + i += 1 + return intervals, word2phns + + +def alignment_zh(wav_path: str, text: str): + tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid()) + + #prepare wav and trs files + try: + os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase + + '.wav remix -') + + except Exception: + print('sox error!') + return None + + #prepare clean_transcript file + try: + unk_words = prep_txt_zh( + line=text, tmpbase=tmpbase, dictfile=MODEL_DIR_ZH + '/dict') + if unk_words: + print('Error! Please add the following words to dictionary:') + for unk in unk_words: + print("非法words: ", unk) + except Exception: + print('prep_txt error!') + return None + + #prepare mlf file + try: + with open(tmpbase + '.txt', 'r') as fid: + txt = fid.readline() + prep_mlf(txt, tmpbase) + except Exception: + print('prep_mlf error!') + return None + + #prepare scp + try: + os.system(HCOPY + ' -C ' + MODEL_DIR_ZH + '/16000/config ' + tmpbase + + '.wav' + ' ' + tmpbase + '.plp') + except Exception: + print('HCopy error!') + return None + + #run alignment + try: + os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase + + '.mlf -H ' + MODEL_DIR_ZH + '/16000/macros -H ' + MODEL_DIR_ZH + + '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + MODEL_DIR_ZH + + '/dict ' + MODEL_DIR_ZH + '/monophones ' + tmpbase + + '.plp 2>&1 > /dev/null') + + except Exception: + print('HVite error!') + return None + + with open(tmpbase + '.txt', 'r') as fid: + words = fid.readline().strip().split() + words = txt.strip().split() + words.reverse() + + with open(tmpbase + '.aligned', 'r') as fid: + lines = fid.readlines() + + i = 2 + intervals = [] + word2phns = {} + current_word = '' + index = 0 + while (i < len(lines)): + splited_line = lines[i].strip().split() + if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]): + phn = splited_line[2] + pst = (int(splited_line[0]) / 1000 + 125) / 10000 + pen = (int(splited_line[1]) / 1000 + 125) / 10000 + intervals.append([phn, pst, pen]) + # splited_line[-1]!='sp' + if len(splited_line) == 5: + current_word = str(index) + '_' + splited_line[-1] + word2phns[current_word] = phn + index += 1 + elif len(splited_line) == 4: + word2phns[current_word] += ' ' + phn + i += 1 + return intervals, word2phns diff --git a/examples/ernie_sat/local/inference.py b/examples/ernie_sat/local/inference.py new file mode 100644 index 000000000..196d9c6d0 --- /dev/null +++ b/examples/ernie_sat/local/inference.py @@ -0,0 +1,601 @@ +#!/usr/bin/env python3 +import os +import random +from typing import Dict +from typing import List + +import librosa +import numpy as np +import paddle +import soundfile as sf +from align import alignment +from align import alignment_zh +from align import words2phns +from align import words2phns_zh +from paddle import nn +from sedit_arg_parser import parse_args +from utils import eval_durs +from utils import get_voc_out +from utils import is_chinese +from utils import load_num_sequence_text +from utils import read_2col_text + +from paddlespeech.t2s.datasets.am_batch_fn import build_mlm_collate_fn +from paddlespeech.t2s.models.ernie_sat.mlm import build_model_from_file + +random.seed(0) +np.random.seed(0) + + +def get_wav(wav_path: str, + source_lang: str='english', + target_lang: str='english', + model_name: str="paddle_checkpoint_en", + old_str: str="", + new_str: str="", + non_autoreg: bool=True): + wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length = get_mlm_output( + source_lang=source_lang, + target_lang=target_lang, + model_name=model_name, + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + use_teacher_forcing=non_autoreg) + + masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]] + + alt_wav = get_voc_out(masked_feat) + + old_time_bdy = [hop_length * x for x in old_span_bdy] + + wav_replaced = np.concatenate( + [wav_org[:old_time_bdy[0]], alt_wav, wav_org[old_time_bdy[1]:]]) + + data_dict = {"origin": wav_org, "output": wav_replaced} + + return data_dict + + +def load_model(model_name: str="paddle_checkpoint_en"): + config_path = './pretrained_model/{}/config.yaml'.format(model_name) + model_path = './pretrained_model/{}/model.pdparams'.format(model_name) + mlm_model, conf = build_model_from_file( + config_file=config_path, model_file=model_path) + return mlm_model, conf + + +def read_data(uid: str, prefix: os.PathLike): + # 获取 uid 对应的文本 + mfa_text = read_2col_text(prefix + '/text')[uid] + # 获取 uid 对应的音频路径 + mfa_wav_path = read_2col_text(prefix + '/wav.scp')[uid] + if not os.path.isabs(mfa_wav_path): + mfa_wav_path = prefix + mfa_wav_path + return mfa_text, mfa_wav_path + + +def get_align_data(uid: str, prefix: os.PathLike): + mfa_path = prefix + "mfa_" + mfa_text = read_2col_text(mfa_path + 'text')[uid] + mfa_start = load_num_sequence_text( + mfa_path + 'start', loader_type='text_float')[uid] + mfa_end = load_num_sequence_text( + mfa_path + 'end', loader_type='text_float')[uid] + mfa_wav_path = read_2col_text(mfa_path + 'wav.scp')[uid] + return mfa_text, mfa_start, mfa_end, mfa_wav_path + + +# 获取需要被 mask 的 mel 帧的范围 +def get_masked_mel_bdy(mfa_start: List[float], + mfa_end: List[float], + fs: int, + hop_length: int, + span_to_repl: List[List[int]]): + align_start = np.array(mfa_start) + align_end = np.array(mfa_end) + align_start = np.floor(fs * align_start / hop_length).astype('int') + align_end = np.floor(fs * align_end / hop_length).astype('int') + if span_to_repl[0] >= len(mfa_start): + span_bdy = [align_end[-1], align_end[-1]] + else: + span_bdy = [ + align_start[span_to_repl[0]], align_end[span_to_repl[1] - 1] + ] + return span_bdy, align_start, align_end + + +def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]): + dic = {} + keys_to_del = [] + exist_idx = [] + sp_count = 0 + add_sp_count = 0 + for key in word2phns.keys(): + idx, wrd = key.split('_') + if wrd == 'sp': + sp_count += 1 + exist_idx.append(int(idx)) + else: + keys_to_del.append(key) + + for key in keys_to_del: + del word2phns[key] + + cur_id = 0 + for key in tp_word2phns.keys(): + if cur_id in exist_idx: + dic[str(cur_id) + "_sp"] = 'sp' + cur_id += 1 + add_sp_count += 1 + idx, wrd = key.split('_') + dic[str(cur_id) + "_" + wrd] = tp_word2phns[key] + cur_id += 1 + + if add_sp_count + 1 == sp_count: + dic[str(cur_id) + "_sp"] = 'sp' + add_sp_count += 1 + + assert add_sp_count == sp_count, "sp are not added in dic" + return dic + + +def get_max_idx(dic): + return sorted([int(key.split('_')[0]) for key in dic.keys()])[-1] + + +def get_phns_and_spans(wav_path: str, + old_str: str="", + new_str: str="", + source_lang: str="english", + target_lang: str="english"): + is_append = (old_str == new_str[:len(old_str)]) + old_phns, mfa_start, mfa_end = [], [], [] + # source + if source_lang == "english": + intervals, word2phns = alignment(wav_path, old_str) + elif source_lang == "chinese": + intervals, word2phns = alignment_zh(wav_path, old_str) + _, tp_word2phns = words2phns_zh(old_str) + + for key, value in tp_word2phns.items(): + idx, wrd = key.split('_') + cur_val = " ".join(value) + tp_word2phns[key] = cur_val + + word2phns = recover_dict(word2phns, tp_word2phns) + else: + assert source_lang == "chinese" or source_lang == "english", \ + "source_lang is wrong..." + + for item in intervals: + old_phns.append(item[0]) + mfa_start.append(float(item[1])) + mfa_end.append(float(item[2])) + # target + if is_append and (source_lang != target_lang): + cross_lingual_clone = True + else: + cross_lingual_clone = False + + if cross_lingual_clone: + str_origin = new_str[:len(old_str)] + str_append = new_str[len(old_str):] + + if target_lang == "chinese": + phns_origin, origin_word2phns = words2phns(str_origin) + phns_append, append_word2phns_tmp = words2phns_zh(str_append) + + elif target_lang == "english": + # 原始句子 + phns_origin, origin_word2phns = words2phns_zh(str_origin) + # clone 句子 + phns_append, append_word2phns_tmp = words2phns(str_append) + else: + assert target_lang == "chinese" or target_lang == "english", \ + "cloning is not support for this language, please check it." + + new_phns = phns_origin + phns_append + + append_word2phns = {} + length = len(origin_word2phns) + for key, value in append_word2phns_tmp.items(): + idx, wrd = key.split('_') + append_word2phns[str(int(idx) + length) + '_' + wrd] = value + new_word2phns = origin_word2phns.copy() + new_word2phns.update(append_word2phns) + + else: + if source_lang == target_lang and target_lang == "english": + new_phns, new_word2phns = words2phns(new_str) + elif source_lang == target_lang and target_lang == "chinese": + new_phns, new_word2phns = words2phns_zh(new_str) + else: + assert source_lang == target_lang, \ + "source language is not same with target language..." + + span_to_repl = [0, len(old_phns) - 1] + span_to_add = [0, len(new_phns) - 1] + left_idx = 0 + new_phns_left = [] + sp_count = 0 + # find the left different index + for key in word2phns.keys(): + idx, wrd = key.split('_') + if wrd == 'sp': + sp_count += 1 + new_phns_left.append('sp') + else: + idx = str(int(idx) - sp_count) + if idx + '_' + wrd in new_word2phns: + left_idx += len(new_word2phns[idx + '_' + wrd]) + new_phns_left.extend(word2phns[key].split()) + else: + span_to_repl[0] = len(new_phns_left) + span_to_add[0] = len(new_phns_left) + break + + # reverse word2phns and new_word2phns + right_idx = 0 + new_phns_right = [] + sp_count = 0 + word2phns_max_idx = get_max_idx(word2phns) + new_word2phns_max_idx = get_max_idx(new_word2phns) + new_phns_mid = [] + if is_append: + new_phns_right = [] + new_phns_mid = new_phns[left_idx:] + span_to_repl[0] = len(new_phns_left) + span_to_add[0] = len(new_phns_left) + span_to_add[1] = len(new_phns_left) + len(new_phns_mid) + span_to_repl[1] = len(old_phns) - len(new_phns_right) + # speech edit + else: + for key in list(word2phns.keys())[::-1]: + idx, wrd = key.split('_') + if wrd == 'sp': + sp_count += 1 + new_phns_right = ['sp'] + new_phns_right + else: + idx = str(new_word2phns_max_idx - (word2phns_max_idx - int(idx) + - sp_count)) + if idx + '_' + wrd in new_word2phns: + right_idx -= len(new_word2phns[idx + '_' + wrd]) + new_phns_right = word2phns[key].split() + new_phns_right + else: + span_to_repl[1] = len(old_phns) - len(new_phns_right) + new_phns_mid = new_phns[left_idx:right_idx] + span_to_add[1] = len(new_phns_left) + len(new_phns_mid) + if len(new_phns_mid) == 0: + span_to_add[1] = min(span_to_add[1] + 1, len(new_phns)) + span_to_add[0] = max(0, span_to_add[0] - 1) + span_to_repl[0] = max(0, span_to_repl[0] - 1) + span_to_repl[1] = min(span_to_repl[1] + 1, + len(old_phns)) + break + new_phns = new_phns_left + new_phns_mid + new_phns_right + ''' + For that reason cover should not be given. + For that reason cover is impossible to be given. + span_to_repl: [17, 23] "should not" + span_to_add: [17, 30] "is impossible to" + ''' + return mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add + + +# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同 +# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放 +def get_dur_adj_factor(orig_dur: List[int], + pred_dur: List[int], + phns: List[str]): + length = 0 + factor_list = [] + for orig, pred, phn in zip(orig_dur, pred_dur, phns): + if pred == 0 or phn == 'sp': + continue + else: + factor_list.append(orig / pred) + factor_list = np.array(factor_list) + factor_list.sort() + if len(factor_list) < 5: + return 1 + length = 2 + avg = np.average(factor_list[length:-length]) + return avg + + +def prep_feats_with_dur(wav_path: str, + mlm_model: nn.Layer, + source_lang: str="English", + target_lang: str="English", + old_str: str="", + new_str: str="", + mask_reconstruct: bool=False, + duration_adjust: bool=True, + start_end_sp: bool=False, + fs: int=24000, + hop_length: int=300): + ''' + Returns: + np.ndarray: new wav, replace the part to be edited in original wav with 0 + List[str]: new phones + List[float]: mfa start of new wav + List[float]: mfa end of new wav + List[int]: masked mel boundary of original wav + List[int]: masked mel boundary of new wav + ''' + wav_org, _ = librosa.load(wav_path, sr=fs) + + mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add = get_phns_and_spans( + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + source_lang=source_lang, + target_lang=target_lang) + + if start_end_sp: + if new_phns[-1] != 'sp': + new_phns = new_phns + ['sp'] + # 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替 + if target_lang == "english" or target_lang == "chinese": + old_durs = eval_durs(old_phns, target_lang=source_lang) + else: + assert target_lang == "chinese" or target_lang == "english", \ + "calculate duration_predict is not support for this language..." + + orig_old_durs = [e - s for e, s in zip(mfa_end, mfa_start)] + if '[MASK]' in new_str: + new_phns = old_phns + span_to_add = span_to_repl + d_factor_left = get_dur_adj_factor( + orig_dur=orig_old_durs[:span_to_repl[0]], + pred_dur=old_durs[:span_to_repl[0]], + phns=old_phns[:span_to_repl[0]]) + d_factor_right = get_dur_adj_factor( + orig_dur=orig_old_durs[span_to_repl[1]:], + pred_dur=old_durs[span_to_repl[1]:], + phns=old_phns[span_to_repl[1]:]) + d_factor = (d_factor_left + d_factor_right) / 2 + new_durs_adjusted = [d_factor * i for i in old_durs] + else: + if duration_adjust: + d_factor = get_dur_adj_factor( + orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns) + d_factor = d_factor * 1.25 + else: + d_factor = 1 + + if target_lang == "english" or target_lang == "chinese": + new_durs = eval_durs(new_phns, target_lang=target_lang) + else: + assert target_lang == "chinese" or target_lang == "english", \ + "calculate duration_predict is not support for this language..." + + new_durs_adjusted = [d_factor * i for i in new_durs] + + new_span_dur_sum = sum(new_durs_adjusted[span_to_add[0]:span_to_add[1]]) + old_span_dur_sum = sum(orig_old_durs[span_to_repl[0]:span_to_repl[1]]) + dur_offset = new_span_dur_sum - old_span_dur_sum + new_mfa_start = mfa_start[:span_to_repl[0]] + new_mfa_end = mfa_end[:span_to_repl[0]] + for i in new_durs_adjusted[span_to_add[0]:span_to_add[1]]: + if len(new_mfa_end) == 0: + new_mfa_start.append(0) + new_mfa_end.append(i) + else: + new_mfa_start.append(new_mfa_end[-1]) + new_mfa_end.append(new_mfa_end[-1] + i) + new_mfa_start += [i + dur_offset for i in mfa_start[span_to_repl[1]:]] + new_mfa_end += [i + dur_offset for i in mfa_end[span_to_repl[1]:]] + + # 3. get new wav + # 在原始句子后拼接 + if span_to_repl[0] >= len(mfa_start): + left_idx = len(wav_org) + right_idx = left_idx + # 在原始句子中间替换 + else: + left_idx = int(np.floor(mfa_start[span_to_repl[0]] * fs)) + right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * fs)) + blank_wav = np.zeros( + (int(np.ceil(new_span_dur_sum * fs)), ), dtype=wav_org.dtype) + # 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定 + new_wav = np.concatenate( + [wav_org[:left_idx], blank_wav, wav_org[right_idx:]]) + + # 4. get old and new mel span to be mask + # [92, 92] + + old_span_bdy, mfa_start, mfa_end = get_masked_mel_bdy( + mfa_start=mfa_start, + mfa_end=mfa_end, + fs=fs, + hop_length=hop_length, + span_to_repl=span_to_repl) + # [92, 174] + # new_mfa_start, new_mfa_end 时间级别的开始和结束时间 -> 帧级别 + new_span_bdy, new_mfa_start, new_mfa_end = get_masked_mel_bdy( + mfa_start=new_mfa_start, + mfa_end=new_mfa_end, + fs=fs, + hop_length=hop_length, + span_to_repl=span_to_add) + + # old_span_bdy, new_span_bdy 是帧级别的范围 + return new_wav, new_phns, new_mfa_start, new_mfa_end, old_span_bdy, new_span_bdy + + +def prep_feats(mlm_model: nn.Layer, + wav_path: str, + source_lang: str="english", + target_lang: str="english", + old_str: str="", + new_str: str="", + duration_adjust: bool=True, + start_end_sp: bool=False, + mask_reconstruct: bool=False, + fs: int=24000, + hop_length: int=300, + token_list: List[str]=[]): + wav, phns, mfa_start, mfa_end, old_span_bdy, new_span_bdy = prep_feats_with_dur( + source_lang=source_lang, + target_lang=target_lang, + mlm_model=mlm_model, + old_str=old_str, + new_str=new_str, + wav_path=wav_path, + duration_adjust=duration_adjust, + start_end_sp=start_end_sp, + mask_reconstruct=mask_reconstruct, + fs=fs, + hop_length=hop_length) + + token_to_id = {item: i for i, item in enumerate(token_list)} + text = np.array( + list(map(lambda x: token_to_id.get(x, token_to_id['']), phns))) + span_bdy = np.array(new_span_bdy) + + batch = [('1', { + "speech": wav, + "align_start": mfa_start, + "align_end": mfa_end, + "text": text, + "span_bdy": span_bdy + })] + + return batch, old_span_bdy, new_span_bdy + + +def decode_with_model(mlm_model: nn.Layer, + collate_fn, + wav_path: str, + source_lang: str="english", + target_lang: str="english", + old_str: str="", + new_str: str="", + use_teacher_forcing: bool=False, + duration_adjust: bool=True, + start_end_sp: bool=False, + fs: int=24000, + hop_length: int=300, + token_list: List[str]=[]): + batch, old_span_bdy, new_span_bdy = prep_feats( + source_lang=source_lang, + target_lang=target_lang, + mlm_model=mlm_model, + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + duration_adjust=duration_adjust, + start_end_sp=start_end_sp, + fs=fs, + hop_length=hop_length, + token_list=token_list) + + feats = collate_fn(batch)[1] + + if 'text_masked_pos' in feats.keys(): + feats.pop('text_masked_pos') + + output = mlm_model.inference( + text=feats['text'], + speech=feats['speech'], + masked_pos=feats['masked_pos'], + speech_mask=feats['speech_mask'], + text_mask=feats['text_mask'], + speech_seg_pos=feats['speech_seg_pos'], + text_seg_pos=feats['text_seg_pos'], + span_bdy=new_span_bdy, + use_teacher_forcing=use_teacher_forcing) + + # 拼接音频 + output_feat = paddle.concat(x=output, axis=0) + wav_org, _ = librosa.load(wav_path, sr=fs) + return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length + + +def get_mlm_output(wav_path: str, + model_name: str="paddle_checkpoint_en", + source_lang: str="english", + target_lang: str="english", + old_str: str="", + new_str: str="", + use_teacher_forcing: bool=False, + duration_adjust: bool=True, + start_end_sp: bool=False): + mlm_model, train_conf = load_model(model_name) + mlm_model.eval() + + collate_fn = build_mlm_collate_fn( + sr=train_conf.feats_extract_conf['fs'], + n_fft=train_conf.feats_extract_conf['n_fft'], + hop_length=train_conf.feats_extract_conf['hop_length'], + win_length=train_conf.feats_extract_conf['win_length'], + n_mels=train_conf.feats_extract_conf['n_mels'], + fmin=train_conf.feats_extract_conf['fmin'], + fmax=train_conf.feats_extract_conf['fmax'], + mlm_prob=train_conf['mlm_prob'], + mean_phn_span=train_conf['mean_phn_span'], + seg_emb=train_conf.encoder_conf['input_layer'] == 'sega_mlm') + + return decode_with_model( + source_lang=source_lang, + target_lang=target_lang, + mlm_model=mlm_model, + collate_fn=collate_fn, + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + use_teacher_forcing=use_teacher_forcing, + duration_adjust=duration_adjust, + start_end_sp=start_end_sp, + fs=train_conf.feats_extract_conf['fs'], + hop_length=train_conf.feats_extract_conf['hop_length'], + token_list=train_conf.token_list) + + +def evaluate(uid: str, + source_lang: str="english", + target_lang: str="english", + prefix: os.PathLike="./prompt/dev/", + model_name: str="paddle_checkpoint_en", + new_str: str="", + prompt_decoding: bool=False, + task_name: str=None): + + # get origin text and path of origin wav + old_str, wav_path = read_data(uid=uid, prefix=prefix) + + if task_name == 'edit': + new_str = new_str + elif task_name == 'synthesize': + new_str = old_str + new_str + else: + new_str = old_str + ' '.join([ch for ch in new_str if is_chinese(ch)]) + + print('new_str is ', new_str) + + results_dict = get_wav( + source_lang=source_lang, + target_lang=target_lang, + model_name=model_name, + wav_path=wav_path, + old_str=old_str, + new_str=new_str) + return results_dict + + +if __name__ == "__main__": + # parse config and args + args = parse_args() + + data_dict = evaluate( + uid=args.uid, + source_lang=args.source_lang, + target_lang=args.target_lang, + prefix=args.prefix, + model_name=args.model_name, + new_str=args.new_str, + task_name=args.task_name) + sf.write(args.output_name, data_dict['output'], samplerate=24000) + print("finished...") diff --git a/examples/ernie_sat/local/sedit_arg_parser.py b/examples/ernie_sat/local/sedit_arg_parser.py new file mode 100644 index 000000000..21c6d0b4b --- /dev/null +++ b/examples/ernie_sat/local/sedit_arg_parser.py @@ -0,0 +1,84 @@ +import argparse + + +def parse_args(): + # parse args and config and redirect to train_sp + parser = argparse.ArgumentParser( + description="Synthesize with acoustic model & vocoder") + # acoustic model + parser.add_argument( + '--am', + type=str, + default='fastspeech2_csmsc', + choices=[ + 'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_ljspeech', + 'fastspeech2_aishell3', 'fastspeech2_vctk', 'tacotron2_csmsc', + 'tacotron2_ljspeech', 'tacotron2_aishell3' + ], + help='Choose acoustic model type of tts task.') + parser.add_argument( + '--am_config', + type=str, + default=None, + help='Config of acoustic model. Use deault config when it is None.') + parser.add_argument( + '--am_ckpt', + type=str, + default=None, + help='Checkpoint file of acoustic model.') + parser.add_argument( + "--am_stat", + type=str, + default=None, + help="mean and standard deviation used to normalize spectrogram when training acoustic model." + ) + parser.add_argument( + "--phones_dict", type=str, default=None, help="phone vocabulary file.") + parser.add_argument( + "--tones_dict", type=str, default=None, help="tone vocabulary file.") + parser.add_argument( + "--speaker_dict", type=str, default=None, help="speaker id map file.") + + # vocoder + parser.add_argument( + '--voc', + type=str, + default='pwgan_aishell3', + choices=[ + 'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk', + 'mb_melgan_csmsc', 'wavernn_csmsc', 'hifigan_csmsc', + 'hifigan_ljspeech', 'hifigan_aishell3', 'hifigan_vctk', + 'style_melgan_csmsc' + ], + help='Choose vocoder type of tts task.') + parser.add_argument( + '--voc_config', + type=str, + default=None, + help='Config of voc. Use deault config when it is None.') + parser.add_argument( + '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') + parser.add_argument( + "--voc_stat", + type=str, + default=None, + help="mean and standard deviation used to normalize spectrogram when training voc." + ) + # other + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + + parser.add_argument("--model_name", type=str, help="model name") + parser.add_argument("--uid", type=str, help="uid") + parser.add_argument("--new_str", type=str, help="new string") + parser.add_argument("--prefix", type=str, help="prefix") + parser.add_argument( + "--source_lang", type=str, default="english", help="source language") + parser.add_argument( + "--target_lang", type=str, default="english", help="target language") + parser.add_argument("--output_name", type=str, help="output name") + parser.add_argument("--task_name", type=str, help="task name") + + # pre + args = parser.parse_args() + return args diff --git a/examples/ernie_sat/local/utils.py b/examples/ernie_sat/local/utils.py new file mode 100644 index 000000000..836942a26 --- /dev/null +++ b/examples/ernie_sat/local/utils.py @@ -0,0 +1,162 @@ +from pathlib import Path +from typing import Dict +from typing import List +from typing import Union + +import numpy as np +import paddle +import yaml +from sedit_arg_parser import parse_args +from yacs.config import CfgNode + +from paddlespeech.t2s.exps.syn_utils import get_am_inference +from paddlespeech.t2s.exps.syn_utils import get_voc_inference + + +def read_2col_text(path: Union[Path, str]) -> Dict[str, str]: + """Read a text file having 2 column as dict object. + + Examples: + wav.scp: + key1 /some/path/a.wav + key2 /some/path/b.wav + + >>> read_2col_text('wav.scp') + {'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'} + + """ + + data = {} + with Path(path).open("r", encoding="utf-8") as f: + for linenum, line in enumerate(f, 1): + sps = line.rstrip().split(maxsplit=1) + if len(sps) == 1: + k, v = sps[0], "" + else: + k, v = sps + if k in data: + raise RuntimeError(f"{k} is duplicated ({path}:{linenum})") + data[k] = v + return data + + +def load_num_sequence_text(path: Union[Path, str], loader_type: str="csv_int" + ) -> Dict[str, List[Union[float, int]]]: + """Read a text file indicating sequences of number + + Examples: + key1 1 2 3 + key2 34 5 6 + + >>> d = load_num_sequence_text('text') + >>> np.testing.assert_array_equal(d["key1"], np.array([1, 2, 3])) + """ + if loader_type == "text_int": + delimiter = " " + dtype = int + elif loader_type == "text_float": + delimiter = " " + dtype = float + elif loader_type == "csv_int": + delimiter = "," + dtype = int + elif loader_type == "csv_float": + delimiter = "," + dtype = float + else: + raise ValueError(f"Not supported loader_type={loader_type}") + + # path looks like: + # utta 1,0 + # uttb 3,4,5 + # -> return {'utta': np.ndarray([1, 0]), + # 'uttb': np.ndarray([3, 4, 5])} + d = read_2column_text(path) + # Using for-loop instead of dict-comprehension for debuggability + retval = {} + for k, v in d.items(): + try: + retval[k] = [dtype(i) for i in v.split(delimiter)] + except TypeError: + print(f'Error happened with path="{path}", id="{k}", value="{v}"') + raise + return retval + + +def is_chinese(ch): + if u'\u4e00' <= ch <= u'\u9fff': + return True + else: + return False + + +def get_voc_out(mel): + # vocoder + args = parse_args() + with open(args.voc_config) as f: + voc_config = CfgNode(yaml.safe_load(f)) + voc_inference = get_voc_inference( + voc=args.voc, + voc_config=voc_config, + voc_ckpt=args.voc_ckpt, + voc_stat=args.voc_stat) + + with paddle.no_grad(): + wav = voc_inference(mel) + return np.squeeze(wav) + + +def eval_durs(phns, target_lang="chinese", fs=24000, hop_length=300): + args = parse_args() + + if target_lang == 'english': + args.am = "fastspeech2_ljspeech" + args.am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml" + args.am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz" + args.am_stat = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy" + args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt" + + elif target_lang == 'chinese': + args.am = "fastspeech2_csmsc" + args.am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml" + args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz" + args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy" + args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt" + + if args.ngpu == 0: + paddle.set_device("cpu") + elif args.ngpu > 0: + paddle.set_device("gpu") + else: + print("ngpu should >= 0 !") + + # Init body. + with open(args.am_config) as f: + am_config = CfgNode(yaml.safe_load(f)) + + am_inference, am = get_am_inference( + am=args.am, + am_config=am_config, + am_ckpt=args.am_ckpt, + am_stat=args.am_stat, + phones_dict=args.phones_dict, + tones_dict=args.tones_dict, + speaker_dict=args.speaker_dict, + return_am=True) + + vocab_phones = {} + with open(args.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + for tone, id in phn_id: + vocab_phones[tone] = int(id) + vocab_size = len(vocab_phones) + phonemes = [phn if phn in vocab_phones else "sp" for phn in phns] + + phone_ids = [vocab_phones[item] for item in phonemes] + phone_ids.append(vocab_size - 1) + phone_ids = paddle.to_tensor(np.array(phone_ids, np.int64)) + _, d_outs, _, _ = am.inference(phone_ids, spk_id=None, spk_emb=None) + pre_d_outs = d_outs + phu_durs_new = pre_d_outs * hop_length / fs + phu_durs_new = phu_durs_new.tolist()[:-1] + return phu_durs_new diff --git a/examples/ernie_sat/path.sh b/examples/ernie_sat/path.sh new file mode 100755 index 000000000..d46d2f612 --- /dev/null +++ b/examples/ernie_sat/path.sh @@ -0,0 +1,13 @@ +#!/bin/bash +export MAIN_ROOT=`realpath ${PWD}/../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +MODEL=ernie_sat +export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL} \ No newline at end of file diff --git a/examples/ernie_sat/prompt/dev/text b/examples/ernie_sat/prompt/dev/text new file mode 100644 index 000000000..f79cdcb42 --- /dev/null +++ b/examples/ernie_sat/prompt/dev/text @@ -0,0 +1,3 @@ +p243_new For that reason cover should not be given. +Prompt_003_new This was not the show for me. +p299_096 We are trying to establish a date. diff --git a/examples/ernie_sat/prompt/dev/wav.scp b/examples/ernie_sat/prompt/dev/wav.scp new file mode 100644 index 000000000..eb0e8e48d --- /dev/null +++ b/examples/ernie_sat/prompt/dev/wav.scp @@ -0,0 +1,3 @@ +p243_new ../../prompt_wav/p243_313.wav +Prompt_003_new ../../prompt_wav/this_was_not_the_show_for_me.wav +p299_096 ../../prompt_wav/p299_096.wav diff --git a/examples/ernie_sat/run_clone_en_to_zh.sh b/examples/ernie_sat/run_clone_en_to_zh.sh new file mode 100755 index 000000000..68b1c7544 --- /dev/null +++ b/examples/ernie_sat/run_clone_en_to_zh.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +set -e +source path.sh + +# en --> zh 的 语音合成 +# 根据 Prompt_003_new 作为提示语音: This was not the show for me. 来合成: '今天天气很好' +# 注: 输入的 new_str 需为中文汉字, 否则会通过预处理只保留中文汉字, 即合成预处理后的中文语音。 + +python local/inference.py \ + --task_name=cross-lingual_clone \ + --model_name=paddle_checkpoint_dual_mask_enzh \ + --uid=Prompt_003_new \ + --new_str='今天天气很好.' \ + --prefix='./prompt/dev/' \ + --source_lang=english \ + --target_lang=chinese \ + --output_name=pred_clone.wav \ + --voc=pwgan_aishell3 \ + --voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \ + --voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ + --voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \ + --am=fastspeech2_csmsc \ + --am_config=download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml \ + --am_ckpt=download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz \ + --am_stat=download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy \ + --phones_dict=download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt diff --git a/examples/ernie_sat/run_gen_en.sh b/examples/ernie_sat/run_gen_en.sh new file mode 100755 index 000000000..a0641bc7f --- /dev/null +++ b/examples/ernie_sat/run_gen_en.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +set -e +source path.sh + +# 纯英文的语音合成 +# 样例为根据 p299_096 对应的语音作为提示语音: This was not the show for me. 来合成: 'I enjoy my life.' + +python local/inference.py \ + --task_name=synthesize \ + --model_name=paddle_checkpoint_en \ + --uid=p299_096 \ + --new_str='I enjoy my life, do you?' \ + --prefix='./prompt/dev/' \ + --source_lang=english \ + --target_lang=english \ + --output_name=pred_gen.wav \ + --voc=pwgan_aishell3 \ + --voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \ + --voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ + --voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \ + --am=fastspeech2_ljspeech \ + --am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \ + --am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \ + --am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \ + --phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt diff --git a/examples/ernie_sat/run_sedit_en.sh b/examples/ernie_sat/run_sedit_en.sh new file mode 100755 index 000000000..eec7d6402 --- /dev/null +++ b/examples/ernie_sat/run_sedit_en.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +set -e +source path.sh + +# 纯英文的语音编辑 +# 样例为把 p243_new 对应的原始语音: For that reason cover should not be given.编辑成 'for that reason cover is impossible to be given.' 对应的语音 +# NOTE: 语音编辑任务暂支持句子中 1 个位置的替换或者插入文本操作 + +python local/inference.py \ + --task_name=edit \ + --model_name=paddle_checkpoint_en \ + --uid=p243_new \ + --new_str='for that reason cover is impossible to be given.' \ + --prefix='./prompt/dev/' \ + --source_lang=english \ + --target_lang=english \ + --output_name=pred_edit.wav \ + --voc=pwgan_aishell3 \ + --voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \ + --voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ + --voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \ + --am=fastspeech2_ljspeech \ + --am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \ + --am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \ + --am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \ + --phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt diff --git a/examples/ernie_sat/test_run.sh b/examples/ernie_sat/test_run.sh new file mode 100755 index 000000000..75b6a5691 --- /dev/null +++ b/examples/ernie_sat/test_run.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +rm -rf *.wav +./run_sedit_en.sh # 语音编辑任务(英文) +./run_gen_en.sh # 个性化语音合成任务(英文) +./run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆) \ No newline at end of file diff --git a/examples/ernie_sat/tools/.gitkeep b/examples/ernie_sat/tools/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/examples/vctk/ernie_sat/README.md b/examples/vctk/ernie_sat/README.md new file mode 100644 index 000000000..055e7903d --- /dev/null +++ b/examples/vctk/ernie_sat/README.md @@ -0,0 +1 @@ +# ERNIE SAT with VCTK dataset diff --git a/paddlespeech/t2s/datasets/am_batch_fn.py b/paddlespeech/t2s/datasets/am_batch_fn.py index 0b278abaf..1c70b1cdc 100644 --- a/paddlespeech/t2s/datasets/am_batch_fn.py +++ b/paddlespeech/t2s/datasets/am_batch_fn.py @@ -11,10 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Collection +from typing import Dict +from typing import List +from typing import Tuple + import numpy as np import paddle from paddlespeech.t2s.datasets.batch import batch_sequences +from paddlespeech.t2s.datasets.get_feats import LogMelFBank +from paddlespeech.t2s.modules.nets_utils import get_seg_pos +from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask +from paddlespeech.t2s.modules.nets_utils import pad_list +from paddlespeech.t2s.modules.nets_utils import phones_masking +from paddlespeech.t2s.modules.nets_utils import phones_text_masking def tacotron2_single_spk_batch_fn(examples): @@ -335,3 +346,182 @@ def vits_single_spk_batch_fn(examples): "speech": speech } return batch + + +# for ERNIE SAT +class MLMCollateFn: + """Functor class of common_collate_fn()""" + + def __init__( + self, + feats_extract, + mlm_prob: float=0.8, + mean_phn_span: int=8, + seg_emb: bool=False, + text_masking: bool=False, + attention_window: int=0, + not_sequence: Collection[str]=(), ): + self.mlm_prob = mlm_prob + self.mean_phn_span = mean_phn_span + self.feats_extract = feats_extract + self.not_sequence = set(not_sequence) + self.attention_window = attention_window + self.seg_emb = seg_emb + self.text_masking = text_masking + + def __call__(self, data: Collection[Tuple[str, Dict[str, np.ndarray]]] + ) -> Tuple[List[str], Dict[str, paddle.Tensor]]: + return mlm_collate_fn( + data, + feats_extract=self.feats_extract, + mlm_prob=self.mlm_prob, + mean_phn_span=self.mean_phn_span, + seg_emb=self.seg_emb, + text_masking=self.text_masking, + attention_window=self.attention_window, + not_sequence=self.not_sequence) + + +def mlm_collate_fn( + data: Collection[Tuple[str, Dict[str, np.ndarray]]], + feats_extract=None, + mlm_prob: float=0.8, + mean_phn_span: int=8, + seg_emb: bool=False, + text_masking: bool=False, + attention_window: int=0, + pad_value: int=0, + not_sequence: Collection[str]=(), +) -> Tuple[List[str], Dict[str, paddle.Tensor]]: + uttids = [u for u, _ in data] + data = [d for _, d in data] + + assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching" + assert all(not k.endswith("_lens") + for k in data[0]), f"*_lens is reserved: {list(data[0])}" + + output = {} + for key in data[0]: + + array_list = [d[key] for d in data] + + # Assume the first axis is length: + # tensor_list: Batch x (Length, ...) + tensor_list = [paddle.to_tensor(a) for a in array_list] + # tensor: (Batch, Length, ...) + tensor = pad_list(tensor_list, pad_value) + output[key] = tensor + + # lens: (Batch,) + if key not in not_sequence: + lens = paddle.to_tensor( + [d[key].shape[0] for d in data], dtype=paddle.int64) + output[key + "_lens"] = lens + + feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0])) + feats = paddle.to_tensor(feats) + feats_lens = paddle.shape(feats)[0] + feats = paddle.unsqueeze(feats, 0) + + text = output["text"] + text_lens = output["text_lens"] + align_start = output["align_start"] + align_start_lens = output["align_start_lens"] + align_end = output["align_end"] + + max_tlen = max(text_lens) + max_slen = max(feats_lens) + + speech_pad = feats[:, :max_slen] + + text_pad = text + text_mask = make_non_pad_mask( + text_lens, text_pad, length_dim=1).unsqueeze(-2) + speech_mask = make_non_pad_mask( + feats_lens, speech_pad[:, :, 0], length_dim=1).unsqueeze(-2) + span_bdy = None + if 'span_bdy' in output.keys(): + span_bdy = output['span_bdy'] + + # dual_mask 的是混合中英时候同时 mask 语音和文本 + # ernie sat 在实现跨语言的时候都 mask 了 + if text_masking: + masked_pos, text_masked_pos = phones_text_masking( + xs_pad=speech_pad, + src_mask=speech_mask, + text_pad=text_pad, + text_mask=text_mask, + align_start=align_start, + align_end=align_end, + align_start_lens=align_start_lens, + mlm_prob=mlm_prob, + mean_phn_span=mean_phn_span, + span_bdy=span_bdy) + # 训练纯中文和纯英文的 -> a3t 没有对 phoneme 做 mask, 只对语音 mask 了 + # a3t 和 ernie sat 的区别主要在于做 mask 的时候 + else: + masked_pos = phones_masking( + xs_pad=speech_pad, + src_mask=speech_mask, + align_start=align_start, + align_end=align_end, + align_start_lens=align_start_lens, + mlm_prob=mlm_prob, + mean_phn_span=mean_phn_span, + span_bdy=span_bdy) + text_masked_pos = paddle.zeros(paddle.shape(text_pad)) + + output_dict = {} + + speech_seg_pos, text_seg_pos = get_seg_pos( + speech_pad=speech_pad, + text_pad=text_pad, + align_start=align_start, + align_end=align_end, + align_start_lens=align_start_lens, + seg_emb=seg_emb) + output_dict['speech'] = speech_pad + output_dict['text'] = text_pad + output_dict['masked_pos'] = masked_pos + output_dict['text_masked_pos'] = text_masked_pos + output_dict['speech_mask'] = speech_mask + output_dict['text_mask'] = text_mask + output_dict['speech_seg_pos'] = speech_seg_pos + output_dict['text_seg_pos'] = text_seg_pos + output = (uttids, output_dict) + return output + + +def build_mlm_collate_fn( + sr: int=24000, + n_fft: int=2048, + hop_length: int=300, + win_length: int=None, + n_mels: int=80, + fmin: int=80, + fmax: int=7600, + mlm_prob: float=0.8, + mean_phn_span: int=8, + seg_emb: bool=False, + epoch: int=-1, ): + feats_extract_class = LogMelFBank + + feats_extract = feats_extract_class( + sr=sr, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mels=n_mels, + fmin=fmin, + fmax=fmax) + + if epoch == -1: + mlm_prob_factor = 1 + else: + mlm_prob_factor = 0.8 + + return MLMCollateFn( + feats_extract=feats_extract, + mlm_prob=mlm_prob * mlm_prob_factor, + mean_phn_span=mean_phn_span, + seg_emb=seg_emb) diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index 6b9f41a6b..cabea9897 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -147,14 +147,14 @@ def get_frontend(lang: str='zh', # dygraph -def get_am_inference( - am: str='fastspeech2_csmsc', - am_config: CfgNode=None, - am_ckpt: Optional[os.PathLike]=None, - am_stat: Optional[os.PathLike]=None, - phones_dict: Optional[os.PathLike]=None, - tones_dict: Optional[os.PathLike]=None, - speaker_dict: Optional[os.PathLike]=None, ): +def get_am_inference(am: str='fastspeech2_csmsc', + am_config: CfgNode=None, + am_ckpt: Optional[os.PathLike]=None, + am_stat: Optional[os.PathLike]=None, + phones_dict: Optional[os.PathLike]=None, + tones_dict: Optional[os.PathLike]=None, + speaker_dict: Optional[os.PathLike]=None, + return_am: bool=False): with open(phones_dict, "r") as f: phn_id = [line.strip().split() for line in f.readlines()] vocab_size = len(phn_id) @@ -203,7 +203,10 @@ def get_am_inference( am_inference = am_inference_class(am_normalizer, am) am_inference.eval() print("acoustic model done!") - return am_inference + if return_am: + return am_inference, am + else: + return am_inference def get_voc_inference( diff --git a/paddlespeech/t2s/models/__init__.py b/paddlespeech/t2s/models/__init__.py index 0b6f29119..d8df4368a 100644 --- a/paddlespeech/t2s/models/__init__.py +++ b/paddlespeech/t2s/models/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .ernie_sat import * from .fastspeech2 import * from .hifigan import * from .melgan import * diff --git a/paddlespeech/t2s/models/ernie_sat/__init__.py b/paddlespeech/t2s/models/ernie_sat/__init__.py new file mode 100644 index 000000000..dc86fa514 --- /dev/null +++ b/paddlespeech/t2s/models/ernie_sat/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .mlm import * diff --git a/paddlespeech/t2s/models/ernie_sat/mlm.py b/paddlespeech/t2s/models/ernie_sat/mlm.py new file mode 100644 index 000000000..c9c3d67a6 --- /dev/null +++ b/paddlespeech/t2s/models/ernie_sat/mlm.py @@ -0,0 +1,601 @@ +import argparse +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import paddle +import yaml +from paddle import nn +from yacs.config import CfgNode + +from paddlespeech.t2s.modules.activation import get_activation +from paddlespeech.t2s.modules.conformer.convolution import ConvolutionModule +from paddlespeech.t2s.modules.conformer.encoder_layer import EncoderLayer +from paddlespeech.t2s.modules.layer_norm import LayerNorm +from paddlespeech.t2s.modules.masked_fill import masked_fill +from paddlespeech.t2s.modules.nets_utils import initialize +from paddlespeech.t2s.modules.tacotron2.decoder import Postnet +from paddlespeech.t2s.modules.transformer.attention import LegacyRelPositionMultiHeadedAttention +from paddlespeech.t2s.modules.transformer.attention import MultiHeadedAttention +from paddlespeech.t2s.modules.transformer.attention import RelPositionMultiHeadedAttention +from paddlespeech.t2s.modules.transformer.embedding import LegacyRelPositionalEncoding +from paddlespeech.t2s.modules.transformer.embedding import PositionalEncoding +from paddlespeech.t2s.modules.transformer.embedding import RelPositionalEncoding +from paddlespeech.t2s.modules.transformer.embedding import ScaledPositionalEncoding +from paddlespeech.t2s.modules.transformer.multi_layer_conv import Conv1dLinear +from paddlespeech.t2s.modules.transformer.multi_layer_conv import MultiLayeredConv1d +from paddlespeech.t2s.modules.transformer.positionwise_feed_forward import PositionwiseFeedForward +from paddlespeech.t2s.modules.transformer.repeat import repeat +from paddlespeech.t2s.modules.transformer.subsampling import Conv2dSubsampling + + +# MLM -> Mask Language Model +class mySequential(nn.Sequential): + def forward(self, *inputs): + for module in self._sub_layers.values(): + if type(inputs) == tuple: + inputs = module(*inputs) + else: + inputs = module(inputs) + return inputs + + +class MaskInputLayer(nn.Layer): + def __init__(self, out_features: int) -> None: + super().__init__() + self.mask_feature = paddle.create_parameter( + shape=(1, 1, out_features), + dtype=paddle.float32, + default_initializer=paddle.nn.initializer.Assign( + paddle.normal(shape=(1, 1, out_features)))) + + def forward(self, input: paddle.Tensor, + masked_pos: paddle.Tensor=None) -> paddle.Tensor: + masked_pos = paddle.expand_as(paddle.unsqueeze(masked_pos, -1), input) + masked_input = masked_fill(input, masked_pos, 0) + masked_fill( + paddle.expand_as(self.mask_feature, input), ~masked_pos, 0) + return masked_input + + +class MLMEncoder(nn.Layer): + """Conformer encoder module. + + Args: + idim (int): Input dimension. + attention_dim (int): Dimension of attention. + attention_heads (int): The number of heads of multi head attention. + linear_units (int): The number of units of position-wise feed forward. + num_blocks (int): The number of decoder blocks. + dropout_rate (float): Dropout rate. + positional_dropout_rate (float): Dropout rate after adding positional encoding. + attention_dropout_rate (float): Dropout rate in attention. + input_layer (Union[str, paddle.nn.Layer]): Input layer type. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". + positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. + macaron_style (bool): Whether to use macaron style for positionwise layer. + pos_enc_layer_type (str): Encoder positional encoding layer type. + selfattention_layer_type (str): Encoder attention layer type. + activation_type (str): Encoder activation function type. + use_cnn_module (bool): Whether to use convolution module. + zero_triu (bool): Whether to zero the upper triangular part of attention matrix. + cnn_module_kernel (int): Kernerl size of convolution module. + padding_idx (int): Padding idx for input_layer=embed. + stochastic_depth_rate (float): Maximum probability to skip the encoder layer. + + """ + + def __init__(self, + idim: int, + vocab_size: int=0, + pre_speech_layer: int=0, + attention_dim: int=256, + attention_heads: int=4, + linear_units: int=2048, + num_blocks: int=6, + dropout_rate: float=0.1, + positional_dropout_rate: float=0.1, + attention_dropout_rate: float=0.0, + input_layer: str="conv2d", + normalize_before: bool=True, + concat_after: bool=False, + positionwise_layer_type: str="linear", + positionwise_conv_kernel_size: int=1, + macaron_style: bool=False, + pos_enc_layer_type: str="abs_pos", + pos_enc_class=None, + selfattention_layer_type: str="selfattn", + activation_type: str="swish", + use_cnn_module: bool=False, + zero_triu: bool=False, + cnn_module_kernel: int=31, + padding_idx: int=-1, + stochastic_depth_rate: float=0.0, + text_masking: bool=False): + """Construct an Encoder object.""" + super().__init__() + self._output_size = attention_dim + self.text_masking = text_masking + if self.text_masking: + self.text_masking_layer = MaskInputLayer(attention_dim) + activation = get_activation(activation_type) + if pos_enc_layer_type == "abs_pos": + pos_enc_class = PositionalEncoding + elif pos_enc_layer_type == "scaled_abs_pos": + pos_enc_class = ScaledPositionalEncoding + elif pos_enc_layer_type == "rel_pos": + assert selfattention_layer_type == "rel_selfattn" + pos_enc_class = RelPositionalEncoding + elif pos_enc_layer_type == "legacy_rel_pos": + pos_enc_class = LegacyRelPositionalEncoding + assert selfattention_layer_type == "legacy_rel_selfattn" + else: + raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) + + self.conv_subsampling_factor = 1 + if input_layer == "linear": + self.embed = nn.Sequential( + nn.Linear(idim, attention_dim), + nn.LayerNorm(attention_dim), + nn.Dropout(dropout_rate), + nn.ReLU(), + pos_enc_class(attention_dim, positional_dropout_rate), ) + elif input_layer == "conv2d": + self.embed = Conv2dSubsampling( + idim, + attention_dim, + dropout_rate, + pos_enc_class(attention_dim, positional_dropout_rate), ) + self.conv_subsampling_factor = 4 + elif input_layer == "embed": + self.embed = nn.Sequential( + nn.Embedding(idim, attention_dim, padding_idx=padding_idx), + pos_enc_class(attention_dim, positional_dropout_rate), ) + elif input_layer == "mlm": + self.segment_emb = None + self.speech_embed = mySequential( + MaskInputLayer(idim), + nn.Linear(idim, attention_dim), + nn.LayerNorm(attention_dim), + nn.ReLU(), + pos_enc_class(attention_dim, positional_dropout_rate)) + self.text_embed = nn.Sequential( + nn.Embedding( + vocab_size, attention_dim, padding_idx=padding_idx), + pos_enc_class(attention_dim, positional_dropout_rate), ) + elif input_layer == "sega_mlm": + self.segment_emb = nn.Embedding( + 500, attention_dim, padding_idx=padding_idx) + self.speech_embed = mySequential( + MaskInputLayer(idim), + nn.Linear(idim, attention_dim), + nn.LayerNorm(attention_dim), + nn.ReLU(), + pos_enc_class(attention_dim, positional_dropout_rate)) + self.text_embed = nn.Sequential( + nn.Embedding( + vocab_size, attention_dim, padding_idx=padding_idx), + pos_enc_class(attention_dim, positional_dropout_rate), ) + elif isinstance(input_layer, nn.Layer): + self.embed = nn.Sequential( + input_layer, + pos_enc_class(attention_dim, positional_dropout_rate), ) + elif input_layer is None: + self.embed = nn.Sequential( + pos_enc_class(attention_dim, positional_dropout_rate)) + else: + raise ValueError("unknown input_layer: " + input_layer) + self.normalize_before = normalize_before + + # self-attention module definition + if selfattention_layer_type == "selfattn": + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = (attention_heads, attention_dim, + attention_dropout_rate, ) + elif selfattention_layer_type == "legacy_rel_selfattn": + assert pos_enc_layer_type == "legacy_rel_pos" + encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention + encoder_selfattn_layer_args = (attention_heads, attention_dim, + attention_dropout_rate, ) + elif selfattention_layer_type == "rel_selfattn": + assert pos_enc_layer_type == "rel_pos" + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = (attention_heads, attention_dim, + attention_dropout_rate, zero_triu, ) + else: + raise ValueError("unknown encoder_attn_layer: " + + selfattention_layer_type) + + # feed-forward module definition + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = (attention_dim, linear_units, + dropout_rate, activation, ) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = (attention_dim, linear_units, + positionwise_conv_kernel_size, + dropout_rate, ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = (attention_dim, linear_units, + positionwise_conv_kernel_size, + dropout_rate, ) + else: + raise NotImplementedError("Support only linear or conv1d.") + + # convolution module definition + convolution_layer = ConvolutionModule + convolution_layer_args = (attention_dim, cnn_module_kernel, activation) + + self.encoders = repeat( + num_blocks, + lambda lnum: EncoderLayer( + attention_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer(*positionwise_layer_args) if macaron_style else None, + convolution_layer(*convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + concat_after, + stochastic_depth_rate * float(1 + lnum) / num_blocks, ), ) + self.pre_speech_layer = pre_speech_layer + self.pre_speech_encoders = repeat( + self.pre_speech_layer, + lambda lnum: EncoderLayer( + attention_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer(*positionwise_layer_args) if macaron_style else None, + convolution_layer(*convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + concat_after, + stochastic_depth_rate * float(1 + lnum) / self.pre_speech_layer, ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(attention_dim) + + def forward(self, + speech: paddle.Tensor, + text: paddle.Tensor, + masked_pos: paddle.Tensor, + speech_mask: paddle.Tensor=None, + text_mask: paddle.Tensor=None, + speech_seg_pos: paddle.Tensor=None, + text_seg_pos: paddle.Tensor=None): + """Encode input sequence. + + """ + if masked_pos is not None: + speech = self.speech_embed(speech, masked_pos) + else: + speech = self.speech_embed(speech) + if text is not None: + text = self.text_embed(text) + if speech_seg_pos is not None and text_seg_pos is not None and self.segment_emb: + speech_seg_emb = self.segment_emb(speech_seg_pos) + text_seg_emb = self.segment_emb(text_seg_pos) + text = (text[0] + text_seg_emb, text[1]) + speech = (speech[0] + speech_seg_emb, speech[1]) + if self.pre_speech_encoders: + speech, _ = self.pre_speech_encoders(speech, speech_mask) + + if text is not None: + xs = paddle.concat([speech[0], text[0]], axis=1) + xs_pos_emb = paddle.concat([speech[1], text[1]], axis=1) + masks = paddle.concat([speech_mask, text_mask], axis=-1) + else: + xs = speech[0] + xs_pos_emb = speech[1] + masks = speech_mask + + xs, masks = self.encoders((xs, xs_pos_emb), masks) + + if isinstance(xs, tuple): + xs = xs[0] + if self.normalize_before: + xs = self.after_norm(xs) + + return xs, masks + + +class MLMDecoder(MLMEncoder): + def forward(self, xs: paddle.Tensor, masks: paddle.Tensor): + """Encode input sequence. + + Args: + xs (paddle.Tensor): Input tensor (#batch, time, idim). + masks (paddle.Tensor): Mask tensor (#batch, time). + + Returns: + paddle.Tensor: Output tensor (#batch, time, attention_dim). + paddle.Tensor: Mask tensor (#batch, time). + + """ + xs = self.embed(xs) + xs, masks = self.encoders(xs, masks) + + if isinstance(xs, tuple): + xs = xs[0] + if self.normalize_before: + xs = self.after_norm(xs) + + return xs, masks + + +# encoder and decoder is nn.Layer, not str +class MLM(nn.Layer): + def __init__(self, + token_list: Union[Tuple[str, ...], List[str]], + odim: int, + encoder: nn.Layer, + decoder: Optional[nn.Layer], + postnet_layers: int=0, + postnet_chans: int=0, + postnet_filts: int=0, + text_masking: bool=False): + + super().__init__() + self.odim = odim + self.token_list = token_list.copy() + self.encoder = encoder + self.decoder = decoder + self.vocab_size = encoder.text_embed[0]._num_embeddings + + if self.decoder is None or not (hasattr(self.decoder, + 'output_layer') and + self.decoder.output_layer is not None): + self.sfc = nn.Linear(self.encoder._output_size, odim) + else: + self.sfc = None + if text_masking: + self.text_sfc = nn.Linear( + self.encoder.text_embed[0]._embedding_dim, + self.vocab_size, + weight_attr=self.encoder.text_embed[0]._weight_attr) + else: + self.text_sfc = None + + self.postnet = (None if postnet_layers == 0 else Postnet( + idim=self.encoder._output_size, + odim=odim, + n_layers=postnet_layers, + n_chans=postnet_chans, + n_filts=postnet_filts, + use_batch_norm=True, + dropout_rate=0.5, )) + + def inference( + self, + speech: paddle.Tensor, + text: paddle.Tensor, + masked_pos: paddle.Tensor, + speech_mask: paddle.Tensor, + text_mask: paddle.Tensor, + speech_seg_pos: paddle.Tensor, + text_seg_pos: paddle.Tensor, + span_bdy: List[int], + use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]: + ''' + Args: + speech (paddle.Tensor): input speech (1, Tmax, D). + text (paddle.Tensor): input text (1, Tmax2). + masked_pos (paddle.Tensor): masked position of input speech (1, Tmax) + speech_mask (paddle.Tensor): mask of speech (1, 1, Tmax). + text_mask (paddle.Tensor): mask of text (1, 1, Tmax2). + speech_seg_pos (paddle.Tensor): n-th phone of each mel, 0<=n<=Tmax2 (1, Tmax). + text_seg_pos (paddle.Tensor): n-th phone of each phone, 0<=n<=Tmax2 (1, Tmax2). + span_bdy (List[int]): masked mel boundary of input speech (2,) + use_teacher_forcing (bool): whether to use teacher forcing + Returns: + List[Tensor]: + eg: + [Tensor(shape=[1, 181, 80]), Tensor(shape=[80, 80]), Tensor(shape=[1, 67, 80])] + ''' + + z_cache = None + if use_teacher_forcing: + before_outs, zs, *_ = self.forward( + speech=speech, + text=text, + masked_pos=masked_pos, + speech_mask=speech_mask, + text_mask=text_mask, + speech_seg_pos=speech_seg_pos, + text_seg_pos=text_seg_pos) + if zs is None: + zs = before_outs + + speech = speech.squeeze(0) + outs = [speech[:span_bdy[0]]] + outs += [zs[0][span_bdy[0]:span_bdy[1]]] + outs += [speech[span_bdy[1]:]] + return outs + return None + + +class MLMEncAsDecoder(MLM): + def forward(self, + speech: paddle.Tensor, + text: paddle.Tensor, + masked_pos: paddle.Tensor, + speech_mask: paddle.Tensor, + text_mask: paddle.Tensor, + speech_seg_pos: paddle.Tensor, + text_seg_pos: paddle.Tensor): + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim2) + encoder_out, h_masks = self.encoder( + speech=speech, + text=text, + masked_pos=masked_pos, + speech_mask=speech_mask, + text_mask=text_mask, + speech_seg_pos=speech_seg_pos, + text_seg_pos=text_seg_pos) + if self.decoder is not None: + zs, _ = self.decoder(encoder_out, h_masks) + else: + zs = encoder_out + speech_hidden_states = zs[:, :paddle.shape(speech)[1], :] + if self.sfc is not None: + before_outs = paddle.reshape( + self.sfc(speech_hidden_states), + (paddle.shape(speech_hidden_states)[0], -1, self.odim)) + else: + before_outs = speech_hidden_states + if self.postnet is not None: + after_outs = before_outs + paddle.transpose( + self.postnet(paddle.transpose(before_outs, [0, 2, 1])), + [0, 2, 1]) + else: + after_outs = None + return before_outs, after_outs, None + + +class MLMDualMaksing(MLM): + def forward(self, + speech: paddle.Tensor, + text: paddle.Tensor, + masked_pos: paddle.Tensor, + speech_mask: paddle.Tensor, + text_mask: paddle.Tensor, + speech_seg_pos: paddle.Tensor, + text_seg_pos: paddle.Tensor): + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim2) + encoder_out, h_masks = self.encoder( + speech=speech, + text=text, + masked_pos=masked_pos, + speech_mask=speech_mask, + text_mask=text_mask, + speech_seg_pos=speech_seg_pos, + text_seg_pos=text_seg_pos) + if self.decoder is not None: + zs, _ = self.decoder(encoder_out, h_masks) + else: + zs = encoder_out + speech_hidden_states = zs[:, :paddle.shape(speech)[1], :] + if self.text_sfc: + text_hiddent_states = zs[:, paddle.shape(speech)[1]:, :] + text_outs = paddle.reshape( + self.text_sfc(text_hiddent_states), + (paddle.shape(text_hiddent_states)[0], -1, self.vocab_size)) + if self.sfc is not None: + before_outs = paddle.reshape( + self.sfc(speech_hidden_states), + (paddle.shape(speech_hidden_states)[0], -1, self.odim)) + else: + before_outs = speech_hidden_states + if self.postnet is not None: + after_outs = before_outs + paddle.transpose( + self.postnet(paddle.transpose(before_outs, [0, 2, 1])), + [0, 2, 1]) + else: + after_outs = None + return before_outs, after_outs, text_outs + + +def build_model_from_file(config_file, model_file): + + state_dict = paddle.load(model_file) + model_class = MLMDualMaksing if 'conformer_combine_vctk_aishell3_dual_masking' in config_file \ + else MLMEncAsDecoder + + # 构建模型 + with open(config_file) as f: + conf = CfgNode(yaml.safe_load(f)) + model = build_model(conf, model_class) + model.set_state_dict(state_dict) + return model, conf + + +# select encoder and decoder here +def build_model(args: argparse.Namespace, model_class=MLMEncAsDecoder) -> MLM: + if isinstance(args.token_list, str): + with open(args.token_list, encoding="utf-8") as f: + token_list = [line.rstrip() for line in f] + + # Overwriting token_list to keep it as "portable". + args.token_list = list(token_list) + elif isinstance(args.token_list, (tuple, list)): + token_list = list(args.token_list) + else: + raise RuntimeError("token_list must be str or list") + + vocab_size = len(token_list) + odim = 80 + + pos_enc_class = ScaledPositionalEncoding if args.use_scaled_pos_enc else PositionalEncoding + + if "conformer" == args.encoder: + conformer_self_attn_layer_type = args.encoder_conf[ + 'selfattention_layer_type'] + conformer_pos_enc_layer_type = args.encoder_conf['pos_enc_layer_type'] + conformer_rel_pos_type = "legacy" + if conformer_rel_pos_type == "legacy": + if conformer_pos_enc_layer_type == "rel_pos": + conformer_pos_enc_layer_type = "legacy_rel_pos" + if conformer_self_attn_layer_type == "rel_selfattn": + conformer_self_attn_layer_type = "legacy_rel_selfattn" + elif conformer_rel_pos_type == "latest": + assert conformer_pos_enc_layer_type != "legacy_rel_pos" + assert conformer_self_attn_layer_type != "legacy_rel_selfattn" + else: + raise ValueError(f"Unknown rel_pos_type: {conformer_rel_pos_type}") + args.encoder_conf[ + 'selfattention_layer_type'] = conformer_self_attn_layer_type + args.encoder_conf['pos_enc_layer_type'] = conformer_pos_enc_layer_type + if "conformer" == args.decoder: + args.decoder_conf[ + 'selfattention_layer_type'] = conformer_self_attn_layer_type + args.decoder_conf[ + 'pos_enc_layer_type'] = conformer_pos_enc_layer_type + + # Encoder + encoder_class = MLMEncoder + + if 'text_masking' in args.model_conf.keys() and args.model_conf[ + 'text_masking']: + args.encoder_conf['text_masking'] = True + else: + args.encoder_conf['text_masking'] = False + + encoder = encoder_class( + args.input_size, + vocab_size=vocab_size, + pos_enc_class=pos_enc_class, + **args.encoder_conf) + + # Decoder + if args.decoder != 'no_decoder': + decoder_class = MLMDecoder + decoder = decoder_class( + idim=0, + input_layer=None, + **args.decoder_conf, ) + else: + decoder = None + + # Build model + model = model_class( + odim=odim, + encoder=encoder, + decoder=decoder, + token_list=token_list, + **args.model_conf, ) + + # Initialize + if args.init is not None: + initialize(model, args.init) + + return model diff --git a/paddlespeech/t2s/modules/losses.py b/paddlespeech/t2s/modules/losses.py index e6ab93513..4726f40ec 100644 --- a/paddlespeech/t2s/modules/losses.py +++ b/paddlespeech/t2s/modules/losses.py @@ -1007,3 +1007,55 @@ class KLDivergenceLoss(nn.Layer): loss = kl / paddle.sum(z_mask) return loss + + +# loss for ERNIE SAT +class MLMLoss(nn.Layer): + def __init__(self, + lsm_weight: float=0.1, + ignore_id: int=-1, + text_masking: bool=False): + super().__init__() + if text_masking: + self.text_mlm_loss = nn.CrossEntropyLoss(ignore_index=ignore_id) + if lsm_weight > 50: + self.l1_loss_func = nn.MSELoss() + else: + self.l1_loss_func = nn.L1Loss(reduction='none') + self.text_masking = text_masking + + def forward(self, + speech: paddle.Tensor, + before_outs: paddle.Tensor, + after_outs: paddle.Tensor, + masked_pos: paddle.Tensor, + text: paddle.Tensor=None, + text_outs: paddle.Tensor=None, + text_masked_pos: paddle.Tensor=None): + + xs_pad = speech + mlm_loss_pos = masked_pos > 0 + loss = paddle.sum( + self.l1_loss_func( + paddle.reshape(before_outs, (-1, self.odim)), + paddle.reshape(xs_pad, (-1, self.odim))), + axis=-1) + if after_outs is not None: + loss += paddle.sum( + self.l1_loss_func( + paddle.reshape(after_outs, (-1, self.odim)), + paddle.reshape(xs_pad, (-1, self.odim))), + axis=-1) + loss_mlm = paddle.sum((loss * paddle.reshape( + mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10) + + if self.text_masking: + loss_text = paddle.sum((self.text_mlm_loss( + paddle.reshape(text_outs, (-1, self.vocab_size)), + paddle.reshape(text, (-1))) * paddle.reshape( + text_masked_pos, + (-1)))) / paddle.sum((text_masked_pos) + 1e-10) + + return loss_mlm, loss_text + + return loss_mlm diff --git a/paddlespeech/t2s/modules/nets_utils.py b/paddlespeech/t2s/modules/nets_utils.py index 598b63164..0238f4dba 100644 --- a/paddlespeech/t2s/modules/nets_utils.py +++ b/paddlespeech/t2s/modules/nets_utils.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from espnet(https://github.com/espnet/espnet) +import math from typing import Tuple +import numpy as np import paddle from paddle import nn from typeguard import check_argument_types @@ -40,7 +42,8 @@ def pad_list(xs, pad_value): """ n_batch = len(xs) max_len = max(x.shape[0] for x in xs) - pad = paddle.full([n_batch, max_len, *xs[0].shape[1:]], pad_value) + pad = paddle.full( + [n_batch, max_len, *xs[0].shape[1:]], pad_value, dtype=xs[0].dtype) for i in range(n_batch): pad[i, :xs[i].shape[0]] = xs[i] @@ -48,13 +51,17 @@ def pad_list(xs, pad_value): return pad -def make_pad_mask(lengths, length_dim=-1): +def make_pad_mask(lengths, xs=None, length_dim=-1): """Make mask tensor containing indices of padded part. Args: lengths (Tensor(int64)): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. + If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. + See the example. - Returns: + Returns: Tensor(bool): Mask tensor containing indices of padded part bool. Examples: @@ -63,23 +70,99 @@ def make_pad_mask(lengths, length_dim=-1): >>> lengths = [5, 3, 2] >>> make_non_pad_mask(lengths) masks = [[0, 0, 0, 0 ,0], - [0, 0, 0, 1, 1], - [0, 0, 1, 1, 1]] + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + + With the reference tensor. + + >>> xs = paddle.zeros((3, 2, 4)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0], + [0, 0, 0, 0]], + [[0, 0, 0, 1], + [0, 0, 0, 1]], + [[0, 0, 1, 1], + [0, 0, 1, 1]]]) + >>> xs = paddle.zeros((3, 2, 6)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]]) + + With the reference tensor and dimension indicator. + + >>> xs = paddle.zeros((3, 6, 6)) + >>> make_pad_mask(lengths, xs, 1) + tensor([[[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]]]) + >>> make_pad_mask(lengths, xs, 2) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]],) + """ if length_dim == 0: raise ValueError("length_dim cannot be 0: {}".format(length_dim)) bs = paddle.shape(lengths)[0] - maxlen = lengths.max() + if xs is None: + maxlen = lengths.max() + else: + maxlen = paddle.shape(xs)[length_dim] + seq_range = paddle.arange(0, maxlen, dtype=paddle.int64) seq_range_expand = seq_range.unsqueeze(0).expand([bs, maxlen]) seq_length_expand = lengths.unsqueeze(-1) mask = seq_range_expand >= seq_length_expand + if xs is not None: + assert paddle.shape(xs)[0] == bs, (paddle.shape(xs)[0], bs) + + if length_dim < 0: + length_dim = len(paddle.shape(xs)) + length_dim + # ind = (:, None, ..., None, :, , None, ..., None) + ind = tuple( + slice(None) if i in (0, length_dim) else None + for i in range(len(paddle.shape(xs)))) + mask = paddle.expand(mask[ind], paddle.shape(xs)) return mask -def make_non_pad_mask(lengths, length_dim=-1): +def make_non_pad_mask(lengths, xs=None, length_dim=-1): """Make mask tensor containing indices of non-padded part. Args: @@ -92,16 +175,78 @@ def make_non_pad_mask(lengths, length_dim=-1): Returns: Tensor(bool): mask tensor containing indices of padded part bool. - Examples: + Examples: With only lengths. >>> lengths = [5, 3, 2] >>> make_non_pad_mask(lengths) masks = [[1, 1, 1, 1 ,1], - [1, 1, 1, 0, 0], - [1, 1, 0, 0, 0]] + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0]] + + With the reference tensor. + + >>> xs = paddle.zeros((3, 2, 4)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1], + [1, 1, 1, 1]], + [[1, 1, 1, 0], + [1, 1, 1, 0]], + [[1, 1, 0, 0], + [1, 1, 0, 0]]]) + >>> xs = paddle.zeros((3, 2, 6)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]]) + + With the reference tensor and dimension indicator. + + >>> xs = paddle.zeros((3, 6, 6)) + >>> make_non_pad_mask(lengths, xs, 1) + tensor([[[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]]]) + >>> make_non_pad_mask(lengths, xs, 2) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]]) + """ - return paddle.logical_not(make_pad_mask(lengths, length_dim)) + return paddle.logical_not(make_pad_mask(lengths, xs, length_dim)) def initialize(model: nn.Layer, init: str): @@ -194,3 +339,270 @@ def paddle_gather(x, dim, index): ind2 = paddle.transpose(paddle.stack(nd_index), [1, 0]).astype("int64") paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape) return paddle_out + + +# for ERNIE SAT +# mask phones +def phones_masking(xs_pad: paddle.Tensor, + src_mask: paddle.Tensor, + align_start: paddle.Tensor, + align_end: paddle.Tensor, + align_start_lens: paddle.Tensor, + mlm_prob: float=0.8, + mean_phn_span: int=8, + span_bdy: paddle.Tensor=None): + ''' + Args: + xs_pad (paddle.Tensor): input speech (B, Tmax, D). + src_mask (paddle.Tensor): mask of speech (B, 1, Tmax). + align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2). + align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2). + align_start_lens (paddle.Tensor): length of align_start (B, ). + mlm_prob (float): + mean_phn_span (int): + span_bdy (paddle.Tensor): masked mel boundary of input speech (B, 2). + Returns: + paddle.Tensor[bool]: masked position of input speech (B, Tmax). + ''' + bz, sent_len, _ = paddle.shape(xs_pad) + masked_pos = paddle.zeros((bz, sent_len)) + if mlm_prob == 1.0: + masked_pos += 1 + elif mean_phn_span == 0: + # only speech + length = sent_len + mean_phn_span = min(length * mlm_prob // 3, 50) + masked_phn_idxs = random_spans_noise_mask( + length=length, mlm_prob=mlm_prob, + mean_phn_span=mean_phn_span).nonzero() + masked_pos[:, masked_phn_idxs] = 1 + else: + for idx in range(bz): + # for inference + if span_bdy is not None: + for s, e in zip(span_bdy[idx][::2], span_bdy[idx][1::2]): + masked_pos[idx, s:e] = 1 + # for training + else: + length = align_start_lens[idx] + if length < 2: + continue + masked_phn_idxs = random_spans_noise_mask( + length=length, + mlm_prob=mlm_prob, + mean_phn_span=mean_phn_span).nonzero() + masked_start = align_start[idx][masked_phn_idxs].tolist() + masked_end = align_end[idx][masked_phn_idxs].tolist() + + for s, e in zip(masked_start, masked_end): + masked_pos[idx, s:e] = 1 + non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]) + masked_pos = masked_pos * non_eos_mask + masked_pos = paddle.cast(masked_pos, 'bool') + + return masked_pos + + +# mask speech and phones +def phones_text_masking(xs_pad: paddle.Tensor, + src_mask: paddle.Tensor, + text_pad: paddle.Tensor, + text_mask: paddle.Tensor, + align_start: paddle.Tensor, + align_end: paddle.Tensor, + align_start_lens: paddle.Tensor, + mlm_prob: float=0.8, + mean_phn_span: int=8, + span_bdy: paddle.Tensor=None): + ''' + Args: + xs_pad (paddle.Tensor): input speech (B, Tmax, D). + src_mask (paddle.Tensor): mask of speech (B, 1, Tmax). + text_pad (paddle.Tensor): input text (B, Tmax2). + text_mask (paddle.Tensor): mask of text (B, 1, Tmax2). + align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2). + align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2). + align_start_lens (paddle.Tensor): length of align_start (B, ). + mlm_prob (float): + mean_phn_span (int): + span_bdy (paddle.Tensor): masked mel boundary of input speech (B, 2). + Returns: + paddle.Tensor[bool]: masked position of input speech (B, Tmax). + paddle.Tensor[bool]: masked position of input text (B, Tmax2). + ''' + bz, sent_len, _ = paddle.shape(xs_pad) + masked_pos = paddle.zeros((bz, sent_len)) + _, text_len = paddle.shape(text_pad) + text_mask_num_lower = math.ceil(text_len * (1 - mlm_prob) * 0.5) + text_masked_pos = paddle.zeros((bz, text_len)) + + if mlm_prob == 1.0: + masked_pos += 1 + elif mean_phn_span == 0: + # only speech + length = sent_len + mean_phn_span = min(length * mlm_prob // 3, 50) + masked_phn_idxs = random_spans_noise_mask( + length=length, mlm_prob=mlm_prob, + mean_phn_span=mean_phn_span).nonzero() + masked_pos[:, masked_phn_idxs] = 1 + else: + for idx in range(bz): + # for inference + if span_bdy is not None: + for s, e in zip(span_bdy[idx][::2], span_bdy[idx][1::2]): + masked_pos[idx, s:e] = 1 + # for training + else: + length = align_start_lens[idx] + if length < 2: + continue + masked_phn_idxs = random_spans_noise_mask( + length=length, + mlm_prob=mlm_prob, + mean_phn_span=mean_phn_span).nonzero() + unmasked_phn_idxs = list( + set(range(length)) - set(masked_phn_idxs[0].tolist())) + np.random.shuffle(unmasked_phn_idxs) + masked_text_idxs = unmasked_phn_idxs[:text_mask_num_lower] + text_masked_pos[idx][masked_text_idxs] = 1 + masked_start = align_start[idx][masked_phn_idxs].tolist() + masked_end = align_end[idx][masked_phn_idxs].tolist() + for s, e in zip(masked_start, masked_end): + masked_pos[idx, s:e] = 1 + non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]) + masked_pos = masked_pos * non_eos_mask + non_eos_text_mask = paddle.reshape(text_mask, paddle.shape(xs_pad)[:2]) + text_masked_pos = text_masked_pos * non_eos_text_mask + masked_pos = paddle.cast(masked_pos, 'bool') + text_masked_pos = paddle.cast(text_masked_pos, 'bool') + + return masked_pos, text_masked_pos + + +def get_seg_pos(speech_pad: paddle.Tensor, + text_pad: paddle.Tensor, + align_start: paddle.Tensor, + align_end: paddle.Tensor, + align_start_lens: paddle.Tensor, + seg_emb: bool=False): + ''' + Args: + speech_pad (paddle.Tensor): input speech (B, Tmax, D). + text_pad (paddle.Tensor): input text (B, Tmax2). + align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2). + align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2). + align_start_lens (paddle.Tensor): length of align_start (B, ). + seg_emb (bool): whether to use segment embedding. + Returns: + paddle.Tensor[int]: n-th phone of each mel, 0<=n<=Tmax2 (B, Tmax). + eg: + Tensor(shape=[1, 328], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [[0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , + 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , + 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , + 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , + 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 2 , 2 , 2 , 3 , 3 , 3 , 4 , 4 , 4 , + 5 , 5 , 5 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 7 , 7 , 7 , 7 , 7 , 7 , 7 , + 7 , 8 , 8 , 8 , 8 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 10, 10, 10, 10, 10, + 10, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, + 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, + 15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17, + 17, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20, + 20, 20, 21, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 23, 23, + 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25, + 25, 26, 26, 26, 27, 27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 29, + 29, 29, 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 31, 31, 31, 31, 32, + 32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 34, 34, 34, 34, 35, 35, + 35, 35, 35, 35, 35, 35, 36, 36, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, + 37, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, + 38, 38, 0 , 0 ]]) + paddle.Tensor[int]: n-th phone of each phone, 0<=n<=Tmax2 (B, Tmax2). + eg: + Tensor(shape=[1, 38], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38]]) + ''' + + bz, speech_len, _ = paddle.shape(speech_pad) + _, text_len = paddle.shape(text_pad) + + text_seg_pos = paddle.zeros((bz, text_len), dtype='int64') + speech_seg_pos = paddle.zeros((bz, speech_len), dtype='int64') + + if not seg_emb: + return speech_seg_pos, text_seg_pos + for idx in range(bz): + align_length = align_start_lens[idx] + for j in range(align_length): + s, e = align_start[idx][j], align_end[idx][j] + speech_seg_pos[idx, s:e] = j + 1 + text_seg_pos[idx, j] = j + 1 + + return speech_seg_pos, text_seg_pos + + +# randomly select the range of speech and text to mask during training +def random_spans_noise_mask(length: int, + mlm_prob: float=0.8, + mean_phn_span: float=8): + """This function is copy of `random_spans_helper + `__ . + Noise mask consisting of random spans of noise tokens. + The number of noise tokens and the number of noise spans and non-noise spans + are determined deterministically as follows: + num_noise_tokens = round(length * noise_density) + num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length) + Spans alternate between non-noise and noise, beginning with non-noise. + Subject to the above restrictions, all masks are equally likely. + Args: + length: an int32 scalar (length of the incoming token sequence) + noise_density: a float - approximate density of output mask + mean_noise_span_length: a number + Returns: + np.ndarray: a boolean tensor with shape [length] + """ + + orig_length = length + + num_noise_tokens = int(np.round(length * mlm_prob)) + # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. + num_noise_tokens = min(max(num_noise_tokens, 1), length - 1) + num_noise_spans = int(np.round(num_noise_tokens / mean_phn_span)) + + # avoid degeneracy by ensuring positive number of noise spans + num_noise_spans = max(num_noise_spans, 1) + num_nonnoise_tokens = length - num_noise_tokens + + # pick the lengths of the noise spans and the non-noise spans + def _random_seg(num_items, num_segs): + """Partition a sequence of items randomly into non-empty segments. + Args: + num_items: an integer scalar > 0 + num_segs: an integer scalar in [1, num_items] + Returns: + a Tensor with shape [num_segs] containing positive integers that add + up to num_items + """ + mask_idxs = np.arange(num_items - 1) < (num_segs - 1) + np.random.shuffle(mask_idxs) + first_in_seg = np.pad(mask_idxs, [[1, 0]]) + segment_id = np.cumsum(first_in_seg) + # count length of sub segments assuming that list is sorted + _, segment_length = np.unique(segment_id, return_counts=True) + return segment_length + + noise_span_lens = _random_seg(num_noise_tokens, num_noise_spans) + nonnoise_span_lens = _random_seg(num_nonnoise_tokens, num_noise_spans) + + interleaved_span_lens = np.reshape( + np.stack([nonnoise_span_lens, noise_span_lens], axis=1), + [num_noise_spans * 2]) + span_starts = np.cumsum(interleaved_span_lens)[:-1] + span_start_indicator = np.zeros((length, ), dtype=np.int8) + span_start_indicator[span_starts] = True + span_num = np.cumsum(span_start_indicator) + is_noise = np.equal(span_num % 2, 1) + + return is_noise[:orig_length] diff --git a/paddlespeech/t2s/modules/transformer/attention.py b/paddlespeech/t2s/modules/transformer/attention.py index cdb95b211..538a36b6b 100644 --- a/paddlespeech/t2s/modules/transformer/attention.py +++ b/paddlespeech/t2s/modules/transformer/attention.py @@ -220,3 +220,99 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) return self.forward_attention(v, scores, mask) + + +class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding (old version). + Details can be found in https://github.com/espnet/espnet/pull/2816. + Paper: https://arxiv.org/abs/1901.02860 + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + zero_triu (bool): Whether to zero the upper triangular part of attention matrix. + """ + + def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate) + self.zero_triu = zero_triu + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias_attr=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + + self.pos_bias_u = paddle.create_parameter( + shape=(self.h, self.d_k), + dtype='float32', + default_initializer=paddle.nn.initializer.XavierUniform()) + self.pos_bias_v = paddle.create_parameter( + shape=(self.h, self.d_k), + dtype='float32', + default_initializer=paddle.nn.initializer.XavierUniform()) + + def rel_shift(self, x): + """Compute relative positional encoding. + Args: + x(Tensor): Input tensor (batch, head, time1, time2). + + Returns: + Tensor:Output tensor. + """ + b, h, t1, t2 = paddle.shape(x) + zero_pad = paddle.zeros((b, h, t1, 1)) + x_padded = paddle.concat([zero_pad, x], axis=-1) + x_padded = paddle.reshape(x_padded, [b, h, t2 + 1, t1]) + # only keep the positions from 0 to time2 + x = paddle.reshape(x_padded[:, :, 1:], [b, h, t1, t2]) + + if self.zero_triu: + ones = paddle.ones((t1, t2)) + x = x * paddle.tril(ones, t2 - 1)[None, None, :, :] + + return x + + def forward(self, query, key, value, pos_emb, mask): + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + + Args: + query(Tensor): Query tensor (#batch, time1, size). + key(Tensor): Key tensor (#batch, time2, size). + value(Tensor): Value tensor (#batch, time2, size). + pos_emb(Tensor): Positional embedding tensor (#batch, time1, size). + mask(Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2). + + Returns: + Tensor: Output tensor (#batch, time1, d_model). + """ + q, k, v = self.forward_qkv(query, key, value) + # (batch, time1, head, d_k) + q = paddle.transpose(q, [0, 2, 1, 3]) + + n_batch_pos = paddle.shape(pos_emb)[0] + p = paddle.reshape( + self.linear_pos(pos_emb), [n_batch_pos, -1, self.h, self.d_k]) + # (batch, head, time1, d_k) + p = paddle.transpose(p, [0, 2, 1, 3]) + # (batch, head, time1, d_k) + q_with_bias_u = paddle.transpose((q + self.pos_bias_u), [0, 2, 1, 3]) + # (batch, head, time1, d_k) + q_with_bias_v = paddle.transpose((q + self.pos_bias_v), [0, 2, 1, 3]) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = paddle.matmul(q_with_bias_u, + paddle.transpose(k, [0, 1, 3, 2])) + + # compute matrix b and matrix d + # (batch, head, time1, time1) + matrix_bd = paddle.matmul(q_with_bias_v, + paddle.transpose(p, [0, 1, 3, 2])) + matrix_bd = self.rel_shift(matrix_bd) + # (batch, head, time1, time2) + scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) + + return self.forward_attention(v, scores, mask) diff --git a/paddlespeech/t2s/modules/transformer/embedding.py b/paddlespeech/t2s/modules/transformer/embedding.py index d9339d20b..9524f07ee 100644 --- a/paddlespeech/t2s/modules/transformer/embedding.py +++ b/paddlespeech/t2s/modules/transformer/embedding.py @@ -185,3 +185,61 @@ class RelPositionalEncoding(nn.Layer): pe_size = paddle.shape(self.pe) pos_emb = self.pe[:, pe_size[1] // 2 - T + 1:pe_size[1] // 2 + T, ] return self.dropout(x), self.dropout(pos_emb) + + +class LegacyRelPositionalEncoding(PositionalEncoding): + """Relative positional encoding module (old version). + + Details can be found in https://github.com/espnet/espnet/pull/2816. + + See : Appendix B in https://arxiv.org/abs/1901.02860 + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000): + """ + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int, optional): [Maximum input length.]. Defaults to 5000. + """ + super().__init__(d_model, dropout_rate, max_len, reverse=True) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if paddle.shape(self.pe)[1] >= paddle.shape(x)[1]: + return + pe = paddle.zeros((paddle.shape(x)[1], self.d_model)) + if self.reverse: + position = paddle.arange( + paddle.shape(x)[1] - 1, -1, -1.0, + dtype=paddle.float32).unsqueeze(1) + else: + position = paddle.arange( + 0, paddle.shape(x)[1], dtype=paddle.float32).unsqueeze(1) + div_term = paddle.exp( + paddle.arange(0, self.d_model, 2, dtype=paddle.float32) * + -(math.log(10000.0) / self.d_model)) + pe[:, 0::2] = paddle.sin(position * div_term) + pe[:, 1::2] = paddle.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe + + def forward(self, x: paddle.Tensor): + """Compute positional encoding. + Args: + x (paddle.Tensor): Input tensor (batch, time, `*`). + Returns: + paddle.Tensor: Encoded tensor (batch, time, `*`). + paddle.Tensor: Positional embedding tensor (1, time, `*`). + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.pe[:, :paddle.shape(x)[1]] + return self.dropout(x), self.dropout(pos_emb)