pull/2502/head
Hui Zhang 2 years ago
parent e86337a423
commit 1f4f98b171

@ -18,6 +18,7 @@ from pathlib import Path
import paddle import paddle
import soundfile import soundfile
from paddleslim import PTQ
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.audio.transform.transformation import Transformation
@ -26,7 +27,6 @@ from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
from paddleslim import PTQ
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -90,14 +90,14 @@ class U2Infer():
ctc_weight=decode_config.ctc_weight, ctc_weight=decode_config.ctc_weight,
decoding_chunk_size=decode_config.decoding_chunk_size, decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks, num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming simulate_streaming=decode_config.simulate_streaming,
reverse_weight=decode_config.reverse_weight) reverse_weight=decode_config.reverse_weight)
rsl = result_transcripts[0][0] rsl = result_transcripts[0][0]
utt = Path(self.audio_file).name utt = Path(self.audio_file).name
logger.info(f"hyp: {utt} {rsl}") logger.info(f"hyp: {utt} {rsl}")
# print(self.model) # print(self.model)
# print(self.model.forward_encoder_chunk) # print(self.model.forward_encoder_chunk)
logger.info("-------------start quant ----------------------") logger.info("-------------start quant ----------------------")
batch_size = 1 batch_size = 1
feat_dim = 80 feat_dim = 80
@ -161,7 +161,11 @@ class U2Infer():
# jit save # jit save
logger.info(f"export save: {self.args.export_path}") logger.info(f"export save: {self.args.export_path}")
config = {'is_static': True, 'combine_params':True, 'skip_forward':True} config = {
'is_static': True,
'combine_params': True,
'skip_forward': True
}
self.ptq.save_quantized_model(self.model, self.args.export_path) self.ptq.save_quantized_model(self.model, self.args.export_path)
# paddle.jit.save( # paddle.jit.save(
# self.model, # self.model,
@ -169,7 +173,6 @@ class U2Infer():
# combine_params=True, # combine_params=True,
# skip_forward=True) # skip_forward=True)
def check(audio_file): def check(audio_file):
if not os.path.isfile(audio_file): if not os.path.isfile(audio_file):
@ -201,7 +204,10 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--audio_file", type=str, help="path of the input audio file") "--audio_file", type=str, help="path of the input audio file")
parser.add_argument( parser.add_argument(
"--export_path", type=str, default='export', help="path of the input audio file") "--export_path",
type=str,
default='export',
help="path of the input audio file")
args = parser.parse_args() args = parser.parse_args()
config = CfgNode(new_allowed=True) config = CfgNode(new_allowed=True)

@ -131,7 +131,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
if self.ctc_weight != 1.0: if self.ctc_weight != 1.0:
start = time.time() start = time.time()
loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask, loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask,
text, text_lengths, self.reverse_weight) text, text_lengths,
self.reverse_weight)
decoder_time = time.time() - start decoder_time = time.time() - start
#logger.debug(f"decoder time: {decoder_time}") #logger.debug(f"decoder time: {decoder_time}")
@ -152,13 +153,12 @@ class U2BaseModel(ASRInterface, nn.Layer):
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
return loss, loss_att, loss_ctc return loss, loss_att, loss_ctc
def _calc_att_loss( def _calc_att_loss(self,
self, encoder_out: paddle.Tensor,
encoder_out: paddle.Tensor, encoder_mask: paddle.Tensor,
encoder_mask: paddle.Tensor, ys_pad: paddle.Tensor,
ys_pad: paddle.Tensor, ys_pad_lens: paddle.Tensor,
ys_pad_lens: paddle.Tensor, reverse_weight: float) -> Tuple[paddle.Tensor, float]:
reverse_weight: float) -> Tuple[paddle.Tensor, float]:
"""Calc attention loss. """Calc attention loss.
Args: Args:
@ -188,8 +188,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
r_loss_att = paddle.to_tensor(0.0) r_loss_att = paddle.to_tensor(0.0)
if reverse_weight > 0.0: if reverse_weight > 0.0:
r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad) r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad)
loss_att = loss_att * (1 - reverse_weight loss_att = loss_att * (1 - reverse_weight) + r_loss_att * reverse_weight
) + r_loss_att * reverse_weight
acc_att = th_accuracy( acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size), decoder_out.view(-1, self.vocab_size),
ys_out_pad, ys_out_pad,
@ -599,8 +598,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
f"hyp {i} len {len(hyp[0])} r2l score: {r_score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}" f"hyp {i} len {len(hyp[0])} r2l score: {r_score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}"
) )
score = score * (1 - reverse_weight score = score * (1 - reverse_weight) + r_score * reverse_weight
) + r_score * reverse_weight
# add ctc score (which in ln domain) # add ctc score (which in ln domain)
score += hyp[1] * ctc_weight score += hyp[1] * ctc_weight
if score > best_score: if score > best_score:

@ -22,7 +22,6 @@ from numpy import float32
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.audio.utils.tensor_utils import st_reverse_pad_list
from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.resource import CommonTaskResource from paddlespeech.resource import CommonTaskResource

Loading…
Cancel
Save