|
|
|
@ -534,8 +534,8 @@ class U2Tester(U2Trainer):
|
|
|
|
|
paddle.static.InputSpec(
|
|
|
|
|
shape=[None, None, None, None], dtype='float32'),
|
|
|
|
|
# cnn_cache
|
|
|
|
|
paddle.static.InputSpec(
|
|
|
|
|
shape=[None, None, None, None], dtype='float32')
|
|
|
|
|
# paddle.static.InputSpec(
|
|
|
|
|
# shape=[None, None, None, None], dtype='float32')
|
|
|
|
|
]
|
|
|
|
|
infer_model.forward_encoder_chunk = paddle.jit.to_static(
|
|
|
|
|
infer_model.forward_encoder_chunk, input_spec=input_spec)
|
|
|
|
@ -590,9 +590,11 @@ class U2Tester(U2Trainer):
|
|
|
|
|
offset = paddle.to_tensor([0], dtype='int32')
|
|
|
|
|
required_cache_size = num_left_chunks
|
|
|
|
|
att_cache = paddle.zeros([0, 0, 0, 0])
|
|
|
|
|
cnn_cache = paddle.zeros([0, 0, 0, 0])
|
|
|
|
|
xs_d, att_cache_d, cnn_cache_d = infer_model.forward_encoder_chunk(
|
|
|
|
|
xs1, offset, required_cache_size, att_cache, cnn_cache)
|
|
|
|
|
# cnn_cache = paddle.zeros([0, 0, 0, 0])
|
|
|
|
|
xs_d, att_cache_d = infer_model.forward_encoder_chunk(
|
|
|
|
|
xs1, offset, required_cache_size, att_cache,
|
|
|
|
|
# cnn_cache
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# load static model
|
|
|
|
|
from paddle.jit.layer import Layer
|
|
|
|
@ -604,10 +606,12 @@ class U2Tester(U2Trainer):
|
|
|
|
|
xs1 = paddle.full([1, 67, 80], 0.1, dtype='float32')
|
|
|
|
|
offset = paddle.to_tensor([0], dtype='int32')
|
|
|
|
|
att_cache = paddle.zeros([0, 0, 0, 0])
|
|
|
|
|
cnn_cache = paddle.zeros([0, 0, 0, 0])
|
|
|
|
|
# cnn_cache = paddle.zeros([0, 0, 0, 0])
|
|
|
|
|
func = getattr(layer, 'forward_encoder_chunk')
|
|
|
|
|
xs_s, att_cache_s, cnn_cache_s = func(xs1, offset, att_cache, cnn_cache)
|
|
|
|
|
xs_s, att_cache_s = func(xs1, offset, att_cache,
|
|
|
|
|
# cnn_cache
|
|
|
|
|
)
|
|
|
|
|
np.testing.assert_allclose(xs_d, xs_s, atol=1e-5)
|
|
|
|
|
np.testing.assert_allclose(att_cache_d, att_cache_s, atol=1e-4)
|
|
|
|
|
np.testing.assert_allclose(cnn_cache_d, cnn_cache_s, atol=1e-4)
|
|
|
|
|
# np.testing.assert_allclose(cnn_cache_d, cnn_cache_s, atol=1e-4)
|
|
|
|
|
# logger.info(f"forward_encoder_chunk output: {xs_s}")
|
|
|
|
|