RoPE with position interpolation

pull/3407/head
Hui Zhang 1 year ago
parent b91b1c9b08
commit b56fb85ca0

@ -20,30 +20,6 @@ import numpy as np
import paddle import paddle
def define_argparse():
parser = argparse.ArgumentParser(description='average model')
parser.add_argument('--dst_model', required=True, help='averaged model')
parser.add_argument(
'--ckpt_dir', required=True, help='ckpt model dir for average')
parser.add_argument(
'--val_best', action="store_true", help='averaged model')
parser.add_argument(
'--num', default=5, type=int, help='nums for averaged model')
parser.add_argument(
'--min_epoch',
default=0,
type=int,
help='min epoch used for averaging model')
parser.add_argument(
'--max_epoch',
default=65536, # Big enough
type=int,
help='max epoch used for averaging model')
args = parser.parse_args()
return args
def average_checkpoints(dst_model="", def average_checkpoints(dst_model="",
ckpt_dir="", ckpt_dir="",
val_best=True, val_best=True,
@ -85,7 +61,7 @@ def average_checkpoints(dst_model="",
print(path_list) print(path_list)
avg = None avg = None
num = args.num num = num
assert num == len(path_list) assert num == len(path_list)
for path in path_list: for path in path_list:
print(f'Processing {path}') print(f'Processing {path}')
@ -100,14 +76,14 @@ def average_checkpoints(dst_model="",
if avg[k] is not None: if avg[k] is not None:
avg[k] /= num avg[k] /= num
paddle.save(avg, args.dst_model) paddle.save(avg, dst_model)
print(f'Saving to {args.dst_model}') print(f'Saving to {dst_model}')
meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json' meta_path = os.path.splitext(dst_model)[0] + '.avg.json'
with open(meta_path, 'w') as f: with open(meta_path, 'w') as f:
data = json.dumps({ data = json.dumps({
"mode": 'val_best' if args.val_best else 'latest', "mode": 'val_best' if val_best else 'latest',
"avg_ckpt": args.dst_model, "avg_ckpt": dst_model,
"val_loss_mean": avg_val_score, "val_loss_mean": avg_val_score,
"ckpts": path_list, "ckpts": path_list,
"epochs": selected_epochs.tolist(), "epochs": selected_epochs.tolist(),
@ -116,9 +92,40 @@ def average_checkpoints(dst_model="",
f.write(data + "\n") f.write(data + "\n")
def define_argparse():
parser = argparse.ArgumentParser(description='average model')
parser.add_argument('--dst_model', required=True, help='averaged model')
parser.add_argument(
'--ckpt_dir', required=True, help='ckpt model dir for average')
parser.add_argument(
'--val_best', action="store_true", help='averaged model')
parser.add_argument(
'--num', default=5, type=int, help='nums for averaged model')
parser.add_argument(
'--min_epoch',
default=0,
type=int,
help='min epoch used for averaging model')
parser.add_argument(
'--max_epoch',
default=65536, # Big enough
type=int,
help='max epoch used for averaging model')
args = parser.parse_args()
print(args)
return args
def main(): def main():
args = define_argparse() args = define_argparse()
average_checkpoints(args) average_checkpoints(
dst_model=args.dst_model,
ckpt_dir=args.ckpt_dir,
val_best=args.val_best,
num=args.num,
min_epoch=args.min_epoch,
max_epoch=args.max_epoch)
if __name__ == '__main__': if __name__ == '__main__':

@ -85,11 +85,11 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
reverse (bool, optional): Not used. Defaults to False. reverse (bool, optional): Not used. Defaults to False.
""" """
nn.Layer.__init__(self) nn.Layer.__init__(self)
self.d_model = d_model self.d_model = paddle.to_tensor(d_model)
self.max_len = max_len self.max_len = max_len
self.xscale = paddle.to_tensor(math.sqrt(self.d_model)) self.xscale = paddle.to_tensor(math.sqrt(self.d_model))
self.dropout = nn.Dropout(p=dropout_rate) self.dropout = nn.Dropout(p=dropout_rate)
self.base = 10000.0 self.base = paddle.to_tensor(10000.0)
self.pe = paddle.zeros([1, self.max_len, self.d_model]) #[B=1,T,D] self.pe = paddle.zeros([1, self.max_len, self.d_model]) #[B=1,T,D]
position = paddle.arange( position = paddle.arange(
@ -97,7 +97,7 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
# base^{-2(i-1)/d)}, i \in (1,2...,d/2) # base^{-2(i-1)/d)}, i \in (1,2...,d/2)
div_term = paddle.exp( div_term = paddle.exp(
-paddle.arange(0, self.d_model, 2, dtype=paddle.float32) * -paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
(math.log(self.base) / self.d_model)) (paddle.log(self.base) / self.d_model))
# [B,T,D] # [B,T,D]
self.pe[:, :, 0::2] = paddle.sin(position * div_term) self.pe[:, :, 0::2] = paddle.sin(position * div_term)
@ -188,19 +188,73 @@ class ScaledRotaryRelPositionalEncoding(RelPositionalEncoding):
scale (int): Interpolation max input length to `scale * max_len` positions. scale (int): Interpolation max input length to `scale * max_len` positions.
""" """
super().__init__(d_model, dropout_rate, max_len, reverse=True) super().__init__(d_model, dropout_rate, max_len, reverse=True)
self.scale = scale self.pscale = paddle.to_tensor(scale)
self.max_len = max_len * scale self.max_len = max_len * scale
def sinusoidal_embeddings(self,
pos: paddle.Tensor,
dim: paddle.Tensor,
base=10000) -> paddle.Tensor:
"""计算pos位置的dim维sinusoidal编码"""
assert dim % 2 == 0
# (d/2,)
indices = paddle.arange(0, dim // 2, dtype=pos.dtype)
indices = paddle.pow(paddle.cast(base, pos.dtype), -2 * indices / dim)
# pos (1, T), indices (d/2,) -> (1, T, d/2)
embeddings = paddle.einsum('...,d->...d', pos, indices)
# (1, T, d/2, 2)
embeddings = paddle.stack(
[paddle.sin(embeddings), paddle.cos(embeddings)], axis=-1)
# (1, T, d)
embeddings = paddle.flatten(embeddings, start_axis=-2, stop_axis=-1)
return embeddings
def forward(self, x: paddle.Tensor,
offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute positional encoding.
Args:
x (paddle.Tensor): Input tensor (batch, time, `*`).
Returns:
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
x = x * self.xscale
B, T, D = x.shape
assert D == self.d_model
# postion interploation
start = 0
end = T * self.pscale
assert end <= self.max_len
position = paddle.arange(start, end, dtype=x.dtype).unsqueeze(0)
position *= 1.0 / self.pscale
pe = self.sinusoidal_embeddings(position, self.d_model, base=self.base)
pos_emb = pe[:, offset:offset + x.shape[1]]
return self.dropout(x), self.dropout(pos_emb)
def position_encoding(self, offset: int, size: int) -> paddle.Tensor:
""" For getting encoding in a streaming fashion
Attention!!!!!
we apply dropout only once at the whole utterance level in a none
streaming way, but will call this function several times with
increasing input size in a streaming scenario, so the dropout will
be applied several times.
Args:
offset (int): start offset
size (int): requried size of position encoding
Returns:
paddle.Tensor: Corresponding position encoding, #[1, T, D].
"""
# postion interploation
start = offset
end = (offset + size) * self.pscale
assert end <= self.max_len
position = paddle.arange( position = paddle.arange(
0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1] start, end, dtype=paddle.get_default_dtype()).unsqueeze(0)
# position interpoloation position *= 1.0 / self.pscale
position *= 1.0 / self.scale
# base^{-2(i-1)/d)}, i \in (1,2...,d/2) pe = self.sinusoidal_embeddings(position, self.d_model, base=self.base)
div_term = paddle.exp(
-paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
(math.log(self.base) / self.d_model))
# [B,T,D] return self.dropout(pe)
self.pe[:, :, 0::2] = paddle.sin(position * div_term)
self.pe[:, :, 1::2] = paddle.cos(position * div_term)

Loading…
Cancel
Save