fix jit save for conformer

pull/2212/head
Hui Zhang 2 years ago
parent 4e7106d9e2
commit e5a6c243f1

@ -25,8 +25,6 @@ import paddle
from paddle import distributed as dist
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.io.dataloader import StreamDataLoader
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.optimizer import OptimizerFactory
@ -109,7 +107,8 @@ class U2Trainer(Trainer):
def valid(self):
self.model.eval()
if not self.use_streamdata:
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
logger.info(
f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list)
num_seen_utts = 1
total_loss = 0.0
@ -136,7 +135,8 @@ class U2Trainer(Trainer):
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata:
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
msg += "batch: {}/{}, ".format(i + 1,
len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items())
logger.info(msg)
@ -157,7 +157,8 @@ class U2Trainer(Trainer):
self.before_train()
if not self.use_streamdata:
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train()
@ -225,14 +226,18 @@ class U2Trainer(Trainer):
config = self.config.clone()
self.use_streamdata = config.get("use_stream_data", False)
if self.train:
self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args)
self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args)
self.train_loader = DataLoaderFactory.get_dataloader(
'train', config, self.args)
self.valid_loader = DataLoaderFactory.get_dataloader(
'valid', config, self.args)
logger.info("Setup train/valid Dataloader!")
else:
decode_batch_size = config.get('decode', dict()).get(
'decode_batch_size', 1)
self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args)
self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args)
self.test_loader = DataLoaderFactory.get_dataloader('test', config,
self.args)
self.align_loader = DataLoaderFactory.get_dataloader(
'align', config, self.args)
logger.info("Setup test/align Dataloader!")
def setup_model(self):
@ -470,166 +475,30 @@ class U2Tester(U2Trainer):
def export(self):
infer_model, input_spec = self.load_inferspec()
assert isinstance(input_spec, list), type(input_spec)
del input_spec
infer_model.eval()
# static_model = paddle.jit.to_static(infer_model, input_spec=input_spec)
# logger.info(f"Export code: {static_model.forward.code}")
# paddle.jit.save(static_model, self.args.export_path)
# # to check outputs
# def flatten(out):
# if isinstance(out, paddle.Tensor):
# return [out]
# flatten_out = []
# for var in out:
# if isinstance(var, (list, tuple)):
# flatten_out.extend(flatten(var))
# else:
# flatten_out.append(var)
# return flatten_out
# ######################### infer_model.forward_attention_decoder ########################
# a = paddle.full(shape=[10, 8], fill_value=10, dtype='int64')
# b = paddle.full(shape=[10], fill_value=8, dtype='int64')
# # c = paddle.rand(shape=[1, 20, 512], dtype='float32')
# c = paddle.full(shape=[1, 20, 512], fill_value=1, dtype='float32')
# out1 = infer_model.forward_attention_decoder(a, b, c)
# print(out1)
# input_spec = [paddle.static.InputSpec(shape=[None, None], dtype='int64'),
# paddle.static.InputSpec(shape=[None], dtype='int64'),
# paddle.static.InputSpec(shape=[1, None, 512], dtype='float32')]
# static_model = paddle.jit.to_static(infer_model.forward_attention_decoder, input_spec=input_spec)
# paddle.jit.save(static_model, self.args.export_path)
# static_model = paddle.jit.load(self.args.export_path)
# out2 = static_model(a, b, c)
# # print(out2)
# out1 = flatten(out1)
# out2 = flatten(out2)
# for i in range(len(out1)):
# print(np.equal(out1[i].numpy(), out2[i].numpy()).all())
# ######################### infer_model.forward_encoder_chunk ########################
# xs = paddle.rand(shape=[1, 67, 80], dtype='float32')
# offset = paddle.to_tensor([80], dtype='int32')
# required_cache_size = -16
# att_cache = paddle.randn(shape=[12, 8, 80, 128], dtype='float32')
# cnn_cache = paddle.randn(shape=[12, 1, 512, 14], dtype='float32')
# # out1 = infer_model.forward_encoder_chunk(xs, offset, required_cache_size, att_cache, cnn_cache)
# # print(out1)
# zero_out1 = infer_model.forward_encoder_chunk(xs, offset, required_cache_size, att_cache=paddle.zeros([0, 0, 0, 0]), cnn_cache=paddle.zeros([0, 0, 0, 0]))
# # print(zero_out1)
# input_spec = [
# paddle.static.InputSpec(shape=[1, None, 80], dtype='float32'),
# paddle.static.InputSpec(shape=[1], dtype='int32'),
# -16,
# paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32'),
# paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32')]
# static_model = paddle.jit.to_static(infer_model.forward_encoder_chunk, input_spec=input_spec)
# paddle.jit.save(static_model, self.args.export_path)
# static_model = paddle.jit.load(self.args.export_path)
# # out2 = static_model(xs, offset, att_cache, cnn_cache)
# # print(out2)
# zero_out2 = static_model(xs, offset, paddle.zeros([0, 0, 0, 0]), paddle.zeros([0, 0, 0, 0]))
# # out1 = flatten(out1)
# # out2 = flatten(out2)
# # for i in range(len(out1)):
# # print(np.equal(out1[i].numpy(), out2[i].numpy()).all())
# zero_out1 = flatten(zero_out1)
# zero_out2 = flatten(zero_out2)
# for i in range(len(zero_out1)):
# print(np.equal(zero_out1[i].numpy(), zero_out2[i].numpy()).all())
# ######################### infer_model.forward_encoder_chunk zero Tensor online ########################
# xs1 = paddle.rand(shape=[1, 67, 80], dtype='float32')
# offset = paddle.to_tensor([0], dtype='int32')
# required_cache_size = -16
# att_cache = paddle.zeros([0, 0, 0, 0])
# cnn_cache=paddle.zeros([0, 0, 0, 0])
# xs, att_cache, cnn_cache = infer_model.forward_encoder_chunk(xs1, offset, required_cache_size, att_cache, cnn_cache)
# xs2 = paddle.rand(shape=[1, 67, 80], dtype='float32')
# offset = paddle.to_tensor([16], dtype='int32')
# out1 = infer_model.forward_encoder_chunk(xs2, offset, required_cache_size, att_cache, cnn_cache)
# # print(out1)
# input_spec = [
# paddle.static.InputSpec(shape=[1, None, 80], dtype='float32'),
# paddle.static.InputSpec(shape=[1], dtype='int32'),
# -16,
# paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32'),
# paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32')]
# static_model = paddle.jit.to_static(infer_model.forward_encoder_chunk, input_spec=input_spec)
# paddle.jit.save(static_model, self.args.export_path)
# static_model = paddle.jit.load(self.args.export_path)
# offset = paddle.to_tensor([0], dtype='int32')
# att_cache = paddle.zeros([0, 0, 0, 0])
# cnn_cache=paddle.zeros([0, 0, 0, 0])
# xs, att_cache, cnn_cache = static_model(xs1, offset, att_cache, cnn_cache)
# xs = paddle.rand(shape=[1, 67, 80], dtype='float32')
# offset = paddle.to_tensor([16], dtype='int32')
# out2 = static_model(xs2, offset, att_cache, cnn_cache)
# # print(out2)
# out1 = flatten(out1)
# out2 = flatten(out2)
# for i in range(len(out1)):
# print(np.equal(out1[i].numpy(), out2[i].numpy()).all())
###################### save/load combine ########################
paddle.jit.save(infer_model, '/workspace/conformer/PaddleSpeech-conformer/conformer/conformer', combine_params=True)
######################### infer_model.forward_encoder_chunk zero Tensor online ########################
input_spec = [
paddle.static.InputSpec(shape=[1, None, 80], dtype='float32'),
paddle.static.InputSpec(shape=[1], dtype='int32'), -1,
paddle.static.InputSpec(
shape=[None, None, None, None],
dtype='float32'), paddle.static.InputSpec(
shape=[None, None, None, None], dtype='float32')
]
infer_model.forward_encoder_chunk = paddle.jit.to_static(
infer_model.forward_encoder_chunk, input_spec=input_spec)
# paddle.jit.save(static_model, self.args.export_path, combine_params=True)
# xs1 = paddle.rand(shape=[1, 67, 80], dtype='float32')
# offset = paddle.to_tensor([0], dtype='int32')
# required_cache_size = -16
# att_cache = paddle.zeros([0, 0, 0, 0])
# cnn_cache=paddle.zeros([0, 0, 0, 0])
######################### infer_model.forward_attention_decoder ########################
input_spec = [
paddle.static.InputSpec(shape=[None, None], dtype='int64'),
paddle.static.InputSpec(shape=[None], dtype='int64'),
paddle.static.InputSpec(shape=[1, None, 512], dtype='float32')
]
infer_model.forward_attention_decoder = paddle.jit.to_static(
infer_model.forward_attention_decoder, input_spec=input_spec)
# paddle.jit.save(static_model, self.args.export_path, combine_params=True)
# xs, att_cache, cnn_cache = infer_model.forward_encoder_chunk(xs1, offset, required_cache_size, att_cache, cnn_cache)
# xs2 = paddle.rand(shape=[1, 67, 80], dtype='float32')
# offset = paddle.to_tensor([16], dtype='int32')
# out1 = infer_model.forward_encoder_chunk(xs2, offset, required_cache_size, att_cache, cnn_cache)
# # print(out1)
# from paddle.jit.layer import Layer
# layer = Layer()
# layer.load('/workspace/conformer/PaddleSpeech-conformer/conformer/conformer', paddle.CUDAPlace(0))
# offset = paddle.to_tensor([0], dtype='int32')
# att_cache = paddle.zeros([0, 0, 0, 0])
# cnn_cache=paddle.zeros([0, 0, 0, 0])
# xs, att_cache, cnn_cache = layer.forward_encoder_chunk(xs1, offset, att_cache, cnn_cache)
# offset = paddle.to_tensor([16], dtype='int32')
# out2 = layer.forward_encoder_chunk(xs2, offset, att_cache, cnn_cache)
# # print(out2)
# out1 = flatten(out1)
# out2 = flatten(out2)
# for i in range(len(out1)):
# print(np.equal(out1[i].numpy(), out2[i].numpy()).all())
paddle.jit.save(infer_model, './export.jit', combine_params=True)

@ -29,6 +29,9 @@ import paddle
from paddle import jit
from paddle import nn
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
from paddlespeech.audio.utils.tensor_utils import pad_sequence
from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.decoders.scorers.ctc import CTCPrefixScorer
from paddlespeech.s2t.frontend.utility import IGNORE_ID
from paddlespeech.s2t.frontend.utility import load_cmvn
@ -48,9 +51,6 @@ from paddlespeech.s2t.utils import checkpoint
from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank
from paddlespeech.s2t.utils.log import Log
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
from paddlespeech.audio.utils.tensor_utils import pad_sequence
from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.utils.utility import log_add
from paddlespeech.s2t.utils.utility import UpdateConfig
@ -59,20 +59,6 @@ __all__ = ["U2Model", "U2InferModel"]
logger = Log(__name__).getlog()
# input_spec1 = [paddle.static.InputSpec(shape=[None, None], dtype='int64'),
# paddle.static.InputSpec(shape=[None], dtype='int64'),
# paddle.static.InputSpec(shape=[1, None, 512], dtype='float32')]
# input_spec2 = [
# paddle.static.InputSpec(shape=[1, None, 80], dtype='float32'),
# paddle.static.InputSpec(shape=[1], dtype='int32'),
# -16,
# paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32'),
# paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32')]
# input_spec3 = [paddle.static.InputSpec(shape=[1, 1, 1], dtype='int64'),
# paddle.static.InputSpec(shape=[1], dtype='int64')]
class U2BaseModel(ASRInterface, nn.Layer):
"""CTC-Attention hybrid Encoder-Decoder model"""
@ -588,44 +574,44 @@ class U2BaseModel(ASRInterface, nn.Layer):
best_index = i
return hyps[best_index][0]
#@jit.to_static
@jit.to_static(property=True)
def subsampling_rate(self) -> int:
""" Export interface for c++ call, return subsampling_rate of the
model
"""
return self.encoder.embed.subsampling_rate
#@jit.to_static
@jit.to_static(property=True)
def right_context(self) -> int:
""" Export interface for c++ call, return right_context of the model
"""
return self.encoder.embed.right_context
#@jit.to_static
@jit.to_static(property=True)
def sos_symbol(self) -> int:
""" Export interface for c++ call, return sos symbol id of the model
"""
return self.sos
#@jit.to_static
@jit.to_static(property=True)
def eos_symbol(self) -> int:
""" Export interface for c++ call, return eos symbol id of the model
"""
return self.eos
@jit.to_static(input_spec=[
paddle.static.InputSpec(shape=[1, None, 80], dtype='float32'),
paddle.static.InputSpec(shape=[1], dtype='int32'),
-16,
paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32'),
paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32')])
# @jit.to_static(input_spec=[
# paddle.static.InputSpec(shape=[1, None, 80], dtype='float32'),
# paddle.static.InputSpec(shape=[1], dtype='int32'),
# -1,
# paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32'),
# paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32')])
def forward_encoder_chunk(
self,
xs: paddle.Tensor,
offset: int,
required_cache_size: int,
att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]),
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.
@ -660,8 +646,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
paddle.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
"""
return self.encoder.forward_chunk(
xs, offset, required_cache_size, att_cache, cnn_cache)
return self.encoder.forward_chunk(xs, offset, required_cache_size,
att_cache, cnn_cache)
# @jit.to_static
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
@ -674,10 +660,10 @@ class U2BaseModel(ASRInterface, nn.Layer):
"""
return self.ctc.log_softmax(xs)
@jit.to_static(input_spec=[
paddle.static.InputSpec(shape=[None, None], dtype='int64'),
paddle.static.InputSpec(shape=[None], dtype='int64'),
paddle.static.InputSpec(shape=[1, None, 512], dtype='float32')])
# @jit.to_static(input_spec=[
# paddle.static.InputSpec(shape=[None, None], dtype='int64'),
# paddle.static.InputSpec(shape=[None], dtype='int64'),
# paddle.static.InputSpec(shape=[1, None, 512], dtype='float32')])
def forward_attention_decoder(
self,
hyps: paddle.Tensor,
@ -941,8 +927,9 @@ class U2InferModel(U2Model):
super().__init__(configs)
@jit.to_static(input_spec=[
paddle.static.InputSpec(shape=[1, 1, 1], dtype='int64'),
paddle.static.InputSpec(shape=[1], dtype='int64')])
paddle.static.InputSpec(shape=[1, 1, 1], dtype='int64'),
paddle.static.InputSpec(shape=[1], dtype='int64')
])
def forward(self,
feats,
feats_lengths,
@ -958,6 +945,7 @@ class U2InferModel(U2Model):
Returns:
List[List[int]]: best path result
"""
# dummy code for dy2st
# return self.ctc_greedy_search(
# feats,
# feats_lengths,

Loading…
Cancel
Save