Merge pull request #2334 from Zth9730/fix_multigpu_train

[s2t] fix asr_engine.py
pull/2347/head
Hui Zhang 3 years ago committed by GitHub
commit 58ab7e8d10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -21,10 +21,10 @@ import paddle
from numpy import float32
from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger
from paddlespeech.resource import CommonTaskResource
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.utils.tensor_utils import add_sos_eos
@ -130,8 +130,8 @@ class PaddleASRConnectionHanddler:
## conformer
# cache for conformer online
self.att_cache = paddle.zeros([0,0,0,0])
self.cnn_cache = paddle.zeros([0,0,0,0])
self.att_cache = paddle.zeros([0, 0, 0, 0])
self.cnn_cache = paddle.zeros([0, 0, 0, 0])
self.encoder_out = None
# conformer decoding state
@ -474,9 +474,10 @@ class PaddleASRConnectionHanddler:
# cur chunk
chunk_xs = self.cached_feat[:, cur:end, :]
# forward chunk
(y, self.att_cache, self.cnn_cache) = self.model.encoder.forward_chunk(
chunk_xs, self.offset, required_cache_size,
self.att_cache, self.cnn_cache)
(y, self.att_cache,
self.cnn_cache) = self.model.encoder.forward_chunk(
chunk_xs, self.offset, required_cache_size, self.att_cache,
self.cnn_cache, paddle.ones([0, 0, 0], dtype=paddle.bool))
outputs.append(y)
# update the global offset, in decoding frame unit

Loading…
Cancel
Save