support export for asr1

pull/3572/head
SigureMo 2 years ago
parent 1dc67f96e0
commit c79d88462f

@ -3,8 +3,8 @@ set -e
source path.sh
gpus=4
stage=0
stop_stage=50
stage=5
stop_stage=5
conf_path=conf/transformer.yaml
ips= #xx.xx.xx.xx,xx.xx.xx.xx
decode_conf_path=conf/tuning/decode.yaml
@ -41,7 +41,7 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
CUDA_VISIBLE_DEVICES=${gpus} ./local/align.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 51 ] && [ ${stop_stage} -ge 51 ]; then
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi

@ -534,8 +534,8 @@ class U2Tester(U2Trainer):
paddle.static.InputSpec(
shape=[None, None, None, None], dtype='float32'),
# cnn_cache
paddle.static.InputSpec(
shape=[None, None, None, None], dtype='float32')
# paddle.static.InputSpec(
# shape=[None, None, None, None], dtype='float32')
]
infer_model.forward_encoder_chunk = paddle.jit.to_static(
infer_model.forward_encoder_chunk, input_spec=input_spec)
@ -590,9 +590,11 @@ class U2Tester(U2Trainer):
offset = paddle.to_tensor([0], dtype='int32')
required_cache_size = num_left_chunks
att_cache = paddle.zeros([0, 0, 0, 0])
cnn_cache = paddle.zeros([0, 0, 0, 0])
xs_d, att_cache_d, cnn_cache_d = infer_model.forward_encoder_chunk(
xs1, offset, required_cache_size, att_cache, cnn_cache)
# cnn_cache = paddle.zeros([0, 0, 0, 0])
xs_d, att_cache_d = infer_model.forward_encoder_chunk(
xs1, offset, required_cache_size, att_cache,
# cnn_cache
)
# load static model
from paddle.jit.layer import Layer
@ -604,10 +606,12 @@ class U2Tester(U2Trainer):
xs1 = paddle.full([1, 67, 80], 0.1, dtype='float32')
offset = paddle.to_tensor([0], dtype='int32')
att_cache = paddle.zeros([0, 0, 0, 0])
cnn_cache = paddle.zeros([0, 0, 0, 0])
# cnn_cache = paddle.zeros([0, 0, 0, 0])
func = getattr(layer, 'forward_encoder_chunk')
xs_s, att_cache_s, cnn_cache_s = func(xs1, offset, att_cache, cnn_cache)
xs_s, att_cache_s = func(xs1, offset, att_cache,
# cnn_cache
)
np.testing.assert_allclose(xs_d, xs_s, atol=1e-5)
np.testing.assert_allclose(att_cache_d, att_cache_s, atol=1e-4)
np.testing.assert_allclose(cnn_cache_d, cnn_cache_s, atol=1e-4)
# np.testing.assert_allclose(cnn_cache_d, cnn_cache_s, atol=1e-4)
# logger.info(f"forward_encoder_chunk output: {xs_s}")

@ -651,7 +651,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
offset: int,
required_cache_size: int,
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
# cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.
@ -685,7 +685,9 @@ class U2BaseModel(ASRInterface, nn.Layer):
same shape as the original cnn_cache.
"""
return self.encoder.forward_chunk(xs, offset, required_cache_size,
att_cache, cnn_cache)
att_cache,
# cnn_cache
)
# @jit.to_static
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:

@ -200,7 +200,7 @@ class BaseEncoder(nn.Layer):
offset: int,
required_cache_size: int,
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
# cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
att_mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool)
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Forward just one chunk
@ -252,7 +252,7 @@ class BaseEncoder(nn.Layer):
next_cache_start = max(attention_key_size - required_cache_size, 0)
r_att_cache = []
r_cnn_cache = []
# r_cnn_cache = []
for i, layer in enumerate(self.encoders):
# att_cache[i:i+1] = (1, head, cache_t1, d_k*2)
# cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2)
@ -262,25 +262,27 @@ class BaseEncoder(nn.Layer):
# raw code as below:
# att_cache=att_cache[i:i+1] if elayers > 0 else att_cache,
# cnn_cache=cnn_cache[i:i+1] if cnn_cache.shape[0] > 0 else cnn_cache,
xs, _, new_att_cache, new_cnn_cache = layer(
xs, _, new_att_cache = layer(
xs,
att_mask,
pos_emb,
att_cache=att_cache[i:i + 1],
cnn_cache=cnn_cache[i:i + 1], )
# cnn_cache=cnn_cache[i:i + 1],
)
# new_att_cache = (1, head, attention_key_size, d_k*2)
# new_cnn_cache = (B=1, hidden-dim, cache_t2)
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
r_cnn_cache.append(new_cnn_cache) # add elayer dim
# r_cnn_cache.append(new_cnn_cache) # add elayer dim
if self.normalize_before:
xs = self.after_norm(xs)
# r_att_cache (elayers, head, T, d_k*2)
# r_cnn_cache (elayers, B=1, hidden-dim, cache_t2)
# breakpoint()
r_att_cache = paddle.concat(r_att_cache, axis=0)
r_cnn_cache = paddle.stack(r_cnn_cache, axis=0)
return xs, r_att_cache, r_cnn_cache
# r_cnn_cache = paddle.stack(r_cnn_cache, axis=0)
return xs, r_att_cache#, r_cnn_cache
def forward_chunk_by_chunk(
self,

@ -81,7 +81,7 @@ class TransformerEncoderLayer(nn.Layer):
pos_emb: paddle.Tensor,
mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
# cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features.
Args:
@ -125,8 +125,8 @@ class TransformerEncoderLayer(nn.Layer):
if not self.normalize_before:
x = self.norm2(x)
fake_cnn_cache = paddle.zeros([0, 0, 0], dtype=x.dtype)
return x, mask, new_att_cache, fake_cnn_cache
# fake_cnn_cache = paddle.zeros([0, 0, 0], dtype=x.dtype)
return x, mask, new_att_cache#, fake_cnn_cache
class ConformerEncoderLayer(nn.Layer):

Loading…
Cancel
Save