Merge pull request #3407 from zh794390558/roformer

Roformer
pull/3424/head
Hui Zhang 12 months ago committed by GitHub
commit 897dcc37e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,14 +1,31 @@
# Aishell # Aishell
## Conformer ## RoFormer Streaming
paddle version: 2.2.2 paddle version: 2.5.0
paddlespeech version: 1.0.1 paddlespeech version: 1.5.0
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- | Tesla V100-SXM2-32GB: 1 node, 4 card
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention | - | 0.0522 | Global BachSize: 32 * 4
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | ctc_greedy_search | - | 0.0481 | Training Done: 1 day, 12:56:39.639646
| conformer | 47.07M | conf/conformer.yaml | spec_aug| test | ctc_prefix_beam_search | - | 0.0480 | ### `decoding.decoding_chunk_size=16`
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention_rescoring | - | 0.0460 |
> chunk_size=16, ((16 - 1) * 4 + 7) * 10ms = (16 * 4 + 3) * 10ms = 670ms
| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | attention | 16, -1 | - | 5.63 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | ctc_greedy_search | 16, -1 | - | 6.13 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | ctc_prefix_beam_search | 16, -1 | - | 6.13 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | attention_rescoring | 16, -1 | - | 5.44 |
### `decoding.decoding_chunk_size=-1`
| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | attention | -1, -1 | - | 5.39 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | ctc_greedy_search | -1, -1 | - | 5.51 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | ctc_prefix_beam_search | -1, -1 | - | 5.51 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | attention_rescoring | -1, -1 | - | 4.99 |
## Conformer Streaming ## Conformer Streaming
@ -24,6 +41,17 @@ Need set `decoding.decoding_chunk_size=16` when decoding.
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | attention_rescoring | 16, -1 | - | 0.051968 | | conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | attention_rescoring | 16, -1 | - | 0.051968 |
## Conformer
paddle version: 2.2.2
paddlespeech version: 1.0.1
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention | - | 0.0522 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | ctc_greedy_search | - | 0.0481 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | ctc_prefix_beam_search | - | 0.0480 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention_rescoring | - | 0.0460 |
## Transformer ## Transformer
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER | | Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER |

@ -0,0 +1,98 @@
############################################
# Network Architecture #
############################################
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rope_pos' # abs_pos, rel_pos, rope_pos
selfattention_layer_type: 'rel_selfattn' # unused
causal: true
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false
# decoder related
decoder: transformer # transformer, bitransformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
r_num_blocks: 0 # only for bitransformer
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
reverse_weight: 0.0 # only for bitransformer
length_normalized_loss: false
init_type: 'kaiming_uniform' # !Warning: need to convergence
###########################################
# Data #
###########################################
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
###########################################
# Dataloader #
###########################################
vocab_filepath: data/lang_char/vocab.txt
spm_model_prefix: ''
unit_type: 'char'
preprocess_config: conf/preprocess.yaml
feat_dim: 80
stride_ms: 10.0
window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 2
subsampling_factor: 1
num_encs: 1
###########################################
# Training #
###########################################
n_epoch: 240
accum_grad: 1
global_grad_clip: 5.0
dist_sampler: True
optim: adam
optim_conf:
lr: 0.001
weight_decay: 1.0e-6
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5

@ -0,0 +1,98 @@
############################################
# Network Architecture #
############################################
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rope_pos' # abs_pos, rel_pos, rope_pos
selfattention_layer_type: 'rel_selfattn' # unused
causal: true
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false
# decoder related
decoder: bitransformer # transformer, bitransformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 3
r_num_blocks: 3 # only for bitransformer
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
reverse_weight: 0.3 # only for bitransformer
length_normalized_loss: false
init_type: 'kaiming_uniform' # !Warning: need to convergence
###########################################
# Data #
###########################################
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
###########################################
# Dataloader #
###########################################
vocab_filepath: data/lang_char/vocab.txt
spm_model_prefix: ''
unit_type: 'char'
preprocess_config: conf/preprocess.yaml
feat_dim: 80
stride_ms: 10.0
window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 2
subsampling_factor: 1
num_encs: 1
###########################################
# Training #
###########################################
n_epoch: 240
accum_grad: 1
global_grad_clip: 5.0
dist_sampler: True
optim: adam
optim_conf:
lr: 0.001
weight_decay: 1.0e-6
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5

@ -20,30 +20,6 @@ import numpy as np
import paddle import paddle
def define_argparse():
parser = argparse.ArgumentParser(description='average model')
parser.add_argument('--dst_model', required=True, help='averaged model')
parser.add_argument(
'--ckpt_dir', required=True, help='ckpt model dir for average')
parser.add_argument(
'--val_best', action="store_true", help='averaged model')
parser.add_argument(
'--num', default=5, type=int, help='nums for averaged model')
parser.add_argument(
'--min_epoch',
default=0,
type=int,
help='min epoch used for averaging model')
parser.add_argument(
'--max_epoch',
default=65536, # Big enough
type=int,
help='max epoch used for averaging model')
args = parser.parse_args()
return args
def average_checkpoints(dst_model="", def average_checkpoints(dst_model="",
ckpt_dir="", ckpt_dir="",
val_best=True, val_best=True,
@ -85,7 +61,7 @@ def average_checkpoints(dst_model="",
print(path_list) print(path_list)
avg = None avg = None
num = args.num num = num
assert num == len(path_list) assert num == len(path_list)
for path in path_list: for path in path_list:
print(f'Processing {path}') print(f'Processing {path}')
@ -100,14 +76,14 @@ def average_checkpoints(dst_model="",
if avg[k] is not None: if avg[k] is not None:
avg[k] /= num avg[k] /= num
paddle.save(avg, args.dst_model) paddle.save(avg, dst_model)
print(f'Saving to {args.dst_model}') print(f'Saving to {dst_model}')
meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json' meta_path = os.path.splitext(dst_model)[0] + '.avg.json'
with open(meta_path, 'w') as f: with open(meta_path, 'w') as f:
data = json.dumps({ data = json.dumps({
"mode": 'val_best' if args.val_best else 'latest', "mode": 'val_best' if val_best else 'latest',
"avg_ckpt": args.dst_model, "avg_ckpt": dst_model,
"val_loss_mean": avg_val_score, "val_loss_mean": avg_val_score,
"ckpts": path_list, "ckpts": path_list,
"epochs": selected_epochs.tolist(), "epochs": selected_epochs.tolist(),
@ -116,9 +92,40 @@ def average_checkpoints(dst_model="",
f.write(data + "\n") f.write(data + "\n")
def define_argparse():
parser = argparse.ArgumentParser(description='average model')
parser.add_argument('--dst_model', required=True, help='averaged model')
parser.add_argument(
'--ckpt_dir', required=True, help='ckpt model dir for average')
parser.add_argument(
'--val_best', action="store_true", help='averaged model')
parser.add_argument(
'--num', default=5, type=int, help='nums for averaged model')
parser.add_argument(
'--min_epoch',
default=0,
type=int,
help='min epoch used for averaging model')
parser.add_argument(
'--max_epoch',
default=65536, # Big enough
type=int,
help='max epoch used for averaging model')
args = parser.parse_args()
print(args)
return args
def main(): def main():
args = define_argparse() args = define_argparse()
average_checkpoints(args) average_checkpoints(
dst_model=args.dst_model,
ckpt_dir=args.ckpt_dir,
val_best=args.val_best,
num=args.num,
min_epoch=args.min_epoch,
max_epoch=args.max_epoch)
if __name__ == '__main__': if __name__ == '__main__':

@ -145,7 +145,6 @@ class U2BaseModel(ASRInterface, nn.Layer):
text_lengths) text_lengths)
ctc_time = time.time() - start ctc_time = time.time() - start
#logger.debug(f"ctc time: {ctc_time}") #logger.debug(f"ctc time: {ctc_time}")
if loss_ctc is None: if loss_ctc is None:
loss = loss_att loss = loss_att
elif loss_att is None: elif loss_att is None:
@ -916,6 +915,8 @@ class U2Model(U2DecodeModel):
decoder_type = configs.get('decoder', 'transformer') decoder_type = configs.get('decoder', 'transformer')
logger.debug(f"U2 Decoder type: {decoder_type}") logger.debug(f"U2 Decoder type: {decoder_type}")
if decoder_type == 'transformer': if decoder_type == 'transformer':
configs['model_conf'].pop('reverse_weight', None)
configs['decoder_conf'].pop('r_num_blocks', None)
decoder = TransformerDecoder(vocab_size, decoder = TransformerDecoder(vocab_size,
encoder.output_size(), encoder.output_size(),
**configs['decoder_conf']) **configs['decoder_conf'])

@ -15,6 +15,7 @@
# Modified from wenet(https://github.com/wenet-e2e/wenet) # Modified from wenet(https://github.com/wenet-e2e/wenet)
"""Multi-Head Attention layer definition.""" """Multi-Head Attention layer definition."""
import math import math
from typing import List
from typing import Tuple from typing import Tuple
import paddle import paddle
@ -26,7 +27,10 @@ from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = ["MultiHeadedAttention", "RelPositionMultiHeadedAttention"] __all__ = [
"MultiHeadedAttention", "RelPositionMultiHeadedAttention",
"RoPERelPositionMultiHeadedAttention"
]
# Relative Positional Encodings # Relative Positional Encodings
# https://www.jianshu.com/p/c0608efcc26f # https://www.jianshu.com/p/c0608efcc26f
@ -165,6 +169,7 @@ class MultiHeadedAttention(nn.Layer):
and `head * d_k == size` and `head * d_k == size`
""" """
# (B,T,D) -> (B,T,H,D/H)
q, k, v = self.forward_qkv(query, key, value) q, k, v = self.forward_qkv(query, key, value)
# when export onnx model, for 1st chunk, we feed # when export onnx model, for 1st chunk, we feed
@ -373,3 +378,139 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
self.d_k) # (batch, head, time1, time2) self.d_k) # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask), new_cache return self.forward_attention(v, scores, mask), new_cache
class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with RoPE relative position encoding."""
def __init__(self,
n_head,
n_feat,
dropout_rate,
adaptive_scale=False,
init_weights=False):
"""Construct an RelPositionMultiHeadedAttention object.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
super().__init__(n_head, n_feat, dropout_rate)
def align(self, tensor: paddle.Tensor, axes: List[int], ndim=None):
"""重新对齐tensor批量版expand_dims
axes原来的第i维对齐新tensor的第axes[i]
ndim新tensor的维度
"""
assert len(axes) == tensor.dim()
assert ndim or min(axes) >= 0
ndim = ndim or max(axes) + 1
# a[0, None, 1] = a[0, np.newaxis, 1]
indices = [None] * ndim
for i in axes:
# slice nothing, a[0, slice(None), 1] = a[0, :, 1]
indices[i] = slice(None)
return tensor[indices]
def apply_rotary_position_embeddings(self, sinusoidal, *tensors):
"""应用RoPE到tensors中
其中sinusoidal.shape=[B, T, D]tensors为tensor的列表
tensor.shape=[B, T, ..., D], or (B,H,T,D/H)
"""
assert len(tensors) > 0, 'at least one input tensor'
assert all(
[tensor.shape == tensors[0].shape
for tensor in tensors[1:]]), 'all tensors must have the same shape'
# (B,H,T,D)
ndim = tensors[0].dim()
_, H, T, D = tensors[0].shape
# sinusoidal shape same with tensors[0]
# [B,T,D] -> [B,T,H,D/H] -> (B,H,T,D/H)
# sinusoidal = self.align(sinusoidal, [0, 1, -1], ndim)
sinusoidal = sinusoidal.reshape((1, T, H, D)).transpose([0, 2, 1, 3])
# http://man.hubwiz.com/docset/TensorFlow.docset/Contents/Resources/Documents/api_docs/python/tf/keras/backend/repeat_elements.html
# like np.repeat, x (s1, s2, s3), axis 1, (s1, s2*rep, s3)
# [b,T, ..., d/2] -> [b,T, ..., d]
cos_pos = paddle.repeat_interleave(sinusoidal[..., 1::2], 2, axis=-1)
sin_pos = paddle.repeat_interleave(sinusoidal[..., 0::2], 2, axis=-1)
outputs = []
for tensor in tensors:
# x2 = [-x2, x1, -x4, x3, ..., -x_d, x_{d-1}]
tensor2 = paddle.stack([-tensor[..., 1::2], tensor[..., ::2]], ndim)
tensor2 = paddle.reshape(tensor2, paddle.shape(tensor))
# 公式 34, out = x * cos_pos + x2 * sin_pos
outputs.append(tensor * cos_pos + tensor2 * sin_pos)
return outputs[0] if len(outputs) == 1 else outputs
def forward(self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
pos_emb: paddle.Tensor=paddle.empty([0]),
cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Ref: https://github.com/facebookresearch/llama/blob/main/llama/model.py
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
key (paddle.Tensor): Key tensor (#batch, time2, size).
value (paddle.Tensor): Value tensor (#batch, time2, size).
mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (paddle.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
paddle.Tensor: Output tensor (#batch, time1, d_model).
paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
# q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
# f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index
# q_t always is chunk_size
q_t = q.shape[2]
q = self.apply_rotary_position_embeddings(pos_emb[:, -q_t:, :], q)
# k will increase when in streaming decoding.
k = self.apply_rotary_position_embeddings(pos_emb[:, -q_t:, :], k)
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.shape[0] > 0:
# last dim `d_k * 2` for (key, val)
key_cache, value_cache = paddle.split(cache, 2, axis=-1)
k = paddle.concat([key_cache, k], axis=2)
v = paddle.concat([value_cache, v], axis=2)
# We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = paddle.concat((k, v), axis=-1)
# dot(q, k)
scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache

@ -85,18 +85,21 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
reverse (bool, optional): Not used. Defaults to False. reverse (bool, optional): Not used. Defaults to False.
""" """
nn.Layer.__init__(self) nn.Layer.__init__(self)
self.d_model = d_model self.d_model = paddle.to_tensor(d_model)
self.max_len = max_len self.max_len = max_len
self.xscale = paddle.to_tensor(math.sqrt(self.d_model)) self.xscale = paddle.to_tensor(math.sqrt(self.d_model))
self.dropout = nn.Dropout(p=dropout_rate) self.dropout = nn.Dropout(p=dropout_rate)
self.base = paddle.to_tensor(10000.0)
self.pe = paddle.zeros([1, self.max_len, self.d_model]) #[B=1,T,D] self.pe = paddle.zeros([1, self.max_len, self.d_model]) #[B=1,T,D]
position = paddle.arange( position = paddle.arange(
0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1] 0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1]
# base^{-2(i-1)/d)}, i \in (1,2...,d/2)
div_term = paddle.exp( div_term = paddle.exp(
paddle.arange(0, self.d_model, 2, dtype=paddle.float32) * -paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
-(math.log(10000.0) / self.d_model)) (paddle.log(self.base) / self.d_model))
# [B,T,D]
self.pe[:, :, 0::2] = paddle.sin(position * div_term) self.pe[:, :, 0::2] = paddle.sin(position * div_term)
self.pe[:, :, 1::2] = paddle.cos(position * div_term) self.pe[:, :, 1::2] = paddle.cos(position * div_term)
@ -161,6 +164,98 @@ class RelPositionalEncoding(PositionalEncoding):
assert offset + x.shape[ assert offset + x.shape[
1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format( 1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format(
offset, x.shape[1], self.max_len) offset, x.shape[1], self.max_len)
x = x * self.xscale x = x * self.xscale
pos_emb = self.pe[:, offset:offset + x.shape[1]] pos_emb = self.pe[:, offset:offset + x.shape[1]]
return self.dropout(x), self.dropout(pos_emb) return self.dropout(x), self.dropout(pos_emb)
# RotaryRelPositionalEncoding is same to RelPositionalEncoding
class ScaledRotaryRelPositionalEncoding(RelPositionalEncoding):
"""Scaled Rotary Relative positional encoding module.
POSITION INTERPOLATION: : https://arxiv.org/pdf/2306.15595v2.pdf
"""
def __init__(self,
d_model: int,
dropout_rate: float,
max_len: int=5000,
scale=1):
"""
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int, optional): [Maximum input length.]. Defaults to 5000.
scale (int): Interpolation max input length to `scale * max_len` positions.
"""
super().__init__(d_model, dropout_rate, max_len, reverse=True)
self.pscale = paddle.to_tensor(scale)
self.max_len = max_len * scale
def sinusoidal_embeddings(self,
pos: paddle.Tensor,
dim: paddle.Tensor,
base=10000) -> paddle.Tensor:
"""计算pos位置的dim维sinusoidal编码"""
assert dim % 2 == 0
# (d/2,)
indices = paddle.arange(0, dim // 2, dtype=pos.dtype)
indices = paddle.pow(paddle.cast(base, pos.dtype), -2 * indices / dim)
# pos (1, T), indices (d/2,) -> (1, T, d/2)
embeddings = paddle.einsum('...,d->...d', pos, indices)
# (1, T, d/2, 2)
embeddings = paddle.stack(
[paddle.sin(embeddings), paddle.cos(embeddings)], axis=-1)
# (1, T, d)
embeddings = paddle.flatten(embeddings, start_axis=-2, stop_axis=-1)
return embeddings
def forward(self, x: paddle.Tensor,
offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute positional encoding.
Args:
x (paddle.Tensor): Input tensor (batch, time, `*`).
Returns:
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
x = x * self.xscale
B, T, D = x.shape
assert D == self.d_model
# postion interploation
start = 0
end = T * self.pscale
assert end <= self.max_len
position = paddle.arange(start, end, dtype=x.dtype).unsqueeze(0)
position *= 1.0 / self.pscale
pe = self.sinusoidal_embeddings(position, self.d_model, base=self.base)
pos_emb = pe[:, offset:offset + x.shape[1]]
return self.dropout(x), self.dropout(pos_emb)
def position_encoding(self, offset: int, size: int) -> paddle.Tensor:
""" For getting encoding in a streaming fashion
Attention!!!!!
we apply dropout only once at the whole utterance level in a none
streaming way, but will call this function several times with
increasing input size in a streaming scenario, so the dropout will
be applied several times.
Args:
offset (int): start offset
size (int): requried size of position encoding
Returns:
paddle.Tensor: Corresponding position encoding, #[1, T, D].
"""
# postion interploation
start = offset
end = (offset + size) * self.pscale
assert end <= self.max_len
position = paddle.arange(
start, end, dtype=paddle.get_default_dtype()).unsqueeze(0)
position *= 1.0 / self.pscale
pe = self.sinusoidal_embeddings(position, self.d_model, base=self.base)
return self.dropout(pe)

@ -28,6 +28,7 @@ from paddlespeech.s2t.modules.align import LayerNorm
from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.modules.align import Linear
from paddlespeech.s2t.modules.attention import MultiHeadedAttention from paddlespeech.s2t.modules.attention import MultiHeadedAttention
from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention
from paddlespeech.s2t.modules.attention import RoPERelPositionMultiHeadedAttention
from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule
from paddlespeech.s2t.modules.embedding import NoPositionalEncoding from paddlespeech.s2t.modules.embedding import NoPositionalEncoding
from paddlespeech.s2t.modules.embedding import PositionalEncoding from paddlespeech.s2t.modules.embedding import PositionalEncoding
@ -115,6 +116,8 @@ class BaseEncoder(nn.Layer):
pos_enc_class = PositionalEncoding pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "rel_pos": elif pos_enc_layer_type == "rel_pos":
pos_enc_class = RelPositionalEncoding pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "rope_pos":
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "no_pos": elif pos_enc_layer_type == "no_pos":
pos_enc_class = NoPositionalEncoding pos_enc_class = NoPositionalEncoding
else: else:
@ -230,14 +233,14 @@ class BaseEncoder(nn.Layer):
xs = self.global_cmvn(xs) xs = self.global_cmvn(xs)
# before embed, xs=(B, T, D1), pos_emb=(B=1, T, D) # before embed, xs=(B, T, D1), pos_emb=(B=1, T, D)
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset) xs, _, _ = self.embed(xs, tmp_masks, offset=offset)
# after embed, xs=(B=1, chunk_size, hidden-dim) # after embed, xs=(B=1, chunk_size, hidden-dim)
elayers, _, cache_t1, _ = att_cache.shape elayers, _, cache_t1, _ = att_cache.shape
chunk_size = xs.shape[1] chunk_size = xs.shape[1]
attention_key_size = cache_t1 + chunk_size attention_key_size = cache_t1 + chunk_size
# only used when using `RelPositionMultiHeadedAttention` # only used when using `RelPositionMultiHeadedAttention` and `RoPERelPositionMultiHeadedAttention`
pos_emb = self.embed.position_encoding( pos_emb = self.embed.position_encoding(
offset=offset - cache_t1, size=attention_key_size) offset=offset - cache_t1, size=attention_key_size)
@ -474,21 +477,35 @@ class ConformerEncoder(BaseEncoder):
activation = get_activation(activation_type) activation = get_activation(activation_type)
# self-attention module definition # self-attention module definition
encoder_selfattn_layer = RelPositionMultiHeadedAttention encoder_dim = output_size
encoder_selfattn_layer_args = (attention_heads, output_size, if pos_enc_layer_type == "abs_pos":
attention_dropout_rate) encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate)
elif pos_enc_layer_type == "rel_pos":
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate)
elif pos_enc_layer_type == "rope_pos":
encoder_selfattn_layer = RoPERelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate)
else:
raise ValueError(
f"pos_enc_layer_type {pos_enc_layer_type} not supported.")
# feed-forward module definition # feed-forward module definition
positionwise_layer = PositionwiseFeedForward positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (output_size, linear_units, dropout_rate, positionwise_layer_args = (encoder_dim, linear_units, dropout_rate,
activation) activation)
# convolution module definition # convolution module definition
convolution_layer = ConvolutionModule convolution_layer = ConvolutionModule
convolution_layer_args = (output_size, cnn_module_kernel, activation, convolution_layer_args = (encoder_dim, cnn_module_kernel, activation,
cnn_module_norm, causal) cnn_module_norm, causal)
self.encoders = nn.LayerList([ self.encoders = nn.LayerList([
ConformerEncoderLayer( ConformerEncoderLayer(
size=output_size, size=encoder_dim,
self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args), self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args),
feed_forward=positionwise_layer(*positionwise_layer_args), feed_forward=positionwise_layer(*positionwise_layer_args),
feed_forward_macaron=positionwise_layer( feed_forward_macaron=positionwise_layer(
@ -580,15 +597,23 @@ class SqueezeformerEncoder(nn.Layer):
activation = get_activation(activation_type) activation = get_activation(activation_type)
# self-attention module definition # self-attention module definition
if pos_enc_layer_type != "rel_pos": if pos_enc_layer_type == "abs_pos":
encoder_selfattn_layer = MultiHeadedAttention encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, output_size, encoder_selfattn_layer_args = (attention_heads, output_size,
attention_dropout_rate) attention_dropout_rate)
else: elif pos_enc_layer_type == "rel_pos":
encoder_selfattn_layer = RelPositionMultiHeadedAttention encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, encoder_dim, encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate, attention_dropout_rate,
adaptive_scale, init_weights) adaptive_scale, init_weights)
elif pos_enc_layer_type == "rope_pos":
encoder_selfattn_layer = RoPERelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate,
adaptive_scale, init_weights)
else:
raise ValueError(
f"pos_enc_layer_type {pos_enc_layer_type} not supported.")
# feed-forward module definition # feed-forward module definition
positionwise_layer = PositionwiseFeedForward positionwise_layer = PositionwiseFeedForward

@ -48,7 +48,7 @@ class TransformerEncoderLayer(nn.Layer):
Args: Args:
size (int): Input dimension. size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance. self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` `MultiHeadedAttention`, `RelPositionMultiHeadedAttention` or `RoPERelPositionMultiHeadedAttention`
instance can be used as the argument. instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance. feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward`, instance can be used as the argument. `PositionwiseFeedForward`, instance can be used as the argument.
@ -147,7 +147,7 @@ class ConformerEncoderLayer(nn.Layer):
Args: Args:
size (int): Input dimension. size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance. self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` `MultiHeadedAttention`, `RelPositionMultiHeadedAttention` or `RoPERelPositionMultiHeadedAttention`
instance can be used as the argument. instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance. feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument. `PositionwiseFeedForward` instance can be used as the argument.
@ -298,7 +298,7 @@ class SqueezeformerEncoderLayer(nn.Layer):
Args: Args:
size (int): Input dimension. size (int): Input dimension.
self_attn (paddle.nn.Layer): Self-attention module instance. self_attn (paddle.nn.Layer): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` `MultiHeadedAttention`, `RelPositionMultiHeadedAttention` or `RoPERelPositionMultiHeadedAttention`
instance can be used as the argument. instance can be used as the argument.
feed_forward1 (paddle.nn.Layer): Feed-forward module instance. feed_forward1 (paddle.nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument. `PositionwiseFeedForward` instance can be used as the argument.

Loading…
Cancel
Save