pull/2502/head
Hui Zhang 2 years ago
parent e86337a423
commit 1f4f98b171

@ -18,6 +18,7 @@ from pathlib import Path
import paddle
import soundfile
from paddleslim import PTQ
from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation
@ -26,7 +27,6 @@ from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddleslim import PTQ
logger = Log(__name__).getlog()
@ -90,7 +90,7 @@ class U2Infer():
ctc_weight=decode_config.ctc_weight,
decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming
simulate_streaming=decode_config.simulate_streaming,
reverse_weight=decode_config.reverse_weight)
rsl = result_transcripts[0][0]
utt = Path(self.audio_file).name
@ -161,7 +161,11 @@ class U2Infer():
# jit save
logger.info(f"export save: {self.args.export_path}")
config = {'is_static': True, 'combine_params':True, 'skip_forward':True}
config = {
'is_static': True,
'combine_params': True,
'skip_forward': True
}
self.ptq.save_quantized_model(self.model, self.args.export_path)
# paddle.jit.save(
# self.model,
@ -170,7 +174,6 @@ class U2Infer():
# skip_forward=True)
def check(audio_file):
if not os.path.isfile(audio_file):
print("Please input the right audio file path")
@ -201,7 +204,10 @@ if __name__ == "__main__":
parser.add_argument(
"--audio_file", type=str, help="path of the input audio file")
parser.add_argument(
"--export_path", type=str, default='export', help="path of the input audio file")
"--export_path",
type=str,
default='export',
help="path of the input audio file")
args = parser.parse_args()
config = CfgNode(new_allowed=True)

@ -131,7 +131,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
if self.ctc_weight != 1.0:
start = time.time()
loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask,
text, text_lengths, self.reverse_weight)
text, text_lengths,
self.reverse_weight)
decoder_time = time.time() - start
#logger.debug(f"decoder time: {decoder_time}")
@ -152,8 +153,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
return loss, loss_att, loss_ctc
def _calc_att_loss(
self,
def _calc_att_loss(self,
encoder_out: paddle.Tensor,
encoder_mask: paddle.Tensor,
ys_pad: paddle.Tensor,
@ -188,8 +188,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
r_loss_att = paddle.to_tensor(0.0)
if reverse_weight > 0.0:
r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad)
loss_att = loss_att * (1 - reverse_weight
) + r_loss_att * reverse_weight
loss_att = loss_att * (1 - reverse_weight) + r_loss_att * reverse_weight
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
@ -599,8 +598,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
f"hyp {i} len {len(hyp[0])} r2l score: {r_score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}"
)
score = score * (1 - reverse_weight
) + r_score * reverse_weight
score = score * (1 - reverse_weight) + r_score * reverse_weight
# add ctc score (which in ln domain)
score += hyp[1] * ctc_weight
if score > best_score:

@ -22,7 +22,6 @@ from numpy import float32
from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.audio.utils.tensor_utils import st_reverse_pad_list
from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger
from paddlespeech.resource import CommonTaskResource

Loading…
Cancel
Save