Merge pull request #1275 from Jackwaterveg/tmp

[ASR]use pre-commit
pull/1284/head
Jackwaterveg 3 years ago committed by GitHub
commit f7ffd9917c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,12 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deepspeech2 ASR Model"""
from typing import Optional
import paddle
from paddle import nn
from src_deepspeech2x.models.ds2.rnn import RNNStack
from yacs.config import CfgNode
from paddlespeech.s2t.models.ds2.conv import ConvStack
from paddlespeech.s2t.modules.ctc import CTCDecoder

@ -15,8 +15,6 @@
import time
from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import numpy as np
import paddle
@ -24,7 +22,6 @@ from paddle import distributed as dist
from paddle.io import DataLoader
from src_deepspeech2x.models.ds2 import DeepSpeech2InferModel
from src_deepspeech2x.models.ds2 import DeepSpeech2Model
from yacs.config import CfgNode
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.io.collator import SpeechCollator

@ -19,8 +19,8 @@ import paddle
from paddle.inference import Config
from paddle.inference import create_predictor
from paddle.io import DataLoader
from yacs.config import CfgNode
from paddlespeech.s2t.io.collator import SpeechCollator
from paddlespeech.s2t.io.dataset import ManifestDataset
from paddlespeech.s2t.models.ds2 import DeepSpeech2Model

@ -17,8 +17,8 @@ import functools
import numpy as np
import paddle
from paddle.io import DataLoader
from yacs.config import CfgNode
from paddlespeech.s2t.io.collator import SpeechCollator
from paddlespeech.s2t.io.dataset import ManifestDataset
from paddlespeech.s2t.models.ds2 import DeepSpeech2Model

@ -13,6 +13,7 @@
# limitations under the License.
"""Export for DeepSpeech2 model."""
from yacs.config import CfgNode
from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2Tester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments

@ -13,6 +13,7 @@
# limitations under the License.
"""Evaluation for DeepSpeech2 model."""
from yacs.config import CfgNode
from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2ExportTester as ExportTester
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments

@ -20,7 +20,6 @@ import paddle
import soundfile
from yacs.config import CfgNode
from paddlespeech.s2t.exps.deepspeech2.config import get_cfg_defaults
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.io.collator import SpeechCollator
from paddlespeech.s2t.models.ds2 import DeepSpeech2Model

@ -13,8 +13,8 @@
# limitations under the License.
"""Trainer for DeepSpeech2 model."""
from paddle import distributed as dist
from yacs.config import CfgNode
from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2Trainer as Trainer
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments

@ -16,7 +16,6 @@ import os
import time
from collections import defaultdict
from contextlib import nullcontext
from typing import Optional
import jsonlines
import numpy as np
@ -24,7 +23,6 @@ import paddle
from paddle import distributed as dist
from paddle import inference
from paddle.io import DataLoader
from yacs.config import CfgNode
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.io.collator import SpeechCollator

@ -13,6 +13,7 @@
# limitations under the License.
"""Export for U2 model."""
from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.model import U2Tester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments

@ -16,8 +16,8 @@ import cProfile
import os
from paddle import distributed as dist
from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.model import U2Trainer as Trainer
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments

@ -18,13 +18,11 @@ import time
from collections import defaultdict
from collections import OrderedDict
from contextlib import nullcontext
from typing import Optional
import jsonlines
import numpy as np
import paddle
from paddle import distributed as dist
from yacs.config import CfgNode
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import BatchDataLoader

@ -17,13 +17,11 @@ import os
import time
from collections import defaultdict
from contextlib import nullcontext
from typing import Optional
import jsonlines
import numpy as np
import paddle
from paddle import distributed as dist
from yacs.config import CfgNode
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import load_dict
@ -42,6 +40,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
logger = Log(__name__).getlog()
class U2Trainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)

@ -13,6 +13,7 @@
# limitations under the License.
"""Export for U2 model."""
from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2_st.model import U2STTester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments

@ -18,13 +18,11 @@ import time
from collections import defaultdict
from collections import OrderedDict
from contextlib import nullcontext
from typing import Optional
import jsonlines
import numpy as np
import paddle
from paddle import distributed as dist
from yacs.config import CfgNode
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import BatchDataLoader
@ -208,8 +206,7 @@ class U2STTrainer(Trainer):
k.split(',')) == 2 else ""
msg += ","
msg = msg[:-1] # remove the last ","
if (batch_index + 1
) % self.config.log_interval == 0:
if (batch_index + 1) % self.config.log_interval == 0:
logger.info(msg)
except Exception as e:
logger.error(e)
@ -260,7 +257,8 @@ class U2STTrainer(Trainer):
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.preprocess_config, # aug will be off when train_mode=False
preprocess_conf=config.
preprocess_config, # aug will be off when train_mode=False
n_iter_processes=config.num_workers,
subsampling_factor=1,
load_aux_output=load_transcript,
@ -281,7 +279,8 @@ class U2STTrainer(Trainer):
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.preprocess_config, # aug will be off when train_mode=False
preprocess_conf=config.
preprocess_config, # aug will be off when train_mode=False
n_iter_processes=config.num_workers,
subsampling_factor=1,
load_aux_output=load_transcript,
@ -290,7 +289,8 @@ class U2STTrainer(Trainer):
logger.info("Setup train/valid Dataloader!")
else:
# test dataset, return raw text
decode_batch_size = config.get('decode',dict()).get('decode_batch_size', 1)
decode_batch_size = config.get('decode', dict()).get(
'decode_batch_size', 1)
self.test_loader = BatchDataLoader(
json_file=config.test_manifest,
train_mode=False,
@ -305,7 +305,8 @@ class U2STTrainer(Trainer):
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.preprocess_config, # aug will be off when train_mode=False
preprocess_conf=config.
preprocess_config, # aug will be off when train_mode=False
n_iter_processes=config.num_workers,
subsampling_factor=1,
num_encs=1,

@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import io
from typing import Optional
import numpy as np
from yacs.config import CfgNode
from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline
from paddlespeech.s2t.frontend.featurizer.speech_featurizer import SpeechFeaturizer

@ -13,11 +13,8 @@
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
# Modified from wenet(https://github.com/wenet-e2e/wenet)
from typing import Optional
import jsonlines
from paddle.io import Dataset
from yacs.config import CfgNode
from paddlespeech.s2t.frontend.utility import read_manifest
from paddlespeech.s2t.utils.log import Log

@ -12,11 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deepspeech2 ASR Model"""
from typing import Optional
import paddle
from paddle import nn
from yacs.config import CfgNode
from paddlespeech.s2t.models.ds2.conv import ConvStack
from paddlespeech.s2t.models.ds2.rnn import RNNStack
@ -119,6 +116,7 @@ class DeepSpeech2Model(nn.Layer):
before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput
"""
def __init__(self,
feat_size,
dict_size,

@ -12,12 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deepspeech2 ASR Online Model"""
from typing import Optional
import paddle
import paddle.nn.functional as F
from paddle import nn
from yacs.config import CfgNode
from paddlespeech.s2t.models.ds2_online.conv import Conv2dSubsampling4Online
from paddlespeech.s2t.modules.ctc import CTCDecoder
@ -243,6 +240,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput
"""
def __init__(
self,
feat_size,

@ -26,7 +26,6 @@ from typing import Tuple
import paddle
from paddle import jit
from paddle import nn
from yacs.config import CfgNode
from paddlespeech.s2t.decoders.scorers.ctc import CTCPrefixScorer
from paddlespeech.s2t.frontend.utility import IGNORE_ID
@ -59,6 +58,7 @@ logger = Log(__name__).getlog()
class U2BaseModel(ASRInterface, nn.Layer):
"""CTC-Attention hybrid Encoder-Decoder model"""
def __init__(self,
vocab_size: int,
encoder: TransformerEncoder,

@ -24,7 +24,6 @@ from typing import Tuple
import paddle
from paddle import jit
from paddle import nn
from yacs.config import CfgNode
from paddlespeech.s2t.frontend.utility import IGNORE_ID
from paddlespeech.s2t.frontend.utility import load_cmvn
@ -51,6 +50,7 @@ logger = Log(__name__).getlog()
class U2STBaseModel(nn.Layer):
"""CTC-Attention hybrid Encoder-Decoder model"""
def __init__(self,
vocab_size: int,
encoder: TransformerEncoder,

Loading…
Cancel
Save