Merge branch 'PaddlePaddle:develop' into develop

pull/2615/head
HuangLiangJie 3 years ago committed by GitHub
commit b2597bc0c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -157,6 +157,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision
- 🧩 *Cascaded models application*: as an extension of the typical traditional audio tasks, we combine the workflows of the aforementioned tasks with other fields like Natural language processing (NLP) and Computer Vision (CV). - 🧩 *Cascaded models application*: as an extension of the typical traditional audio tasks, we combine the workflows of the aforementioned tasks with other fields like Natural language processing (NLP) and Computer Vision (CV).
### Recent Update ### Recent Update
- 🎉 2022.10.21: Add [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) for TTS Chinese Text Frontend.
- 👑 2022.10.11: Add [Wav2vec2ASR](./examples/librispeech/asr3), wav2vec2.0 fine-tuning for ASR on LibriSpeech. - 👑 2022.10.11: Add [Wav2vec2ASR](./examples/librispeech/asr3), wav2vec2.0 fine-tuning for ASR on LibriSpeech.
- 🔥 2022.09.26: Add Voice Cloning, TTS finetune, and ERNIE-SAT in [PaddleSpeech Web Demo](./demos/speech_web). - 🔥 2022.09.26: Add Voice Cloning, TTS finetune, and ERNIE-SAT in [PaddleSpeech Web Demo](./demos/speech_web).
- ⚡ 2022.09.09: Add AISHELL-3 Voice Cloning [example](./examples/aishell3/vc2) with ECAPA-TDNN speaker encoder. - ⚡ 2022.09.09: Add AISHELL-3 Voice Cloning [example](./examples/aishell3/vc2) with ECAPA-TDNN speaker encoder.
@ -923,8 +924,8 @@ You are warmly welcome to submit questions in [discussions](https://github.com/P
## Acknowledgement ## Acknowledgement
- Many thanks to [HighCWu](https://github.com/HighCWu) for adding [VITS-aishell3](./examples/aishell3/vits) and [VITS-VC](./examples/aishell3/vits-vc) examples. - Many thanks to [HighCWu](https://github.com/HighCWu) for adding [VITS-aishell3](./examples/aishell3/vits) and [VITS-VC](./examples/aishell3/vits-vc) examples.
- Many thanks to [david-95](https://github.com/david-95) improved TTS, fixed multi-punctuation bug, and contributed to multiple program and data. - Many thanks to [david-95](https://github.com/david-95) for fixing multi-punctuation bug、contributing to multiple program and data, and adding [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) for TTS Chinese Text Frontend.
- Many thanks to [BarryKCL](https://github.com/BarryKCL) improved TTS Chinses frontend based on [G2PW](https://github.com/GitYCC/g2pW). - Many thanks to [BarryKCL](https://github.com/BarryKCL) for improving TTS Chinses Frontend based on [G2PW](https://github.com/GitYCC/g2pW).
- Many thanks to [yeyupiaoling](https://github.com/yeyupiaoling)/[PPASR](https://github.com/yeyupiaoling/PPASR)/[PaddlePaddle-DeepSpeech](https://github.com/yeyupiaoling/PaddlePaddle-DeepSpeech)/[VoiceprintRecognition-PaddlePaddle](https://github.com/yeyupiaoling/VoiceprintRecognition-PaddlePaddle)/[AudioClassification-PaddlePaddle](https://github.com/yeyupiaoling/AudioClassification-PaddlePaddle) for years of attention, constructive advice and great help. - Many thanks to [yeyupiaoling](https://github.com/yeyupiaoling)/[PPASR](https://github.com/yeyupiaoling/PPASR)/[PaddlePaddle-DeepSpeech](https://github.com/yeyupiaoling/PaddlePaddle-DeepSpeech)/[VoiceprintRecognition-PaddlePaddle](https://github.com/yeyupiaoling/VoiceprintRecognition-PaddlePaddle)/[AudioClassification-PaddlePaddle](https://github.com/yeyupiaoling/AudioClassification-PaddlePaddle) for years of attention, constructive advice and great help.
- Many thanks to [mymagicpower](https://github.com/mymagicpower) for the Java implementation of ASR upon [short](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_sdk) and [long](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_long_audio_sdk) audio files. - Many thanks to [mymagicpower](https://github.com/mymagicpower) for the Java implementation of ASR upon [short](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_sdk) and [long](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_long_audio_sdk) audio files.
- Many thanks to [JiehangXie](https://github.com/JiehangXie)/[PaddleBoBo](https://github.com/JiehangXie/PaddleBoBo) for developing Virtual Uploader(VUP)/Virtual YouTuber(VTuber) with PaddleSpeech TTS function. - Many thanks to [JiehangXie](https://github.com/JiehangXie)/[PaddleBoBo](https://github.com/JiehangXie/PaddleBoBo) for developing Virtual Uploader(VUP)/Virtual YouTuber(VTuber) with PaddleSpeech TTS function.

@ -164,7 +164,8 @@
### 近期更新 ### 近期更新
- 👑 2022.10.11: 新增 [Wav2vec2ASR](./examples/librispeech/asr3), 在 LibriSpeech 上针对ASR任务对wav2vec2.0 的fine-tuning. - 🎉 2022.10.21: TTS 中文文本前端新增 [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) 功能。
- 👑 2022.10.11: 新增 [Wav2vec2ASR](./examples/librispeech/asr3), 在 LibriSpeech 上针对 ASR 任务对 wav2vec2.0 的 finetuning。
- 🔥 2022.09.26: 新增 Voice Cloning, TTS finetune 和 ERNIE-SAT 到 [PaddleSpeech 网页应用](./demos/speech_web)。 - 🔥 2022.09.26: 新增 Voice Cloning, TTS finetune 和 ERNIE-SAT 到 [PaddleSpeech 网页应用](./demos/speech_web)。
- ⚡ 2022.09.09: 新增基于 ECAPA-TDNN 声纹模型的 AISHELL-3 Voice Cloning [示例](./examples/aishell3/vc2)。 - ⚡ 2022.09.09: 新增基于 ECAPA-TDNN 声纹模型的 AISHELL-3 Voice Cloning [示例](./examples/aishell3/vc2)。
- ⚡ 2022.08.25: 发布 TTS [finetune](./examples/other/tts_finetune/tts3) 示例。 - ⚡ 2022.08.25: 发布 TTS [finetune](./examples/other/tts_finetune/tts3) 示例。
@ -928,7 +929,7 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
## 致谢 ## 致谢
- 非常感谢 [HighCWu](https://github.com/HighCWu) 新增 [VITS-aishell3](./examples/aishell3/vits) 和 [VITS-VC](./examples/aishell3/vits-vc) 代码示例。 - 非常感谢 [HighCWu](https://github.com/HighCWu) 新增 [VITS-aishell3](./examples/aishell3/vits) 和 [VITS-VC](./examples/aishell3/vits-vc) 代码示例。
- 非常感谢 [david-95](https://github.com/david-95) 修复句尾多标点符号出错的问题,贡献补充多条程序和数据。 - 非常感谢 [david-95](https://github.com/david-95) 修复 TTS 句尾多标点符号出错的问题,贡献补充多条程序和数据。为 TTS 中文文本前端新增 [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) 功能。
- 非常感谢 [BarryKCL](https://github.com/BarryKCL) 基于 [G2PW](https://github.com/GitYCC/g2pW) 对 TTS 中文文本前端的优化。 - 非常感谢 [BarryKCL](https://github.com/BarryKCL) 基于 [G2PW](https://github.com/GitYCC/g2pW) 对 TTS 中文文本前端的优化。
- 非常感谢 [yeyupiaoling](https://github.com/yeyupiaoling)/[PPASR](https://github.com/yeyupiaoling/PPASR)/[PaddlePaddle-DeepSpeech](https://github.com/yeyupiaoling/PaddlePaddle-DeepSpeech)/[VoiceprintRecognition-PaddlePaddle](https://github.com/yeyupiaoling/VoiceprintRecognition-PaddlePaddle)/[AudioClassification-PaddlePaddle](https://github.com/yeyupiaoling/AudioClassification-PaddlePaddle) 多年来的关注和建议,以及在诸多问题上的帮助。 - 非常感谢 [yeyupiaoling](https://github.com/yeyupiaoling)/[PPASR](https://github.com/yeyupiaoling/PPASR)/[PaddlePaddle-DeepSpeech](https://github.com/yeyupiaoling/PaddlePaddle-DeepSpeech)/[VoiceprintRecognition-PaddlePaddle](https://github.com/yeyupiaoling/VoiceprintRecognition-PaddlePaddle)/[AudioClassification-PaddlePaddle](https://github.com/yeyupiaoling/AudioClassification-PaddlePaddle) 多年来的关注和建议,以及在诸多问题上的帮助。
- 非常感谢 [mymagicpower](https://github.com/mymagicpower) 采用PaddleSpeech 对 ASR 的[短语音](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_sdk)及[长语音](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_long_audio_sdk)进行 Java 实现。 - 非常感谢 [mymagicpower](https://github.com/mymagicpower) 采用PaddleSpeech 对 ASR 的[短语音](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_sdk)及[长语音](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_long_audio_sdk)进行 Java 实现。

@ -53,3 +53,22 @@ Pretrain model from https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1
| conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_greedy_search | -1 | 0.061884 | | conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_greedy_search | -1 | 0.061884 |
| conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_prefix_beam_search | -1 | 0.062056 | | conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_prefix_beam_search | -1 | 0.062056 |
| conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | attention_rescoring | -1 | 0.052110 | | conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | attention_rescoring | -1 | 0.052110 |
## U2PP Steaming Pretrained Model
Pretrain model from https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.3.0.model.tar.gz
| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size | CER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | attention | 16 | 0.057031 |
| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_greedy_search | 16 | 0.068826 |
| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_prefix_beam_search | 16 | 0.069111 |
| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | attention_rescoring | 16 | 0.059213 |
| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size | CER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | attention | -1 | 0.049256 |
| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_greedy_search | -1 | 0.052086 |
| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_prefix_beam_search | -1 | 0.052267 |
| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | attention_rescoring | -1 | 0.047198 |

@ -113,7 +113,7 @@ class ServerExecutor(BaseExecutor):
""" """
config = get_config(config_file) config = get_config(config_file)
if self.init(config): if self.init(config):
uvicorn.run(app, host=config.host, port=config.port, debug=True) uvicorn.run(app, host=config.host, port=config.port)
@cli_server_register( @cli_server_register(

@ -18,5 +18,6 @@ from . import exps
from . import frontend from . import frontend
from . import models from . import models
from . import modules from . import modules
from . import ssml
from . import training from . import training
from . import utils from . import utils

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
import os import os
import re
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from typing import Dict from typing import Dict
@ -33,6 +34,7 @@ from paddlespeech.t2s.frontend.mix_frontend import MixFrontend
from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.utils.dynamic_import import dynamic_import from paddlespeech.utils.dynamic_import import dynamic_import
# remove [W:onnxruntime: xxx] from ort # remove [W:onnxruntime: xxx] from ort
ort.set_default_logger_severity(3) ort.set_default_logger_severity(3)
@ -103,14 +105,15 @@ def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'):
sentences = [] sentences = []
with open(text_file, 'rt') as f: with open(text_file, 'rt') as f:
for line in f: for line in f:
items = line.strip().split() if line.strip() != "":
utt_id = items[0] items = re.split(r"\s+", line.strip(), 1)
if lang == 'zh': utt_id = items[0]
sentence = "".join(items[1:]) if lang == 'zh':
elif lang == 'en': sentence = "".join(items[1:])
sentence = " ".join(items[1:]) elif lang == 'en':
elif lang == 'mix': sentence = " ".join(items[1:])
sentence = " ".join(items[1:]) elif lang == 'mix':
sentence = " ".join(items[1:])
sentences.append((utt_id, sentence)) sentences.append((utt_id, sentence))
return sentences return sentences
@ -180,11 +183,20 @@ def run_frontend(frontend: object,
to_tensor: bool=True): to_tensor: bool=True):
outs = dict() outs = dict()
if lang == 'zh': if lang == 'zh':
input_ids = frontend.get_input_ids( input_ids = {}
text, if text.strip() != "" and re.match(r".*?<speak>.*?</speak>.*", text,
merge_sentences=merge_sentences, re.DOTALL):
get_tone_ids=get_tone_ids, input_ids = frontend.get_input_ids_ssml(
to_tensor=to_tensor) text,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids,
to_tensor=to_tensor)
else:
input_ids = frontend.get_input_ids(
text,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids,
to_tensor=to_tensor)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
if get_tone_ids: if get_tone_ids:
tone_ids = input_ids["tone_ids"] tone_ids = input_ids["tone_ids"]

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
import re import re
from operator import itemgetter
from typing import Dict from typing import Dict
from typing import List from typing import List
@ -31,6 +32,7 @@ from paddlespeech.t2s.frontend.g2pw import G2PWOnnxConverter
from paddlespeech.t2s.frontend.generate_lexicon import generate_lexicon from paddlespeech.t2s.frontend.generate_lexicon import generate_lexicon
from paddlespeech.t2s.frontend.tone_sandhi import ToneSandhi from paddlespeech.t2s.frontend.tone_sandhi import ToneSandhi
from paddlespeech.t2s.frontend.zh_normalization.text_normlization import TextNormalizer from paddlespeech.t2s.frontend.zh_normalization.text_normlization import TextNormalizer
from paddlespeech.t2s.ssml.xml_processor import MixTextProcessor
INITIALS = [ INITIALS = [
'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'zh', 'ch', 'sh', 'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'zh', 'ch', 'sh',
@ -81,6 +83,7 @@ class Frontend():
g2p_model="g2pW", g2p_model="g2pW",
phone_vocab_path=None, phone_vocab_path=None,
tone_vocab_path=None): tone_vocab_path=None):
self.mix_ssml_processor = MixTextProcessor()
self.tone_modifier = ToneSandhi() self.tone_modifier = ToneSandhi()
self.text_normalizer = TextNormalizer() self.text_normalizer = TextNormalizer()
self.punc = ":,;。?!“”‘’':,;.?!" self.punc = ":,;。?!“”‘’':,;.?!"
@ -281,6 +284,65 @@ class Frontend():
phones_list.append(merge_list) phones_list.append(merge_list)
return phones_list return phones_list
def _split_word_to_char(self, words):
res = []
for x in words:
res.append(x)
return res
# if using ssml, have pingyin specified, assign pinyin to words
def _g2p_assign(self,
words: List[str],
pinyin_spec: List[str],
merge_sentences: bool=True) -> List[List[str]]:
phones_list = []
initials = []
finals = []
words = self._split_word_to_char(words[0])
for pinyin, char in zip(pinyin_spec, words):
sub_initials = []
sub_finals = []
pinyin = pinyin.replace("u:", "v")
#self.pinyin2phone: is a dict with all pinyin mapped with sheng_mu yun_mu
if pinyin in self.pinyin2phone:
initial_final_list = self.pinyin2phone[pinyin].split(" ")
if len(initial_final_list) == 2:
sub_initials.append(initial_final_list[0])
sub_finals.append(initial_final_list[1])
elif len(initial_final_list) == 1:
sub_initials.append('')
sub_finals.append(initial_final_list[1])
else:
# If it's not pinyin (possibly punctuation) or no conversion is required
sub_initials.append(pinyin)
sub_finals.append(pinyin)
initials.append(sub_initials)
finals.append(sub_finals)
initials = sum(initials, [])
finals = sum(finals, [])
phones = []
for c, v in zip(initials, finals):
# NOTE: post process for pypinyin outputs
# we discriminate i, ii and iii
if c and c not in self.punc:
phones.append(c)
if c and c in self.punc:
phones.append('sp')
if v and v not in self.punc:
phones.append(v)
phones_list.append(phones)
if merge_sentences:
merge_list = sum(phones_list, [])
# rm the last 'sp' to avoid the noise at the end
# cause in the training data, no 'sp' in the end
if merge_list[-1] == 'sp':
merge_list = merge_list[:-1]
phones_list = []
phones_list.append(merge_list)
return phones_list
def _merge_erhua(self, def _merge_erhua(self,
initials: List[str], initials: List[str],
finals: List[str], finals: List[str],
@ -396,6 +458,52 @@ class Frontend():
print("----------------------------") print("----------------------------")
return phonemes return phonemes
#@an added for ssml pinyin
def get_phonemes_ssml(self,
ssml_inputs: list,
merge_sentences: bool=True,
with_erhua: bool=True,
robot: bool=False,
print_info: bool=False) -> List[List[str]]:
all_phonemes = []
for word_pinyin_item in ssml_inputs:
phonemes = []
sentence, pinyin_spec = itemgetter(0, 1)(word_pinyin_item)
sentences = self.text_normalizer.normalize(sentence)
if len(pinyin_spec) == 0:
phonemes = self._g2p(
sentences,
merge_sentences=merge_sentences,
with_erhua=with_erhua)
else:
# phonemes should be pinyin_spec
phonemes = self._g2p_assign(
sentences, pinyin_spec, merge_sentences=merge_sentences)
all_phonemes = all_phonemes + phonemes
if robot:
new_phonemes = []
for sentence in all_phonemes:
new_sentence = []
for item in sentence:
# `er` only have tone `2`
if item[-1] in "12345" and item != "er2":
item = item[:-1] + "1"
new_sentence.append(item)
new_phonemes.append(new_sentence)
all_phonemes = new_phonemes
if print_info:
print("----------------------------")
print("text norm results:")
print(sentences)
print("----------------------------")
print("g2p results:")
print(all_phonemes[0])
print("----------------------------")
return [sum(all_phonemes, [])]
def get_input_ids(self, def get_input_ids(self,
sentence: str, sentence: str,
merge_sentences: bool=True, merge_sentences: bool=True,
@ -405,6 +513,7 @@ class Frontend():
add_blank: bool=False, add_blank: bool=False,
blank_token: str="<pad>", blank_token: str="<pad>",
to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]: to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]:
phonemes = self.get_phonemes( phonemes = self.get_phonemes(
sentence, sentence,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
@ -437,3 +546,49 @@ class Frontend():
if temp_phone_ids: if temp_phone_ids:
result["phone_ids"] = temp_phone_ids result["phone_ids"] = temp_phone_ids
return result return result
# @an added for ssml
def get_input_ids_ssml(
self,
sentence: str,
merge_sentences: bool=True,
get_tone_ids: bool=False,
robot: bool=False,
print_info: bool=False,
add_blank: bool=False,
blank_token: str="<pad>",
to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]:
l_inputs = MixTextProcessor.get_pinyin_split(sentence)
phonemes = self.get_phonemes_ssml(
l_inputs,
merge_sentences=merge_sentences,
print_info=print_info,
robot=robot)
result = {}
phones = []
tones = []
temp_phone_ids = []
temp_tone_ids = []
for part_phonemes in phonemes:
phones, tones = self._get_phone_tone(
part_phonemes, get_tone_ids=get_tone_ids)
if add_blank:
phones = insert_after_character(phones, blank_token)
if tones:
tone_ids = self._t2id(tones)
if to_tensor:
tone_ids = paddle.to_tensor(tone_ids)
temp_tone_ids.append(tone_ids)
if phones:
phone_ids = self._p2id(phones)
# if use paddle.to_tensor() in onnxruntime, the first time will be too low
if to_tensor:
phone_ids = paddle.to_tensor(phone_ids)
temp_phone_ids.append(phone_ids)
if temp_tone_ids:
result["tone_ids"] = temp_tone_ids
if temp_phone_ids:
result["phone_ids"] = temp_phone_ids
return result

@ -0,0 +1,14 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .xml_processor import *

@ -0,0 +1,158 @@
# -*- coding: utf-8 -*-
import re
import xml.dom.minidom
import xml.parsers.expat
from xml.dom.minidom import Node
from xml.dom.minidom import parseString
'''
Note: xml 有5种特殊字符 &<>"'
其一采用<![CDATA[ ]]>特殊标签将包含特殊字符的字符串封装起来
例如
<TitleName><![CDATA["姓名"]]></TitleName>
其二使用XML转义序列表示这些特殊的字符这5个特殊字符所对应XML转义序列为
& &amp;
< &lt;
> &gt;
" &quot;
' &apos;
例如
<TitleName>&quot;姓名&quot;</TitleName>
'''
class MixTextProcessor():
def __repr__(self):
print("@an MixTextProcessor class")
def get_xml_content(self, mixstr):
'''返回字符串的 xml 内容'''
xmlptn = re.compile(r"<speak>.*?</speak>", re.M | re.S)
ctn = re.search(xmlptn, mixstr)
if ctn:
return ctn.group(0)
else:
return None
def get_content_split(self, mixstr):
''' 文本分解,顺序加了列表中,按非 xml 和 xml 分开,对应的字符串,带标点符号
不能去除空格因为 xml 中tag 属性带空格
'''
ctlist = []
# print("Testing:",mixstr[:20])
patn = re.compile(r'(.*\s*?)(<speak>.*?</speak>)(.*\s*)$', re.M | re.S)
mat = re.match(patn, mixstr)
if mat:
pre_xml = mat.group(1)
in_xml = mat.group(2)
after_xml = mat.group(3)
ctlist.append(pre_xml)
ctlist.append(in_xml)
ctlist.append(after_xml)
return ctlist
else:
ctlist.append(mixstr)
return ctlist
@classmethod
def get_pinyin_split(self, mixstr):
ctlist = []
patn = re.compile(r'(.*\s*?)(<speak>.*?</speak>)(.*\s*)$', re.M | re.S)
mat = re.match(patn, mixstr)
if mat:
pre_xml = mat.group(1)
in_xml = mat.group(2)
after_xml = mat.group(3)
ctlist.append([pre_xml, []])
dom = DomXml(in_xml)
pinyinlist = dom.get_pinyins_for_xml()
ctlist = ctlist + pinyinlist
ctlist.append([after_xml, []])
else:
ctlist.append([mixstr, []])
return ctlist
class DomXml():
def __init__(self, xmlstr):
self.tdom = parseString(xmlstr) #Document
self.root = self.tdom.documentElement #Element
self.rnode = self.tdom.childNodes #NodeList
def get_text(self):
'''返回 xml 内容的所有文本内容的列表'''
res = []
for x1 in self.rnode:
if x1.nodeType == Node.TEXT_NODE:
res.append(x1.value)
else:
for x2 in x1.childNodes:
if isinstance(x2, xml.dom.minidom.Text):
res.append(x2.data)
else:
for x3 in x2.childNodes:
if isinstance(x3, xml.dom.minidom.Text):
res.append(x3.data)
else:
print("len(nodes of x3):", len(x3.childNodes))
return res
def get_xmlchild_list(self):
'''返回 xml 内容的列表,包括所有文本内容(不带 tag)'''
res = []
for x1 in self.rnode:
if x1.nodeType == Node.TEXT_NODE:
res.append(x1.value)
else:
for x2 in x1.childNodes:
if isinstance(x2, xml.dom.minidom.Text):
res.append(x2.data)
else:
for x3 in x2.childNodes:
if isinstance(x3, xml.dom.minidom.Text):
res.append(x3.data)
else:
print("len(nodes of x3):", len(x3.childNodes))
print(res)
return res
def get_pinyins_for_xml(self):
'''返回 xml 内容,字符串和拼音的 list '''
res = []
for x1 in self.rnode:
if x1.nodeType == Node.TEXT_NODE:
t = re.sub(r"\s+", "", x1.value)
res.append([t, []])
else:
for x2 in x1.childNodes:
if isinstance(x2, xml.dom.minidom.Text):
t = re.sub(r"\s+", "", x2.data)
res.append([t, []])
else:
# print("x2",x2,x2.tagName)
if x2.hasAttribute('pinyin'):
pinyin_value = x2.getAttribute("pinyin")
pinyins = pinyin_value.split(" ")
for x3 in x2.childNodes:
# print('x3',x3)
if isinstance(x3, xml.dom.minidom.Text):
t = re.sub(r"\s+", "", x3.data)
res.append([t, pinyins])
else:
print("len(nodes of x3):", len(x3.childNodes))
return res
def get_all_tags(self, tag_name):
'''获取所有的 tag 及属性值'''
alltags = self.root.getElementsByTagName(tag_name)
for x in alltags:
if x.hasAttribute('pinyin'): # pinyin
print(x.tagName, 'pinyin',
x.getAttribute('pinyin'), x.firstChild.data)

@ -25,8 +25,6 @@ DefinedClassifier = {
'ErnieLinear': ErnieLinear, 'ErnieLinear': ErnieLinear,
} }
tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
def _clean_text(text, punc_list): def _clean_text(text, punc_list):
text = text.lower() text = text.lower()
@ -35,7 +33,7 @@ def _clean_text(text, punc_list):
return text return text
def preprocess(text, punc_list): def preprocess(text, punc_list, tokenizer):
clean_text = _clean_text(text, punc_list) clean_text = _clean_text(text, punc_list)
assert len(clean_text) > 0, f'Invalid input string: {text}' assert len(clean_text) > 0, f'Invalid input string: {text}'
tokenized_input = tokenizer( tokenized_input = tokenizer(
@ -51,7 +49,8 @@ def test(args):
with open(args.config) as f: with open(args.config) as f:
config = CfgNode(yaml.safe_load(f)) config = CfgNode(yaml.safe_load(f))
print("========Args========") print("========Args========")
print(yaml.safe_dump(vars(args))) print(yaml.safe_dump(vars(args), allow_unicode=True))
# print(args)
print("========Config========") print("========Config========")
print(config) print(config)
@ -61,10 +60,16 @@ def test(args):
punc_list.append(line.strip()) punc_list.append(line.strip())
model = DefinedClassifier[config["model_type"]](**config["model"]) model = DefinedClassifier[config["model_type"]](**config["model"])
# print(model)
pretrained_token = config['data_params']['pretrained_token']
tokenizer = ErnieTokenizer.from_pretrained(pretrained_token)
# tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
state_dict = paddle.load(args.checkpoint) state_dict = paddle.load(args.checkpoint)
model.set_state_dict(state_dict["main_params"]) model.set_state_dict(state_dict["main_params"])
model.eval() model.eval()
_inputs = preprocess(args.text, punc_list) _inputs = preprocess(args.text, punc_list, tokenizer)
seq_len = _inputs['seq_len'] seq_len = _inputs['seq_len']
input_ids = paddle.to_tensor(_inputs['input_ids']).unsqueeze(0) input_ids = paddle.to_tensor(_inputs['input_ids']).unsqueeze(0)
seg_ids = paddle.to_tensor(_inputs['seg_ids']).unsqueeze(0) seg_ids = paddle.to_tensor(_inputs['seg_ids']).unsqueeze(0)

Loading…
Cancel
Save