RoPE with position interpolation

pull/3407/head
Hui Zhang 12 months ago
parent b91b1c9b08
commit b56fb85ca0

@ -20,30 +20,6 @@ import numpy as np
import paddle
def define_argparse():
parser = argparse.ArgumentParser(description='average model')
parser.add_argument('--dst_model', required=True, help='averaged model')
parser.add_argument(
'--ckpt_dir', required=True, help='ckpt model dir for average')
parser.add_argument(
'--val_best', action="store_true", help='averaged model')
parser.add_argument(
'--num', default=5, type=int, help='nums for averaged model')
parser.add_argument(
'--min_epoch',
default=0,
type=int,
help='min epoch used for averaging model')
parser.add_argument(
'--max_epoch',
default=65536, # Big enough
type=int,
help='max epoch used for averaging model')
args = parser.parse_args()
return args
def average_checkpoints(dst_model="",
ckpt_dir="",
val_best=True,
@ -85,7 +61,7 @@ def average_checkpoints(dst_model="",
print(path_list)
avg = None
num = args.num
num = num
assert num == len(path_list)
for path in path_list:
print(f'Processing {path}')
@ -100,14 +76,14 @@ def average_checkpoints(dst_model="",
if avg[k] is not None:
avg[k] /= num
paddle.save(avg, args.dst_model)
print(f'Saving to {args.dst_model}')
paddle.save(avg, dst_model)
print(f'Saving to {dst_model}')
meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json'
meta_path = os.path.splitext(dst_model)[0] + '.avg.json'
with open(meta_path, 'w') as f:
data = json.dumps({
"mode": 'val_best' if args.val_best else 'latest',
"avg_ckpt": args.dst_model,
"mode": 'val_best' if val_best else 'latest',
"avg_ckpt": dst_model,
"val_loss_mean": avg_val_score,
"ckpts": path_list,
"epochs": selected_epochs.tolist(),
@ -116,9 +92,40 @@ def average_checkpoints(dst_model="",
f.write(data + "\n")
def define_argparse():
parser = argparse.ArgumentParser(description='average model')
parser.add_argument('--dst_model', required=True, help='averaged model')
parser.add_argument(
'--ckpt_dir', required=True, help='ckpt model dir for average')
parser.add_argument(
'--val_best', action="store_true", help='averaged model')
parser.add_argument(
'--num', default=5, type=int, help='nums for averaged model')
parser.add_argument(
'--min_epoch',
default=0,
type=int,
help='min epoch used for averaging model')
parser.add_argument(
'--max_epoch',
default=65536, # Big enough
type=int,
help='max epoch used for averaging model')
args = parser.parse_args()
print(args)
return args
def main():
args = define_argparse()
average_checkpoints(args)
average_checkpoints(
dst_model=args.dst_model,
ckpt_dir=args.ckpt_dir,
val_best=args.val_best,
num=args.num,
min_epoch=args.min_epoch,
max_epoch=args.max_epoch)
if __name__ == '__main__':

@ -85,11 +85,11 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
reverse (bool, optional): Not used. Defaults to False.
"""
nn.Layer.__init__(self)
self.d_model = d_model
self.d_model = paddle.to_tensor(d_model)
self.max_len = max_len
self.xscale = paddle.to_tensor(math.sqrt(self.d_model))
self.dropout = nn.Dropout(p=dropout_rate)
self.base = 10000.0
self.base = paddle.to_tensor(10000.0)
self.pe = paddle.zeros([1, self.max_len, self.d_model]) #[B=1,T,D]
position = paddle.arange(
@ -97,7 +97,7 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
# base^{-2(i-1)/d)}, i \in (1,2...,d/2)
div_term = paddle.exp(
-paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
(math.log(self.base) / self.d_model))
(paddle.log(self.base) / self.d_model))
# [B,T,D]
self.pe[:, :, 0::2] = paddle.sin(position * div_term)
@ -188,19 +188,73 @@ class ScaledRotaryRelPositionalEncoding(RelPositionalEncoding):
scale (int): Interpolation max input length to `scale * max_len` positions.
"""
super().__init__(d_model, dropout_rate, max_len, reverse=True)
self.scale = scale
self.pscale = paddle.to_tensor(scale)
self.max_len = max_len * scale
def sinusoidal_embeddings(self,
pos: paddle.Tensor,
dim: paddle.Tensor,
base=10000) -> paddle.Tensor:
"""计算pos位置的dim维sinusoidal编码"""
assert dim % 2 == 0
# (d/2,)
indices = paddle.arange(0, dim // 2, dtype=pos.dtype)
indices = paddle.pow(paddle.cast(base, pos.dtype), -2 * indices / dim)
# pos (1, T), indices (d/2,) -> (1, T, d/2)
embeddings = paddle.einsum('...,d->...d', pos, indices)
# (1, T, d/2, 2)
embeddings = paddle.stack(
[paddle.sin(embeddings), paddle.cos(embeddings)], axis=-1)
# (1, T, d)
embeddings = paddle.flatten(embeddings, start_axis=-2, stop_axis=-1)
return embeddings
def forward(self, x: paddle.Tensor,
offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute positional encoding.
Args:
x (paddle.Tensor): Input tensor (batch, time, `*`).
Returns:
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
x = x * self.xscale
B, T, D = x.shape
assert D == self.d_model
# postion interploation
start = 0
end = T * self.pscale
assert end <= self.max_len
position = paddle.arange(start, end, dtype=x.dtype).unsqueeze(0)
position *= 1.0 / self.pscale
pe = self.sinusoidal_embeddings(position, self.d_model, base=self.base)
pos_emb = pe[:, offset:offset + x.shape[1]]
return self.dropout(x), self.dropout(pos_emb)
def position_encoding(self, offset: int, size: int) -> paddle.Tensor:
""" For getting encoding in a streaming fashion
Attention!!!!!
we apply dropout only once at the whole utterance level in a none
streaming way, but will call this function several times with
increasing input size in a streaming scenario, so the dropout will
be applied several times.
Args:
offset (int): start offset
size (int): requried size of position encoding
Returns:
paddle.Tensor: Corresponding position encoding, #[1, T, D].
"""
# postion interploation
start = offset
end = (offset + size) * self.pscale
assert end <= self.max_len
position = paddle.arange(
0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1]
# position interpoloation
position *= 1.0 / self.scale
start, end, dtype=paddle.get_default_dtype()).unsqueeze(0)
position *= 1.0 / self.pscale
# base^{-2(i-1)/d)}, i \in (1,2...,d/2)
div_term = paddle.exp(
-paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
(math.log(self.base) / self.d_model))
pe = self.sinusoidal_embeddings(position, self.d_model, base=self.base)
# [B,T,D]
self.pe[:, :, 0::2] = paddle.sin(position * div_term)
self.pe[:, :, 1::2] = paddle.cos(position * div_term)
return self.dropout(pe)

Loading…
Cancel
Save