From 226cfac0bd0b50cb482a20cc8f7b35e64a341016 Mon Sep 17 00:00:00 2001 From: liyulingyue <852433440@qq.com> Date: Wed, 15 Jan 2025 05:55:21 +0800 Subject: [PATCH] fix setup.py --- paddlespeech/__init__.py | 4 ---- .../s2t/modules/conformer_convolution.py | 4 ++-- paddlespeech/s2t/modules/ctc.py | 4 ++-- paddlespeech/s2t/modules/decoder.py | 8 +++----- paddlespeech/s2t/modules/encoder.py | 11 +++++------ paddlespeech/s2t/training/scheduler.py | 4 ++-- .../t2s/models/diffsinger/diffsinger.py | 6 +++--- .../t2s/models/diffsinger/fastspeech2midi.py | 6 +++--- .../t2s/models/fastspeech2/fastspeech2.py | 15 ++++++++------- paddlespeech/t2s/models/jets/generator.py | 2 +- paddlespeech/t2s/models/jets/jets.py | 7 +++---- paddlespeech/t2s/models/tacotron2/tacotron2.py | 4 ++-- .../models/transformer_tts/transformer_tts.py | 5 +++-- paddlespeech/t2s/models/vits/vits.py | 18 ++++++++++-------- .../adversarial_loss/speaker_classifier.py | 4 ++-- paddlespeech/t2s/modules/losses.py | 4 ++-- paddlespeech/t2s/modules/nets_utils.py | 5 ++--- .../modules/predictor/variance_predictor.py | 6 +++--- paddlespeech/t2s/modules/style_encoder.py | 8 ++++---- setup.py | 6 +++--- 20 files changed, 63 insertions(+), 68 deletions(-) diff --git a/paddlespeech/__init__.py b/paddlespeech/__init__.py index 969d189f5..6c7e75c1f 100644 --- a/paddlespeech/__init__.py +++ b/paddlespeech/__init__.py @@ -13,7 +13,3 @@ # limitations under the License. import _locale _locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8']) - -__version__ = '0.0.0' - -__commit__ = '9cf8c1985a98bb380c183116123672976bdfe5c9' diff --git a/paddlespeech/s2t/modules/conformer_convolution.py b/paddlespeech/s2t/modules/conformer_convolution.py index 7a0c72f3b..4a2b449d1 100644 --- a/paddlespeech/s2t/modules/conformer_convolution.py +++ b/paddlespeech/s2t/modules/conformer_convolution.py @@ -19,7 +19,7 @@ from typing import Tuple import paddle from paddle import nn from paddle.nn import initializer as I -from typeguard import check_argument_types +from typeguard import typechecked from paddlespeech.s2t.modules.align import BatchNorm1D from paddlespeech.s2t.modules.align import Conv1D @@ -34,6 +34,7 @@ __all__ = ['ConvolutionModule'] class ConvolutionModule(nn.Layer): """ConvolutionModule in Conformer model.""" + @typechecked def __init__(self, channels: int, kernel_size: int=15, @@ -52,7 +53,6 @@ class ConvolutionModule(nn.Layer): causal (bool): Whether use causal convolution or not bias (bool): Whether Conv with bias or not """ - assert check_argument_types() super().__init__() self.bias = bias self.channels = channels diff --git a/paddlespeech/s2t/modules/ctc.py b/paddlespeech/s2t/modules/ctc.py index e0c01ab46..9309a1e0e 100644 --- a/paddlespeech/s2t/modules/ctc.py +++ b/paddlespeech/s2t/modules/ctc.py @@ -17,7 +17,7 @@ from typing import Union import paddle from paddle import nn from paddle.nn import functional as F -from typeguard import check_argument_types +from typeguard import typechecked from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.modules.loss import CTCLoss @@ -48,6 +48,7 @@ __all__ = ['CTCDecoder'] class CTCDecoderBase(nn.Layer): + @typechecked def __init__(self, odim, enc_n_units, @@ -66,7 +67,6 @@ class CTCDecoderBase(nn.Layer): batch_average (bool): do batch dim wise average. grad_norm_type (str): Default, None. one of 'instance', 'batch', 'frame', None. """ - assert check_argument_types() super().__init__() self.blank_id = blank_id diff --git a/paddlespeech/s2t/modules/decoder.py b/paddlespeech/s2t/modules/decoder.py index 1881a865c..6a65b2cee 100644 --- a/paddlespeech/s2t/modules/decoder.py +++ b/paddlespeech/s2t/modules/decoder.py @@ -21,7 +21,7 @@ from typing import Tuple import paddle from paddle import nn -from typeguard import check_argument_types +from typeguard import typechecked from paddlespeech.s2t.decoders.scorers.scorer_interface import BatchScorerInterface from paddlespeech.s2t.modules.align import Embedding @@ -61,6 +61,7 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): False: x -> x + att(x) """ + @typechecked def __init__(self, vocab_size: int, encoder_output_size: int, @@ -77,8 +78,6 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): concat_after: bool=False, max_len: int=5000): - assert check_argument_types() - nn.Layer.__init__(self) self.selfattention_layer_type = 'selfattn' attention_dim = encoder_output_size @@ -276,6 +275,7 @@ class BiTransformerDecoder(BatchScorerInterface, nn.Layer): False: x -> x + att(x) """ + @typechecked def __init__(self, vocab_size: int, encoder_output_size: int, @@ -293,8 +293,6 @@ class BiTransformerDecoder(BatchScorerInterface, nn.Layer): concat_after: bool=False, max_len: int=5000): - assert check_argument_types() - nn.Layer.__init__(self) self.left_decoder = TransformerDecoder( vocab_size, encoder_output_size, attention_heads, linear_units, diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index 27d7ffbd7..841145759 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -21,7 +21,7 @@ from typing import Union import paddle from paddle import nn -from typeguard import check_argument_types +from typeguard import typechecked from paddlespeech.s2t.modules.activation import get_activation from paddlespeech.s2t.modules.align import LayerNorm @@ -58,6 +58,7 @@ __all__ = [ class BaseEncoder(nn.Layer): + @typechecked def __init__(self, input_size: int, output_size: int=256, @@ -108,7 +109,6 @@ class BaseEncoder(nn.Layer): use_dynamic_left_chunk (bool): whether use dynamic left chunk in dynamic chunk training """ - assert check_argument_types() super().__init__() self._output_size = output_size @@ -349,6 +349,7 @@ class BaseEncoder(nn.Layer): class TransformerEncoder(BaseEncoder): """Transformer encoder module.""" + @typechecked def __init__( self, input_size: int, @@ -370,7 +371,6 @@ class TransformerEncoder(BaseEncoder): """ Construct TransformerEncoder See Encoder for the meaning of each parameter. """ - assert check_argument_types() super().__init__(input_size, output_size, attention_heads, linear_units, num_blocks, dropout_rate, positional_dropout_rate, attention_dropout_rate, input_layer, @@ -424,6 +424,7 @@ class TransformerEncoder(BaseEncoder): class ConformerEncoder(BaseEncoder): """Conformer encoder module.""" + @typechecked def __init__(self, input_size: int, output_size: int=256, @@ -466,8 +467,6 @@ class ConformerEncoder(BaseEncoder): causal (bool): whether to use causal convolution or not. cnn_module_norm (str): cnn conv norm type, Optional['batch_norm','layer_norm'] """ - assert check_argument_types() - super().__init__(input_size, output_size, attention_heads, linear_units, num_blocks, dropout_rate, positional_dropout_rate, attention_dropout_rate, input_layer, @@ -519,6 +518,7 @@ class ConformerEncoder(BaseEncoder): class SqueezeformerEncoder(nn.Layer): + @typechecked def __init__(self, input_size: int, encoder_dim: int=256, @@ -572,7 +572,6 @@ class SqueezeformerEncoder(nn.Layer): init_weights (bool): Whether to initialize weights. causal (bool): whether to use causal convolution or not. """ - assert check_argument_types() super().__init__() self.global_cmvn = global_cmvn self.reduce_idx: Optional[Union[int, List[int]]] = [reduce_idx] \ diff --git a/paddlespeech/s2t/training/scheduler.py b/paddlespeech/s2t/training/scheduler.py index a5e7a08f1..994b6f734 100644 --- a/paddlespeech/s2t/training/scheduler.py +++ b/paddlespeech/s2t/training/scheduler.py @@ -19,7 +19,7 @@ from typing import Union import paddle from paddle.optimizer.lr import LRScheduler -from typeguard import check_argument_types +from typeguard import typechecked from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.dynamic_import import instance_class @@ -57,13 +57,13 @@ class WarmupLR(LRScheduler): Note that the maximum lr equals to optimizer.lr in this scheduler. """ + @typechecked def __init__(self, warmup_steps: Union[int, float]=25000, learning_rate=1.0, last_epoch=-1, verbose=False, **kwargs): - assert check_argument_types() self.warmup_steps = warmup_steps super().__init__(learning_rate, last_epoch, verbose) diff --git a/paddlespeech/t2s/models/diffsinger/diffsinger.py b/paddlespeech/t2s/models/diffsinger/diffsinger.py index 990cfc56a..e489ff724 100644 --- a/paddlespeech/t2s/models/diffsinger/diffsinger.py +++ b/paddlespeech/t2s/models/diffsinger/diffsinger.py @@ -20,7 +20,7 @@ from typing import Tuple import numpy as np import paddle from paddle import nn -from typeguard import check_argument_types +from typeguard import typechecked from paddlespeech.t2s.models.diffsinger.fastspeech2midi import FastSpeech2MIDI from paddlespeech.t2s.modules.diffnet import DiffNet @@ -40,6 +40,7 @@ class DiffSinger(nn.Layer): """ + @typechecked def __init__( self, # min and max spec for stretching before diffusion @@ -157,7 +158,6 @@ class DiffSinger(nn.Layer): denoiser_params (Dict[str, Any]): Parameter dict for dinoiser module. diffusion_params (Dict[str, Any]): Parameter dict for diffusion module. """ - assert check_argument_types() super().__init__() self.fs2 = FastSpeech2MIDI( idim=idim, @@ -336,6 +336,7 @@ class DiffSingerInference(nn.Layer): class DiffusionLoss(nn.Layer): """Loss function module for Diffusion module on DiffSinger.""" + @typechecked def __init__(self, use_masking: bool=True, use_weighted_masking: bool=False): """Initialize feed-forward Transformer loss module. @@ -345,7 +346,6 @@ class DiffusionLoss(nn.Layer): use_weighted_masking (bool): Whether to weighted masking in loss calculation. """ - assert check_argument_types() super().__init__() assert (use_masking != use_weighted_masking) or not use_masking diff --git a/paddlespeech/t2s/models/diffsinger/fastspeech2midi.py b/paddlespeech/t2s/models/diffsinger/fastspeech2midi.py index cce88d8a0..3aff4c4e6 100644 --- a/paddlespeech/t2s/models/diffsinger/fastspeech2midi.py +++ b/paddlespeech/t2s/models/diffsinger/fastspeech2midi.py @@ -19,7 +19,7 @@ from typing import Tuple import paddle from paddle import nn -from typeguard import check_argument_types +from typeguard import typechecked from paddlespeech.t2s.models.fastspeech2 import FastSpeech2 from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Loss @@ -33,6 +33,7 @@ class FastSpeech2MIDI(FastSpeech2): """The Fastspeech2 module of DiffSinger. """ + @typechecked def __init__( self, # fastspeech2 network structure related @@ -57,7 +58,6 @@ class FastSpeech2MIDI(FastSpeech2): is_slur_ids will be provided as the input """ - assert check_argument_types() super().__init__(idim=idim, odim=odim, **fastspeech2_params) self.use_energy_pred = use_energy_pred self.use_postnet = use_postnet @@ -495,6 +495,7 @@ class FastSpeech2MIDI(FastSpeech2): class FastSpeech2MIDILoss(FastSpeech2Loss): """Loss function module for DiffSinger.""" + @typechecked def __init__(self, use_masking: bool=True, use_weighted_masking: bool=False): """Initialize feed-forward Transformer loss module. @@ -504,7 +505,6 @@ class FastSpeech2MIDILoss(FastSpeech2Loss): use_weighted_masking (bool): Whether to weighted masking in loss calculation. """ - assert check_argument_types() super().__init__(use_masking, use_weighted_masking) def forward( diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py index 91bfc540a..6fb65132d 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py @@ -15,6 +15,7 @@ """Fastspeech2 related modules for paddle""" from typing import Dict from typing import List +from typing import Optional from typing import Sequence from typing import Tuple from typing import Union @@ -23,7 +24,7 @@ import numpy as np import paddle import paddle.nn.functional as F from paddle import nn -from typeguard import check_argument_types +from typeguard import typechecked from paddlespeech.t2s.modules.adversarial_loss.gradient_reversal import GradientReversalLayer from paddlespeech.t2s.modules.adversarial_loss.speaker_classifier import SpeakerClassifier @@ -60,6 +61,7 @@ class FastSpeech2(nn.Layer): """ + @typechecked def __init__( self, # network structure related @@ -131,12 +133,12 @@ class FastSpeech2(nn.Layer): pitch_embed_dropout: float=0.5, stop_gradient_from_pitch_predictor: bool=False, # spk emb - spk_num: int=None, - spk_embed_dim: int=None, + spk_num: Optional[int]=None, + spk_embed_dim: Optional[int]=None, spk_embed_integration_type: str="add", # tone emb - tone_num: int=None, - tone_embed_dim: int=None, + tone_num: Optional[int]=None, + tone_embed_dim: Optional[int]=None, tone_embed_integration_type: str="add", # training related init_type: str="xavier_uniform", @@ -282,7 +284,6 @@ class FastSpeech2(nn.Layer): The hidden layer dim of speaker classifier """ - assert check_argument_types() super().__init__() # store hyperparameters @@ -1070,6 +1071,7 @@ class StyleFastSpeech2Inference(FastSpeech2Inference): class FastSpeech2Loss(nn.Layer): """Loss function module for FastSpeech2.""" + @typechecked def __init__(self, use_masking: bool=True, use_weighted_masking: bool=False): """Initialize feed-forward Transformer loss module. @@ -1079,7 +1081,6 @@ class FastSpeech2Loss(nn.Layer): use_weighted_masking (bool): Whether to weighted masking in loss calculation. """ - assert check_argument_types() super().__init__() assert (use_masking != use_weighted_masking) or not use_masking diff --git a/paddlespeech/t2s/models/jets/generator.py b/paddlespeech/t2s/models/jets/generator.py index 1b8e0ce6e..d2bd4102d 100644 --- a/paddlespeech/t2s/models/jets/generator.py +++ b/paddlespeech/t2s/models/jets/generator.py @@ -28,7 +28,7 @@ from typing import Tuple import numpy as np import paddle from paddle import nn -from typeguard import check_argument_types +from typeguard import typechecked from paddlespeech.t2s.models.hifigan import HiFiGANGenerator from paddlespeech.t2s.models.jets.alignments import AlignmentModule diff --git a/paddlespeech/t2s/models/jets/jets.py b/paddlespeech/t2s/models/jets/jets.py index 4346c65b4..9c02da6b5 100644 --- a/paddlespeech/t2s/models/jets/jets.py +++ b/paddlespeech/t2s/models/jets/jets.py @@ -24,7 +24,7 @@ from typing import Optional import paddle from paddle import nn -from typeguard import check_argument_types +from typeguard import typechecked from paddlespeech.t2s.models.hifigan import HiFiGANMultiPeriodDiscriminator from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleDiscriminator @@ -64,6 +64,7 @@ class JETS(nn.Layer): Text-to-Speech`: https://arxiv.org/abs/2203.16852v1 """ + @typechecked def __init__( self, # generator related @@ -225,7 +226,6 @@ class JETS(nn.Layer): cache_generator_outputs (bool): Whether to cache generator outputs. """ - assert check_argument_types() super().__init__() # define modules @@ -279,8 +279,7 @@ class JETS(nn.Layer): lids: Optional[paddle.Tensor]=None, forward_generator: bool=True, use_alignment_module: bool=False, - **kwargs, - ) -> Dict[str, Any]: + **kwargs, ) -> Dict[str, Any]: """Perform generator forward. Args: text (Tensor): diff --git a/paddlespeech/t2s/models/tacotron2/tacotron2.py b/paddlespeech/t2s/models/tacotron2/tacotron2.py index 25b5c932a..404d1fa1c 100644 --- a/paddlespeech/t2s/models/tacotron2/tacotron2.py +++ b/paddlespeech/t2s/models/tacotron2/tacotron2.py @@ -21,7 +21,7 @@ from typing import Tuple import paddle import paddle.nn.functional as F from paddle import nn -from typeguard import check_argument_types +from typeguard import typechecked from paddlespeech.t2s.modules.nets_utils import initialize from paddlespeech.t2s.modules.nets_utils import make_pad_mask @@ -44,6 +44,7 @@ class Tacotron2(nn.Layer): """ + @typechecked def __init__( self, # network structure related @@ -145,7 +146,6 @@ class Tacotron2(nn.Layer): zoneout_rate (float): Zoneout rate. """ - assert check_argument_types() super().__init__() # store hyperparameters diff --git a/paddlespeech/t2s/models/transformer_tts/transformer_tts.py b/paddlespeech/t2s/models/transformer_tts/transformer_tts.py index 355fceb16..80d8a60da 100644 --- a/paddlespeech/t2s/models/transformer_tts/transformer_tts.py +++ b/paddlespeech/t2s/models/transformer_tts/transformer_tts.py @@ -21,7 +21,7 @@ import numpy import paddle import paddle.nn.functional as F from paddle import nn -from typeguard import check_argument_types +from typeguard import typechecked from paddlespeech.t2s.modules.nets_utils import initialize from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask @@ -169,6 +169,7 @@ class TransformerTTS(nn.Layer): Number of layers to apply guided attention loss. """ + @typechecked def __init__( self, # network structure related @@ -227,7 +228,7 @@ class TransformerTTS(nn.Layer): num_heads_applied_guided_attn: int=2, num_layers_applied_guided_attn: int=2, ): """Initialize Transformer module.""" - assert check_argument_types() + super().__init__() # store hyperparameters diff --git a/paddlespeech/t2s/models/vits/vits.py b/paddlespeech/t2s/models/vits/vits.py index 7013e06c0..e92e78676 100644 --- a/paddlespeech/t2s/models/vits/vits.py +++ b/paddlespeech/t2s/models/vits/vits.py @@ -20,7 +20,7 @@ from typing import Optional import paddle from paddle import nn -from typeguard import check_argument_types +from typeguard import typechecked from paddlespeech.t2s.models.hifigan import HiFiGANMultiPeriodDiscriminator from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleDiscriminator @@ -60,6 +60,7 @@ class VITS(nn.Layer): Text-to-Speech`: https://arxiv.org/abs/2006.04558 """ + @typechecked def __init__( self, # generator related @@ -181,7 +182,6 @@ class VITS(nn.Layer): cache_generator_outputs (bool): Whether to cache generator outputs. """ - assert check_argument_types() super().__init__() # define modules @@ -504,8 +504,9 @@ class VITS(nn.Layer): def reset_parameters(self): def _reset_parameters(module): - if isinstance(module, - (nn.Conv1D, nn.Conv1DTranspose, nn.Conv2D, nn.Conv2DTranspose)): + if isinstance( + module, + (nn.Conv1D, nn.Conv1DTranspose, nn.Conv2D, nn.Conv2DTranspose)): kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) @@ -513,8 +514,9 @@ class VITS(nn.Layer): bound = 1 / math.sqrt(fan_in) uniform_(module.bias, -bound, bound) - if isinstance(module, - (nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm, nn.LayerNorm)): + if isinstance( + module, + (nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm, nn.LayerNorm)): ones_(module.weight) zeros_(module.bias) @@ -533,13 +535,13 @@ class VITS(nn.Layer): self.apply(_reset_parameters) + class VITSInference(nn.Layer): def __init__(self, model): super().__init__() self.acoustic_model = model def forward(self, text, sids=None): - out = self.acoustic_model.inference( - text, sids=sids) + out = self.acoustic_model.inference(text, sids=sids) wav = out['wav'] return wav diff --git a/paddlespeech/t2s/modules/adversarial_loss/speaker_classifier.py b/paddlespeech/t2s/modules/adversarial_loss/speaker_classifier.py index d731b2d27..663a76ffe 100644 --- a/paddlespeech/t2s/modules/adversarial_loss/speaker_classifier.py +++ b/paddlespeech/t2s/modules/adversarial_loss/speaker_classifier.py @@ -14,16 +14,16 @@ # Modified from Cross-Lingual-Voice-Cloning(https://github.com/deterministic-algorithms-lab/Cross-Lingual-Voice-Cloning) import paddle from paddle import nn -from typeguard import check_argument_types +from typeguard import typechecked class SpeakerClassifier(nn.Layer): + @typechecked def __init__( self, idim: int, hidden_sc_dim: int, spk_num: int, ): - assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim diff --git a/paddlespeech/t2s/modules/losses.py b/paddlespeech/t2s/modules/losses.py index e675dcab7..f819352d6 100644 --- a/paddlespeech/t2s/modules/losses.py +++ b/paddlespeech/t2s/modules/losses.py @@ -21,7 +21,7 @@ from paddle import nn from paddle.nn import functional as F from scipy import signal from scipy.stats import betabinom -from typeguard import check_argument_types +from typeguard import typechecked from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask from paddlespeech.t2s.modules.predictor.duration_predictor import ( @@ -1137,6 +1137,7 @@ class MLMLoss(nn.Layer): class VarianceLoss(nn.Layer): + @typechecked def __init__(self, use_masking: bool=True, use_weighted_masking: bool=False): """Initialize JETS variance loss module. @@ -1147,7 +1148,6 @@ class VarianceLoss(nn.Layer): calculation. """ - assert check_argument_types() super().__init__() assert (use_masking != use_weighted_masking) or not use_masking diff --git a/paddlespeech/t2s/modules/nets_utils.py b/paddlespeech/t2s/modules/nets_utils.py index 4c86d74f5..7a3f52fe6 100644 --- a/paddlespeech/t2s/modules/nets_utils.py +++ b/paddlespeech/t2s/modules/nets_utils.py @@ -18,7 +18,7 @@ from typing import Tuple import numpy as np import paddle from paddle import nn -from typeguard import check_argument_types +from typeguard import typechecked from paddlespeech.utils.initialize import _calculate_fan_in_and_fan_out from paddlespeech.utils.initialize import kaiming_uniform_ @@ -301,6 +301,7 @@ def make_non_pad_mask(lengths, xs=None, length_dim=-1): return paddle.logical_not(make_pad_mask(lengths, xs, length_dim)) +@typechecked def initialize(model: nn.Layer, init: str): """Initialize weights of a neural network module. @@ -314,8 +315,6 @@ def initialize(model: nn.Layer, init: str): init (str): Method of initialization. """ - assert check_argument_types() - if init == "xavier_uniform": nn.initializer.set_global_initializer(nn.initializer.XavierUniform(), nn.initializer.Constant()) diff --git a/paddlespeech/t2s/modules/predictor/variance_predictor.py b/paddlespeech/t2s/modules/predictor/variance_predictor.py index 197f73595..4b79b3913 100644 --- a/paddlespeech/t2s/modules/predictor/variance_predictor.py +++ b/paddlespeech/t2s/modules/predictor/variance_predictor.py @@ -15,7 +15,7 @@ """Variance predictor related modules.""" import paddle from paddle import nn -from typeguard import check_argument_types +from typeguard import typechecked from paddlespeech.t2s.modules.layer_norm import LayerNorm from paddlespeech.t2s.modules.masked_fill import masked_fill @@ -32,6 +32,7 @@ class VariancePredictor(nn.Layer): """ + @typechecked def __init__( self, idim: int, @@ -54,7 +55,6 @@ class VariancePredictor(nn.Layer): dropout_rate (float, optional): Dropout rate. """ - assert check_argument_types() super().__init__() self.conv = nn.LayerList() for idx in range(n_layers): @@ -96,7 +96,7 @@ class VariancePredictor(nn.Layer): xs = f(xs) # (B, Tmax, 1) xs = self.linear(xs.transpose([0, 2, 1])) - + if x_masks is not None: xs = masked_fill(xs, x_masks, 0.0) return xs diff --git a/paddlespeech/t2s/modules/style_encoder.py b/paddlespeech/t2s/modules/style_encoder.py index b558e7693..ad86a5449 100644 --- a/paddlespeech/t2s/modules/style_encoder.py +++ b/paddlespeech/t2s/modules/style_encoder.py @@ -17,7 +17,7 @@ from typing import Sequence import paddle from paddle import nn -from typeguard import check_argument_types +from typeguard import typechecked from paddlespeech.t2s.modules.transformer.attention import MultiHeadedAttention as BaseMultiHeadedAttention @@ -58,6 +58,7 @@ class StyleEncoder(nn.Layer): """ + @typechecked def __init__( self, idim: int=80, @@ -71,7 +72,6 @@ class StyleEncoder(nn.Layer): gru_layers: int=1, gru_units: int=128, ): """Initilize global style encoder module.""" - assert check_argument_types() super().__init__() self.ref_enc = ReferenceEncoder( @@ -132,6 +132,7 @@ class ReferenceEncoder(nn.Layer): """ + @typechecked def __init__( self, idim=80, @@ -142,7 +143,6 @@ class ReferenceEncoder(nn.Layer): gru_layers: int=1, gru_units: int=128, ): """Initilize reference encoder module.""" - assert check_argument_types() super().__init__() # check hyperparameters are valid @@ -232,6 +232,7 @@ class StyleTokenLayer(nn.Layer): """ + @typechecked def __init__( self, ref_embed_dim: int=128, @@ -240,7 +241,6 @@ class StyleTokenLayer(nn.Layer): gst_heads: int=4, dropout_rate: float=0.0, ): """Initilize style token layer module.""" - assert check_argument_types() super().__init__() gst_embs = paddle.randn(shape=[gst_tokens, gst_token_dim // gst_heads]) diff --git a/setup.py b/setup.py index b15dedbac..c9e18f685 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ base = [ "matplotlib<=3.8.4", "nara_wpe", "onnxruntime>=1.11.0", - "opencc==1.1.6", + "opencc", "opencc-python-reimplemented", "pandas", "paddleaudio>=1.1.0", @@ -69,8 +69,8 @@ base = [ "soundfile", "textgrid", "timer", - "ToJyutping==0.2.1", - "typeguard==2.13.3", + "ToJyutping", + "typeguard", "webrtcvad", "yacs~=0.1.8", "zhon",