From 1dfca4ef736493a99e2ac35f4d985b20472aa197 Mon Sep 17 00:00:00 2001 From: tianhao zhang <15600919271@163.com> Date: Wed, 31 Aug 2022 02:43:54 +0000 Subject: [PATCH] fix multigpu training --- .../server/engine/asr/online/python/asr_engine.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/paddlespeech/server/engine/asr/online/python/asr_engine.py b/paddlespeech/server/engine/asr/online/python/asr_engine.py index 4df38f09..96d4823e 100644 --- a/paddlespeech/server/engine/asr/online/python/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/python/asr_engine.py @@ -21,10 +21,10 @@ import paddle from numpy import float32 from yacs.config import CfgNode +from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.resource import CommonTaskResource -from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.utils.tensor_utils import add_sos_eos @@ -130,8 +130,8 @@ class PaddleASRConnectionHanddler: ## conformer # cache for conformer online - self.att_cache = paddle.zeros([0,0,0,0]) - self.cnn_cache = paddle.zeros([0,0,0,0]) + self.att_cache = paddle.zeros([0, 0, 0, 0]) + self.cnn_cache = paddle.zeros([0, 0, 0, 0]) self.encoder_out = None # conformer decoding state @@ -474,9 +474,10 @@ class PaddleASRConnectionHanddler: # cur chunk chunk_xs = self.cached_feat[:, cur:end, :] # forward chunk - (y, self.att_cache, self.cnn_cache) = self.model.encoder.forward_chunk( - chunk_xs, self.offset, required_cache_size, - self.att_cache, self.cnn_cache) + (y, self.att_cache, + self.cnn_cache) = self.model.encoder.forward_chunk( + chunk_xs, self.offset, required_cache_size, self.att_cache, + self.cnn_cache, paddle.ones([0, 0, 0], dtype=paddle.bool)) outputs.append(y) # update the global offset, in decoding frame unit