format code

pull/1273/head
Hui Zhang 3 years ago
parent 6f651d762e
commit 3a2db414e6

@ -19,8 +19,8 @@ import paddle
from paddle.inference import Config from paddle.inference import Config
from paddle.inference import create_predictor from paddle.inference import create_predictor
from paddle.io import DataLoader from paddle.io import DataLoader
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.collator import SpeechCollator
from paddlespeech.s2t.io.dataset import ManifestDataset from paddlespeech.s2t.io.dataset import ManifestDataset
from paddlespeech.s2t.models.ds2 import DeepSpeech2Model from paddlespeech.s2t.models.ds2 import DeepSpeech2Model

@ -17,8 +17,8 @@ import functools
import numpy as np import numpy as np
import paddle import paddle
from paddle.io import DataLoader from paddle.io import DataLoader
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.collator import SpeechCollator
from paddlespeech.s2t.io.dataset import ManifestDataset from paddlespeech.s2t.io.dataset import ManifestDataset
from paddlespeech.s2t.models.ds2 import DeepSpeech2Model from paddlespeech.s2t.models.ds2 import DeepSpeech2Model

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Export for DeepSpeech2 model.""" """Export for DeepSpeech2 model."""
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2Tester as Tester from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2Tester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments from paddlespeech.s2t.utils.utility import print_arguments

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Evaluation for DeepSpeech2 model.""" """Evaluation for DeepSpeech2 model."""
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2ExportTester as ExportTester from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2ExportTester as ExportTester
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments from paddlespeech.s2t.utils.utility import print_arguments

@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
"""Trainer for DeepSpeech2 model.""" """Trainer for DeepSpeech2 model."""
from paddle import distributed as dist from paddle import distributed as dist
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2Trainer as Trainer from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2Trainer as Trainer
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments from paddlespeech.s2t.utils.utility import print_arguments

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Export for U2 model.""" """Export for U2 model."""
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.model import U2Tester as Tester from paddlespeech.s2t.exps.u2.model import U2Tester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments from paddlespeech.s2t.utils.utility import print_arguments

@ -16,8 +16,8 @@ import cProfile
import os import os
from paddle import distributed as dist from paddle import distributed as dist
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.model import U2Trainer as Trainer from paddlespeech.s2t.exps.u2.model import U2Trainer as Trainer
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments from paddlespeech.s2t.utils.utility import print_arguments

@ -42,6 +42,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
class U2Trainer(Trainer): class U2Trainer(Trainer):
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Export for U2 model.""" """Export for U2 model."""
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2_st.model import U2STTester as Tester from paddlespeech.s2t.exps.u2_st.model import U2STTester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments from paddlespeech.s2t.utils.utility import print_arguments

@ -208,8 +208,7 @@ class U2STTrainer(Trainer):
k.split(',')) == 2 else "" k.split(',')) == 2 else ""
msg += "," msg += ","
msg = msg[:-1] # remove the last "," msg = msg[:-1] # remove the last ","
if (batch_index + 1 if (batch_index + 1) % self.config.log_interval == 0:
) % self.config.log_interval == 0:
logger.info(msg) logger.info(msg)
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
@ -260,7 +259,8 @@ class U2STTrainer(Trainer):
batch_frames_in=0, batch_frames_in=0,
batch_frames_out=0, batch_frames_out=0,
batch_frames_inout=0, batch_frames_inout=0,
preprocess_conf=config.preprocess_config, # aug will be off when train_mode=False preprocess_conf=config.
preprocess_config, # aug will be off when train_mode=False
n_iter_processes=config.num_workers, n_iter_processes=config.num_workers,
subsampling_factor=1, subsampling_factor=1,
load_aux_output=load_transcript, load_aux_output=load_transcript,
@ -281,7 +281,8 @@ class U2STTrainer(Trainer):
batch_frames_in=0, batch_frames_in=0,
batch_frames_out=0, batch_frames_out=0,
batch_frames_inout=0, batch_frames_inout=0,
preprocess_conf=config.preprocess_config, # aug will be off when train_mode=False preprocess_conf=config.
preprocess_config, # aug will be off when train_mode=False
n_iter_processes=config.num_workers, n_iter_processes=config.num_workers,
subsampling_factor=1, subsampling_factor=1,
load_aux_output=load_transcript, load_aux_output=load_transcript,
@ -290,7 +291,8 @@ class U2STTrainer(Trainer):
logger.info("Setup train/valid Dataloader!") logger.info("Setup train/valid Dataloader!")
else: else:
# test dataset, return raw text # test dataset, return raw text
decode_batch_size = config.get('decode',dict()).get('decode_batch_size', 1) decode_batch_size = config.get('decode', dict()).get(
'decode_batch_size', 1)
self.test_loader = BatchDataLoader( self.test_loader = BatchDataLoader(
json_file=config.test_manifest, json_file=config.test_manifest,
train_mode=False, train_mode=False,
@ -305,7 +307,8 @@ class U2STTrainer(Trainer):
batch_frames_in=0, batch_frames_in=0,
batch_frames_out=0, batch_frames_out=0,
batch_frames_inout=0, batch_frames_inout=0,
preprocess_conf=config.preprocess_config, # aug will be off when train_mode=False preprocess_conf=config.
preprocess_config, # aug will be off when train_mode=False
n_iter_processes=config.num_workers, n_iter_processes=config.num_workers,
subsampling_factor=1, subsampling_factor=1,
num_encs=1, num_encs=1,

@ -119,6 +119,7 @@ class DeepSpeech2Model(nn.Layer):
before softmax) and a ctc cost layer. before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput :rtype: tuple of LayerOutput
""" """
def __init__(self, def __init__(self,
feat_size, feat_size,
dict_size, dict_size,

@ -243,6 +243,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
before softmax) and a ctc cost layer. before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput :rtype: tuple of LayerOutput
""" """
def __init__( def __init__(
self, self,
feat_size, feat_size,

@ -59,6 +59,7 @@ logger = Log(__name__).getlog()
class U2BaseModel(ASRInterface, nn.Layer): class U2BaseModel(ASRInterface, nn.Layer):
"""CTC-Attention hybrid Encoder-Decoder model""" """CTC-Attention hybrid Encoder-Decoder model"""
def __init__(self, def __init__(self,
vocab_size: int, vocab_size: int,
encoder: TransformerEncoder, encoder: TransformerEncoder,

@ -51,6 +51,7 @@ logger = Log(__name__).getlog()
class U2STBaseModel(nn.Layer): class U2STBaseModel(nn.Layer):
"""CTC-Attention hybrid Encoder-Decoder model""" """CTC-Attention hybrid Encoder-Decoder model"""
def __init__(self, def __init__(self,
vocab_size: int, vocab_size: int,
encoder: TransformerEncoder, encoder: TransformerEncoder,

@ -39,7 +39,6 @@ except ImportError:
except Exception as e: except Exception as e:
logger.info("paddlespeech_ctcdecoders not installed!") logger.info("paddlespeech_ctcdecoders not installed!")
__all__ = ['CTCDecoder'] __all__ = ['CTCDecoder']

Loading…
Cancel
Save