diff --git a/README.md b/README.md
index a6db486d..49e40624 100644
--- a/README.md
+++ b/README.md
@@ -179,7 +179,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision
- Scan the QR code below with your Wechat, you can access to official technical exchange group and get the bonus ( more than 20GB learning materials, such as papers, codes and videos ) and the live link of the lessons. Look forward to your participation.
-
+
## Installation
diff --git a/README_cn.md b/README_cn.md
index a71e5b24..bf3ff4df 100644
--- a/README_cn.md
+++ b/README_cn.md
@@ -162,22 +162,7 @@
- 🧩 级联模型应用: 作为传统语音任务的扩展,我们结合了自然语言处理、计算机视觉等任务,实现更接近实际需求的产业级应用。
-### 近期活动
-
- ❗️重磅❗️飞桨智慧金融行业系列直播课
-✅ 覆盖智能风控、智能运维、智能营销、智能客服四大金融主流场景
-
-📆 9月6日-9月29日每周二、四19:00
-+ 智慧金融行业深入洞察
-+ 8节理论+实践精品直播课
-+ 10+真实产业场景范例教学及实践
-+ 更有免费算力+结业证书等礼品等你来拿
-扫码报名码住直播链接,与行业精英深度交流
-
-
-
-
-
+
### 近期更新
- 👑 2022.10.11: 新增 [Wav2vec2ASR](./examples/librispeech/asr3), 在 LibriSpeech 上针对ASR任务对wav2vec2.0 的fine-tuning.
- 🔥 2022.09.26: 新增 Voice Cloning, TTS finetune 和 ERNIE-SAT 到 [PaddleSpeech 网页应用](./demos/speech_web)。
@@ -200,13 +185,13 @@
### 🔥 加入技术交流群获取入群福利
- - 3 日直播课链接: 深度解读 PP-TTS、PP-ASR、PP-VPR 三项核心语音系统关键技术
+ - 3 日直播课链接: 深度解读 【一句话语音合成】【小样本语音合成】【定制化语音识别】语音交互技术
- 20G 学习大礼包:视频课程、前沿论文与学习资料
微信扫描二维码关注公众号,点击“马上报名”填写问卷加入官方交流群,获得更高效的问题答疑,与各行各业开发者充分交流,期待您的加入。
-
+
diff --git a/docs/source/released_model.md b/docs/source/released_model.md
index 4e76da03..2f3c9d09 100644
--- a/docs/source/released_model.md
+++ b/docs/source/released_model.md
@@ -9,7 +9,7 @@ Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER |
[Ds2 Online Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz) | Aishell Dataset | Char-based | 491 MB | 2 Conv + 5 LSTM layers | 0.0666 |-| 151 h | [D2 Online Aishell ASR0](../../examples/aishell/asr0) | onnx/inference/python |
[Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_offline_aishell_ckpt_1.0.1.model.tar.gz)| Aishell Dataset | Char-based | 1.4 GB | 2 Conv + 5 bidirectional LSTM layers| 0.0554 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0) | inference/python |
[Conformer Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz) | WenetSpeech Dataset | Char-based | 457 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.11 (test\_net) 0.1879 (test\_meeting) |-| 10000 h |- | python |
-[Conformer U2PP Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.1.4.model.tar.gz) | WenetSpeech Dataset | Char-based | 476 MB | Encoder:Conformer, Decoder:BiTransformer, Decoding method: Attention rescoring| 0.047198 (aishell test\_-1) 0.059212 (aishell test\_16) |-| 10000 h |- | python |
+[Conformer U2PP Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.3.0.model.tar.gz) | WenetSpeech Dataset | Char-based | 476 MB | Encoder:Conformer, Decoder:BiTransformer, Decoding method: Attention rescoring| 0.047198 (aishell test\_-1) 0.059212 (aishell test\_16) |-| 10000 h |- | python |
[Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.0544 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1) | python |
[Conformer Offline Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_1.0.1.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0460 |-| 151 h | [Conformer Offline Aishell ASR1](../../examples/aishell/asr1) | python |
[Transformer Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz) | Aishell Dataset | Char-based | 128 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0523 || 151 h | [Transformer Aishell ASR1](../../examples/aishell/asr1) | python |
diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py
index 437f6463..00414336 100644
--- a/paddlespeech/cli/asr/infer.py
+++ b/paddlespeech/cli/asr/infer.py
@@ -52,7 +52,7 @@ class ASRExecutor(BaseExecutor):
self.parser.add_argument(
'--model',
type=str,
- default='conformer_u2pp_wenetspeech',
+ default='conformer_u2pp_online_wenetspeech',
choices=[
tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys()
@@ -470,7 +470,7 @@ class ASRExecutor(BaseExecutor):
@stats_wrapper
def __call__(self,
audio_file: os.PathLike,
- model: str='conformer_u2pp_wenetspeech',
+ model: str='conformer_u2pp_online_wenetspeech',
lang: str='zh',
sample_rate: int=16000,
config: os.PathLike=None,
diff --git a/paddlespeech/resource/model_alias.py b/paddlespeech/resource/model_alias.py
index f5ec655b..8e9ecc4b 100644
--- a/paddlespeech/resource/model_alias.py
+++ b/paddlespeech/resource/model_alias.py
@@ -25,7 +25,6 @@ model_alias = {
"deepspeech2online": ["paddlespeech.s2t.models.ds2:DeepSpeech2Model"],
"conformer": ["paddlespeech.s2t.models.u2:U2Model"],
"conformer_online": ["paddlespeech.s2t.models.u2:U2Model"],
- "conformer_u2pp": ["paddlespeech.s2t.models.u2:U2Model"],
"conformer_u2pp_online": ["paddlespeech.s2t.models.u2:U2Model"],
"transformer": ["paddlespeech.s2t.models.u2:U2Model"],
"wenetspeech": ["paddlespeech.s2t.models.u2:U2Model"],
diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py
index efd6bb3f..df50a6a9 100644
--- a/paddlespeech/resource/pretrained_models.py
+++ b/paddlespeech/resource/pretrained_models.py
@@ -68,32 +68,12 @@ asr_dynamic_pretrained_models = {
'',
},
},
- "conformer_u2pp_wenetspeech-zh-16k": {
- '1.1': {
- 'url':
- 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.1.3.model.tar.gz',
- 'md5':
- '662b347e1d2131b7a4dc5398365e2134',
- 'cfg_path':
- 'model.yaml',
- 'ckpt_path':
- 'exp/chunk_conformer_u2pp/checkpoints/avg_10',
- 'model':
- 'exp/chunk_conformer_u2pp/checkpoints/avg_10.pdparams',
- 'params':
- 'exp/chunk_conformer_u2pp/checkpoints/avg_10.pdparams',
- 'lm_url':
- '',
- 'lm_md5':
- '',
- },
- },
"conformer_u2pp_online_wenetspeech-zh-16k": {
- '1.1': {
+ '1.3': {
'url':
- 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.1.4.model.tar.gz',
+ 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.3.0.model.tar.gz',
'md5':
- '3100fc1eac5779486cab859366992d0b',
+ '62d230c1bf27731192aa9d3b8deca300',
'cfg_path':
'model.yaml',
'ckpt_path':
diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py
index 128f87c0..d9568dcc 100644
--- a/paddlespeech/s2t/modules/attention.py
+++ b/paddlespeech/s2t/modules/attention.py
@@ -19,7 +19,6 @@ from typing import Tuple
import paddle
from paddle import nn
-from paddle.nn import functional as F
from paddle.nn import initializer as I
from paddlespeech.s2t.modules.align import Linear
@@ -56,16 +55,6 @@ class MultiHeadedAttention(nn.Layer):
self.linear_out = Linear(n_feat, n_feat)
self.dropout = nn.Dropout(p=dropout_rate)
- def _build_once(self, *args, **kwargs):
- super()._build_once(*args, **kwargs)
- # if self.self_att:
- # self.linear_kv = Linear(self.n_feat, self.n_feat*2)
- if not self.training:
- self.weight = paddle.concat(
- [self.linear_k.weight, self.linear_v.weight], axis=-1)
- self.bias = paddle.concat([self.linear_k.bias, self.linear_v.bias])
- self._built = True
-
def forward_qkv(self,
query: paddle.Tensor,
key: paddle.Tensor,
@@ -87,13 +76,8 @@ class MultiHeadedAttention(nn.Layer):
n_batch = query.shape[0]
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
- if self.training:
- k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
- v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
- else:
- k, v = F.linear(key, self.weight, self.bias).view(
- n_batch, -1, 2 * self.h, self.d_k).split(
- 2, axis=2)
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k)