add roformer

pull/3407/head
Hui Zhang 12 months ago
parent 94987f26df
commit 03e9ea9e52

@ -0,0 +1,98 @@
############################################
# Network Architecture #
############################################
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rpoe_pos' # abs_pos, rel_pos, rope_pos
selfattention_layer_type: 'rel_selfattn' # unused
causal: true
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false
# decoder related
decoder: transformer # transformer, bitransformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
r_num_blocks: 3 # only for bitransformer
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
reverse_weight: 0.3 # only for bitransformer
length_normalized_loss: false
init_type: 'kaiming_uniform' # !Warning: need to convergence
###########################################
# Data #
###########################################
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
###########################################
# Dataloader #
###########################################
vocab_filepath: data/lang_char/vocab.txt
spm_model_prefix: ''
unit_type: 'char'
preprocess_config: conf/preprocess.yaml
feat_dim: 80
stride_ms: 10.0
window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 2
subsampling_factor: 1
num_encs: 1
###########################################
# Training #
###########################################
n_epoch: 240
accum_grad: 1
global_grad_clip: 5.0
dist_sampler: True
optim: adam
optim_conf:
lr: 0.001
weight_decay: 1.0e-6
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5

@ -0,0 +1,98 @@
############################################
# Network Architecture #
############################################
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rpoe_pos' # abs_pos, rel_pos, rope_pos
selfattention_layer_type: 'rel_selfattn' # unused
causal: true
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false
# decoder related
decoder: bitransformer # transformer, bitransformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 3
r_num_blocks: 3 # only for bitransformer
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
reverse_weight: 0.3 # only for bitransformer
length_normalized_loss: false
init_type: 'kaiming_uniform' # !Warning: need to convergence
###########################################
# Data #
###########################################
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
###########################################
# Dataloader #
###########################################
vocab_filepath: data/lang_char/vocab.txt
spm_model_prefix: ''
unit_type: 'char'
preprocess_config: conf/preprocess.yaml
feat_dim: 80
stride_ms: 10.0
window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 2
subsampling_factor: 1
num_encs: 1
###########################################
# Training #
###########################################
n_epoch: 240
accum_grad: 1
global_grad_clip: 5.0
dist_sampler: True
optim: adam
optim_conf:
lr: 0.001
weight_decay: 1.0e-6
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5

@ -26,7 +26,10 @@ from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ["MultiHeadedAttention", "RelPositionMultiHeadedAttention"]
__all__ = [
"MultiHeadedAttention", "RelPositionMultiHeadedAttention",
"RoPERelPositionMultiHeadedAttention"
]
# Relative Positional Encodings
# https://www.jianshu.com/p/c0608efcc26f
@ -165,6 +168,7 @@ class MultiHeadedAttention(nn.Layer):
and `head * d_k == size`
"""
# (B,T,D) -> (B,T,H,D/H)
q, k, v = self.forward_qkv(query, key, value)
# when export onnx model, for 1st chunk, we feed
@ -373,3 +377,131 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
self.d_k) # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask), new_cache
class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with RoPE relative position encoding."""
def __init__(self,
n_head,
n_feat,
dropout_rate,
adaptive_scale=False,
init_weights=False):
"""Construct an RelPositionMultiHeadedAttention object.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
super().__init__(n_head, n_feat, dropout_rate)
def align(self, tensor: paddle.Tensor, axes: List[int], ndim=None):
"""重新对齐tensor批量版expand_dims
axes原来的第i维对齐新tensor的第axes[i]
ndim新tensor的维度
"""
assert len(axes) == tensor.dim()
assert ndim or min(axes) >= 0
ndim = ndim or max(axes) + 1
# a[0, None, 1] = a[0, np.newaxis, 1]
indices = [None] * ndim
for i in axes:
# slice nothing, a[0, slice(None), 1] = a[0, :, 1]
indices[i] = slice(None)
return tensor[indices]
def apply_rotary_position_embeddings(self, sinusoidal, *tensors):
"""应用RoPE到tensors中
其中sinusoidal.shape=[B, T, D]tensors为tensor的列表
tensor.shape=[B, T, ..., D], or (B,T,H,D/H)
"""
assert len(tensors) > 0, 'at least one input tensor'
assert all(
[tensor.shape == tensors[0].shape
for tensor in tensors[1:]]), 'all tensors must have the same shape'
ndim = tensors[0].dim()
# sinusoidal shape same with tensors[0]
# [B,T,D] -> [B,T,1,D]
sinusoidal = self.align(sinusoidal, [0, 1, -1], ndim)
# http://man.hubwiz.com/docset/TensorFlow.docset/Contents/Resources/Documents/api_docs/python/tf/keras/backend/repeat_elements.html
# like np.repeat, x (s1, s2, s3), axis 1, (s1, s2*rep, s3)
# [b,T, ..., d/2] -> [b,T, ..., d]
cos_pos = paddle.repeat_interleave(sinusoidal[..., 1::2], 2, axis=-1)
sin_pos = paddle.repeat_interleave(sinusoidal[..., 0::2], 2, axis=-1)
outputs = []
for tensor in tensors:
# x2 = [-x2, x1, -x4, x3, ..., -x_d, x_{d-1}]
tensor2 = paddle.stack([-tensor[..., 1::2], tensor[..., ::2]], ndim)
tensor2 = paddle.reshape(tensor2, paddle.shape(tensor))
# 公式 34, out = x * cos_pos + x2 * sin_pos
outputs.append(tensor * cos_pos + tensor2 * sin_pos)
return outputs[0] if len(outputs) == 1 else outputs
def forward(self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
pos_emb: paddle.Tensor=paddle.empty([0]),
cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
key (paddle.Tensor): Key tensor (#batch, time2, size).
value (paddle.Tensor): Value tensor (#batch, time2, size).
mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (paddle.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
paddle.Tensor: Output tensor (#batch, time1, d_model).
paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
# q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.shape[0] > 0:
# last dim `d_k * 2` for (key, val)
key_cache, value_cache = paddle.split(cache, 2, axis=-1)
k = paddle.concat([key_cache, k], axis=2)
v = paddle.concat([value_cache, v], axis=2)
# We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = paddle.concat((k, v), axis=-1)
# f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index
q, k = self.apply_rotary_position_embeddings(pos_emb, [q, k])
# dot(q, k)
scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache

@ -89,14 +89,17 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
self.max_len = max_len
self.xscale = paddle.to_tensor(math.sqrt(self.d_model))
self.dropout = nn.Dropout(p=dropout_rate)
self.base = 10000.0
self.pe = paddle.zeros([1, self.max_len, self.d_model]) #[B=1,T,D]
position = paddle.arange(
0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1]
# base^{-2(i-1)/d)}, i \in (1,2...,d/2)
div_term = paddle.exp(
paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
-(math.log(10000.0) / self.d_model))
-paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
(math.log(self.base) / self.d_model))
# [B,T,D]
self.pe[:, :, 0::2] = paddle.sin(position * div_term)
self.pe[:, :, 1::2] = paddle.cos(position * div_term)

@ -28,6 +28,7 @@ from paddlespeech.s2t.modules.align import LayerNorm
from paddlespeech.s2t.modules.align import Linear
from paddlespeech.s2t.modules.attention import MultiHeadedAttention
from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention
from paddlespeech.s2t.modules.attention import RoPERelPositionMultiHeadedAttention
from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule
from paddlespeech.s2t.modules.embedding import NoPositionalEncoding
from paddlespeech.s2t.modules.embedding import PositionalEncoding
@ -115,6 +116,8 @@ class BaseEncoder(nn.Layer):
pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "rel_pos":
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "rope_pos":
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "no_pos":
pos_enc_class = NoPositionalEncoding
else:
@ -237,7 +240,7 @@ class BaseEncoder(nn.Layer):
chunk_size = xs.shape[1]
attention_key_size = cache_t1 + chunk_size
# only used when using `RelPositionMultiHeadedAttention`
# only used when using `RelPositionMultiHeadedAttention` and `RoPERelPositionMultiHeadedAttention`
pos_emb = self.embed.position_encoding(
offset=offset - cache_t1, size=attention_key_size)
@ -474,9 +477,22 @@ class ConformerEncoder(BaseEncoder):
activation = get_activation(activation_type)
# self-attention module definition
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, output_size,
attention_dropout_rate)
if pos_enc_layer_type == "abs_pos":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, output_size,
attention_dropout_rate)
elif pos_enc_layer_type == "rel_pos":
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate)
elif pos_enc_layer_type == "rope_pos":
encoder_selfattn_layer = RoPERelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate)
else:
raise ValueError(
f"pos_enc_layer_type {pos_enc_layer_type} not supported.")
# feed-forward module definition
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (output_size, linear_units, dropout_rate,
@ -580,15 +596,23 @@ class SqueezeformerEncoder(nn.Layer):
activation = get_activation(activation_type)
# self-attention module definition
if pos_enc_layer_type != "rel_pos":
if pos_enc_layer_type == "abs_pos":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, output_size,
attention_dropout_rate)
else:
elif pos_enc_layer_type == "rel_pos":
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate,
adaptive_scale, init_weights)
elif pos_enc_layer_type == "rope_pos":
encoder_selfattn_layer = RoPERelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate,
adaptive_scale, init_weights)
else:
raise ValueError(
f"pos_enc_layer_type {pos_enc_layer_type} not supported.")
# feed-forward module definition
positionwise_layer = PositionwiseFeedForward

@ -48,7 +48,7 @@ class TransformerEncoderLayer(nn.Layer):
Args:
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
`MultiHeadedAttention`, `RelPositionMultiHeadedAttention` or `RoPERelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward`, instance can be used as the argument.
@ -147,7 +147,7 @@ class ConformerEncoderLayer(nn.Layer):
Args:
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
`MultiHeadedAttention`, `RelPositionMultiHeadedAttention` or `RoPERelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
@ -298,7 +298,7 @@ class SqueezeformerEncoderLayer(nn.Layer):
Args:
size (int): Input dimension.
self_attn (paddle.nn.Layer): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
`MultiHeadedAttention`, `RelPositionMultiHeadedAttention` or `RoPERelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward1 (paddle.nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.

Loading…
Cancel
Save