configurable export

pull/2212/head
Hui Zhang 2 years ago
parent 63aeb747b0
commit 1c9f238ba0

@ -462,31 +462,37 @@ class U2Tester(U2Trainer):
infer_model = U2InferModel.from_pretrained(self.test_loader, infer_model = U2InferModel.from_pretrained(self.test_loader,
self.config.clone(), self.config.clone(),
self.args.checkpoint_path) self.args.checkpoint_path)
batch_size = 1
feat_dim = self.test_loader.feat_dim feat_dim = self.test_loader.feat_dim
input_spec = [ model_size = 512
paddle.static.InputSpec(shape=[1, None, feat_dim], num_left_chunks = -1
dtype='float32'), # audio, [B,T,D]
paddle.static.InputSpec(shape=[1], return infer_model, (batch_size, feat_dim, model_size, num_left_chunks)
dtype='int64'), # audio_length, [B]
]
return infer_model, input_spec
@paddle.no_grad() @paddle.no_grad()
def export(self): def export(self):
infer_model, input_spec = self.load_inferspec() infer_model, input_spec = self.load_inferspec()
assert isinstance(input_spec, list), type(input_spec)
del input_spec
infer_model.eval() infer_model.eval()
######################### infer_model.forward_encoder_chunk zero Tensor online ############ assert isinstance(input_spec, list), type(input_spec)
batch_size, feat_dim, model_size, num_left_chunks = input_spec
######################### infer_model.forward_encoder_chunk zero tensor online ############
# TODO: 80(feature dim) be configable # TODO: 80(feature dim) be configable
input_spec = [ input_spec = [
paddle.static.InputSpec(shape=[1, None, 80], dtype='float32'), # xs, (B, T, D)
paddle.static.InputSpec(shape=[batch_size, None, feat_dim], dtype='float32'),
# offset, int, but need be tensor
paddle.static.InputSpec(shape=[1], dtype='int32'), paddle.static.InputSpec(shape=[1], dtype='int32'),
-1, # required_cache_size, int
num_left_chunks,
# att_cache
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, None, None, None], shape=[None, None, None, None],
dtype='float32'), dtype='float32'),
# cnn_cache
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, None, None, None], dtype='float32') shape=[None, None, None, None], dtype='float32')
] ]
@ -496,9 +502,12 @@ class U2Tester(U2Trainer):
######################### infer_model.forward_attention_decoder ######################## ######################### infer_model.forward_attention_decoder ########################
# TODO: 512(encoder_output) be configable. 1 for BatchSize # TODO: 512(encoder_output) be configable. 1 for BatchSize
input_spec = [ input_spec = [
# hyps, (B, U)
paddle.static.InputSpec(shape=[None, None], dtype='int64'), paddle.static.InputSpec(shape=[None, None], dtype='int64'),
# hyps_lens, (B,)
paddle.static.InputSpec(shape=[None], dtype='int64'), paddle.static.InputSpec(shape=[None], dtype='int64'),
paddle.static.InputSpec(shape=[1, None, 512], dtype='float32') # encoder_out, (B,T,D)
paddle.static.InputSpec(shape=[batch_size, None, model_size], dtype='float32')
] ]
infer_model.forward_attention_decoder = paddle.jit.to_static( infer_model.forward_attention_decoder = paddle.jit.to_static(
infer_model.forward_attention_decoder, input_spec=input_spec) infer_model.forward_attention_decoder, input_spec=input_spec)
@ -529,7 +538,7 @@ class U2Tester(U2Trainer):
xs1 = paddle.rand(shape=[1, 67, 80], dtype='float32') xs1 = paddle.rand(shape=[1, 67, 80], dtype='float32')
offset = paddle.to_tensor([0], dtype='int32') offset = paddle.to_tensor([0], dtype='int32')
required_cache_size = -16 required_cache_size = num_left_chunks
att_cache = paddle.zeros([0, 0, 0, 0]) att_cache = paddle.zeros([0, 0, 0, 0])
cnn_cache = paddle.zeros([0, 0, 0, 0]) cnn_cache = paddle.zeros([0, 0, 0, 0])

Loading…
Cancel
Save