add roformer

pull/3407/head
Hui Zhang 12 months ago
parent 94987f26df
commit 03e9ea9e52

@ -0,0 +1,98 @@
############################################
# Network Architecture #
############################################
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rpoe_pos' # abs_pos, rel_pos, rope_pos
selfattention_layer_type: 'rel_selfattn' # unused
causal: true
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false
# decoder related
decoder: transformer # transformer, bitransformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
r_num_blocks: 3 # only for bitransformer
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
reverse_weight: 0.3 # only for bitransformer
length_normalized_loss: false
init_type: 'kaiming_uniform' # !Warning: need to convergence
###########################################
# Data #
###########################################
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
###########################################
# Dataloader #
###########################################
vocab_filepath: data/lang_char/vocab.txt
spm_model_prefix: ''
unit_type: 'char'
preprocess_config: conf/preprocess.yaml
feat_dim: 80
stride_ms: 10.0
window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 2
subsampling_factor: 1
num_encs: 1
###########################################
# Training #
###########################################
n_epoch: 240
accum_grad: 1
global_grad_clip: 5.0
dist_sampler: True
optim: adam
optim_conf:
lr: 0.001
weight_decay: 1.0e-6
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5

@ -0,0 +1,98 @@
############################################
# Network Architecture #
############################################
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rpoe_pos' # abs_pos, rel_pos, rope_pos
selfattention_layer_type: 'rel_selfattn' # unused
causal: true
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false
# decoder related
decoder: bitransformer # transformer, bitransformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 3
r_num_blocks: 3 # only for bitransformer
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
reverse_weight: 0.3 # only for bitransformer
length_normalized_loss: false
init_type: 'kaiming_uniform' # !Warning: need to convergence
###########################################
# Data #
###########################################
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
###########################################
# Dataloader #
###########################################
vocab_filepath: data/lang_char/vocab.txt
spm_model_prefix: ''
unit_type: 'char'
preprocess_config: conf/preprocess.yaml
feat_dim: 80
stride_ms: 10.0
window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 2
subsampling_factor: 1
num_encs: 1
###########################################
# Training #
###########################################
n_epoch: 240
accum_grad: 1
global_grad_clip: 5.0
dist_sampler: True
optim: adam
optim_conf:
lr: 0.001
weight_decay: 1.0e-6
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5

@ -26,7 +26,10 @@ from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = ["MultiHeadedAttention", "RelPositionMultiHeadedAttention"] __all__ = [
"MultiHeadedAttention", "RelPositionMultiHeadedAttention",
"RoPERelPositionMultiHeadedAttention"
]
# Relative Positional Encodings # Relative Positional Encodings
# https://www.jianshu.com/p/c0608efcc26f # https://www.jianshu.com/p/c0608efcc26f
@ -165,6 +168,7 @@ class MultiHeadedAttention(nn.Layer):
and `head * d_k == size` and `head * d_k == size`
""" """
# (B,T,D) -> (B,T,H,D/H)
q, k, v = self.forward_qkv(query, key, value) q, k, v = self.forward_qkv(query, key, value)
# when export onnx model, for 1st chunk, we feed # when export onnx model, for 1st chunk, we feed
@ -373,3 +377,131 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
self.d_k) # (batch, head, time1, time2) self.d_k) # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask), new_cache return self.forward_attention(v, scores, mask), new_cache
class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with RoPE relative position encoding."""
def __init__(self,
n_head,
n_feat,
dropout_rate,
adaptive_scale=False,
init_weights=False):
"""Construct an RelPositionMultiHeadedAttention object.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
super().__init__(n_head, n_feat, dropout_rate)
def align(self, tensor: paddle.Tensor, axes: List[int], ndim=None):
"""重新对齐tensor批量版expand_dims
axes原来的第i维对齐新tensor的第axes[i]
ndim新tensor的维度
"""
assert len(axes) == tensor.dim()
assert ndim or min(axes) >= 0
ndim = ndim or max(axes) + 1
# a[0, None, 1] = a[0, np.newaxis, 1]
indices = [None] * ndim
for i in axes:
# slice nothing, a[0, slice(None), 1] = a[0, :, 1]
indices[i] = slice(None)
return tensor[indices]
def apply_rotary_position_embeddings(self, sinusoidal, *tensors):
"""应用RoPE到tensors中
其中sinusoidal.shape=[B, T, D]tensors为tensor的列表
tensor.shape=[B, T, ..., D], or (B,T,H,D/H)
"""
assert len(tensors) > 0, 'at least one input tensor'
assert all(
[tensor.shape == tensors[0].shape
for tensor in tensors[1:]]), 'all tensors must have the same shape'
ndim = tensors[0].dim()
# sinusoidal shape same with tensors[0]
# [B,T,D] -> [B,T,1,D]
sinusoidal = self.align(sinusoidal, [0, 1, -1], ndim)
# http://man.hubwiz.com/docset/TensorFlow.docset/Contents/Resources/Documents/api_docs/python/tf/keras/backend/repeat_elements.html
# like np.repeat, x (s1, s2, s3), axis 1, (s1, s2*rep, s3)
# [b,T, ..., d/2] -> [b,T, ..., d]
cos_pos = paddle.repeat_interleave(sinusoidal[..., 1::2], 2, axis=-1)
sin_pos = paddle.repeat_interleave(sinusoidal[..., 0::2], 2, axis=-1)
outputs = []
for tensor in tensors:
# x2 = [-x2, x1, -x4, x3, ..., -x_d, x_{d-1}]
tensor2 = paddle.stack([-tensor[..., 1::2], tensor[..., ::2]], ndim)
tensor2 = paddle.reshape(tensor2, paddle.shape(tensor))
# 公式 34, out = x * cos_pos + x2 * sin_pos
outputs.append(tensor * cos_pos + tensor2 * sin_pos)
return outputs[0] if len(outputs) == 1 else outputs
def forward(self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
pos_emb: paddle.Tensor=paddle.empty([0]),
cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
key (paddle.Tensor): Key tensor (#batch, time2, size).
value (paddle.Tensor): Value tensor (#batch, time2, size).
mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (paddle.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
paddle.Tensor: Output tensor (#batch, time1, d_model).
paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
# q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.shape[0] > 0:
# last dim `d_k * 2` for (key, val)
key_cache, value_cache = paddle.split(cache, 2, axis=-1)
k = paddle.concat([key_cache, k], axis=2)
v = paddle.concat([value_cache, v], axis=2)
# We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = paddle.concat((k, v), axis=-1)
# f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index
q, k = self.apply_rotary_position_embeddings(pos_emb, [q, k])
# dot(q, k)
scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache

@ -89,14 +89,17 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
self.max_len = max_len self.max_len = max_len
self.xscale = paddle.to_tensor(math.sqrt(self.d_model)) self.xscale = paddle.to_tensor(math.sqrt(self.d_model))
self.dropout = nn.Dropout(p=dropout_rate) self.dropout = nn.Dropout(p=dropout_rate)
self.base = 10000.0
self.pe = paddle.zeros([1, self.max_len, self.d_model]) #[B=1,T,D] self.pe = paddle.zeros([1, self.max_len, self.d_model]) #[B=1,T,D]
position = paddle.arange( position = paddle.arange(
0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1] 0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1]
# base^{-2(i-1)/d)}, i \in (1,2...,d/2)
div_term = paddle.exp( div_term = paddle.exp(
paddle.arange(0, self.d_model, 2, dtype=paddle.float32) * -paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
-(math.log(10000.0) / self.d_model)) (math.log(self.base) / self.d_model))
# [B,T,D]
self.pe[:, :, 0::2] = paddle.sin(position * div_term) self.pe[:, :, 0::2] = paddle.sin(position * div_term)
self.pe[:, :, 1::2] = paddle.cos(position * div_term) self.pe[:, :, 1::2] = paddle.cos(position * div_term)

@ -28,6 +28,7 @@ from paddlespeech.s2t.modules.align import LayerNorm
from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.modules.align import Linear
from paddlespeech.s2t.modules.attention import MultiHeadedAttention from paddlespeech.s2t.modules.attention import MultiHeadedAttention
from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention
from paddlespeech.s2t.modules.attention import RoPERelPositionMultiHeadedAttention
from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule
from paddlespeech.s2t.modules.embedding import NoPositionalEncoding from paddlespeech.s2t.modules.embedding import NoPositionalEncoding
from paddlespeech.s2t.modules.embedding import PositionalEncoding from paddlespeech.s2t.modules.embedding import PositionalEncoding
@ -115,6 +116,8 @@ class BaseEncoder(nn.Layer):
pos_enc_class = PositionalEncoding pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "rel_pos": elif pos_enc_layer_type == "rel_pos":
pos_enc_class = RelPositionalEncoding pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "rope_pos":
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "no_pos": elif pos_enc_layer_type == "no_pos":
pos_enc_class = NoPositionalEncoding pos_enc_class = NoPositionalEncoding
else: else:
@ -237,7 +240,7 @@ class BaseEncoder(nn.Layer):
chunk_size = xs.shape[1] chunk_size = xs.shape[1]
attention_key_size = cache_t1 + chunk_size attention_key_size = cache_t1 + chunk_size
# only used when using `RelPositionMultiHeadedAttention` # only used when using `RelPositionMultiHeadedAttention` and `RoPERelPositionMultiHeadedAttention`
pos_emb = self.embed.position_encoding( pos_emb = self.embed.position_encoding(
offset=offset - cache_t1, size=attention_key_size) offset=offset - cache_t1, size=attention_key_size)
@ -474,9 +477,22 @@ class ConformerEncoder(BaseEncoder):
activation = get_activation(activation_type) activation = get_activation(activation_type)
# self-attention module definition # self-attention module definition
encoder_selfattn_layer = RelPositionMultiHeadedAttention if pos_enc_layer_type == "abs_pos":
encoder_selfattn_layer_args = (attention_heads, output_size, encoder_selfattn_layer = MultiHeadedAttention
attention_dropout_rate) encoder_selfattn_layer_args = (attention_heads, output_size,
attention_dropout_rate)
elif pos_enc_layer_type == "rel_pos":
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate)
elif pos_enc_layer_type == "rope_pos":
encoder_selfattn_layer = RoPERelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate)
else:
raise ValueError(
f"pos_enc_layer_type {pos_enc_layer_type} not supported.")
# feed-forward module definition # feed-forward module definition
positionwise_layer = PositionwiseFeedForward positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (output_size, linear_units, dropout_rate, positionwise_layer_args = (output_size, linear_units, dropout_rate,
@ -580,15 +596,23 @@ class SqueezeformerEncoder(nn.Layer):
activation = get_activation(activation_type) activation = get_activation(activation_type)
# self-attention module definition # self-attention module definition
if pos_enc_layer_type != "rel_pos": if pos_enc_layer_type == "abs_pos":
encoder_selfattn_layer = MultiHeadedAttention encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, output_size, encoder_selfattn_layer_args = (attention_heads, output_size,
attention_dropout_rate) attention_dropout_rate)
else: elif pos_enc_layer_type == "rel_pos":
encoder_selfattn_layer = RelPositionMultiHeadedAttention encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, encoder_dim, encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate, attention_dropout_rate,
adaptive_scale, init_weights) adaptive_scale, init_weights)
elif pos_enc_layer_type == "rope_pos":
encoder_selfattn_layer = RoPERelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate,
adaptive_scale, init_weights)
else:
raise ValueError(
f"pos_enc_layer_type {pos_enc_layer_type} not supported.")
# feed-forward module definition # feed-forward module definition
positionwise_layer = PositionwiseFeedForward positionwise_layer = PositionwiseFeedForward

@ -48,7 +48,7 @@ class TransformerEncoderLayer(nn.Layer):
Args: Args:
size (int): Input dimension. size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance. self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` `MultiHeadedAttention`, `RelPositionMultiHeadedAttention` or `RoPERelPositionMultiHeadedAttention`
instance can be used as the argument. instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance. feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward`, instance can be used as the argument. `PositionwiseFeedForward`, instance can be used as the argument.
@ -147,7 +147,7 @@ class ConformerEncoderLayer(nn.Layer):
Args: Args:
size (int): Input dimension. size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance. self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` `MultiHeadedAttention`, `RelPositionMultiHeadedAttention` or `RoPERelPositionMultiHeadedAttention`
instance can be used as the argument. instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance. feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument. `PositionwiseFeedForward` instance can be used as the argument.
@ -298,7 +298,7 @@ class SqueezeformerEncoderLayer(nn.Layer):
Args: Args:
size (int): Input dimension. size (int): Input dimension.
self_attn (paddle.nn.Layer): Self-attention module instance. self_attn (paddle.nn.Layer): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` `MultiHeadedAttention`, `RelPositionMultiHeadedAttention` or `RoPERelPositionMultiHeadedAttention`
instance can be used as the argument. instance can be used as the argument.
feed_forward1 (paddle.nn.Layer): Feed-forward module instance. feed_forward1 (paddle.nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument. `PositionwiseFeedForward` instance can be used as the argument.

Loading…
Cancel
Save