pull/930/head
Hui Zhang 3 years ago
parent e76d51fda0
commit 12ea02fc48

@ -22,9 +22,9 @@ from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = ["NonePositionalEncoding", "PositionalEncoding", "RelPositionalEncoding"] __all__ = ["NoPositionalEncoding", "PositionalEncoding", "RelPositionalEncoding"]
class NonePositionalEncoding(nn.Layer): class NoPositionalEncoding(nn.Layer):
def __init__(self, def __init__(self,
d_model: int, d_model: int,
dropout_rate: float, dropout_rate: float,

@ -26,7 +26,7 @@ from deepspeech.modules.attention import RelPositionMultiHeadedAttention
from deepspeech.modules.conformer_convolution import ConvolutionModule from deepspeech.modules.conformer_convolution import ConvolutionModule
from deepspeech.modules.embedding import PositionalEncoding from deepspeech.modules.embedding import PositionalEncoding
from deepspeech.modules.embedding import RelPositionalEncoding from deepspeech.modules.embedding import RelPositionalEncoding
from deepspeech.modules.embedding import NonePositionalEncoding from deepspeech.modules.embedding import NoPositionalEncoding
from deepspeech.modules.encoder_layer import ConformerEncoderLayer from deepspeech.modules.encoder_layer import ConformerEncoderLayer
from deepspeech.modules.encoder_layer import TransformerEncoderLayer from deepspeech.modules.encoder_layer import TransformerEncoderLayer
from deepspeech.modules.mask import add_optional_chunk_mask from deepspeech.modules.mask import add_optional_chunk_mask
@ -56,7 +56,7 @@ class BaseEncoder(nn.Layer):
positional_dropout_rate: float=0.1, positional_dropout_rate: float=0.1,
attention_dropout_rate: float=0.0, attention_dropout_rate: float=0.0,
input_layer: str="conv2d", input_layer: str="conv2d",
pos_enc_layer_type: Optional[str, None]="abs_pos", pos_enc_layer_type: str="abs_pos",
normalize_before: bool=True, normalize_before: bool=True,
concat_after: bool=False, concat_after: bool=False,
static_chunk_size: int=0, static_chunk_size: int=0,
@ -77,8 +77,8 @@ class BaseEncoder(nn.Layer):
positional encoding positional encoding
input_layer (str): input layer type. input_layer (str): input layer type.
optional [linear, conv2d, conv2d6, conv2d8] optional [linear, conv2d, conv2d6, conv2d8]
pos_enc_layer_type (str, or None): Encoder positional encoding layer type. pos_enc_layer_type (str): Encoder positional encoding layer type.
opitonal [abs_pos, scaled_abs_pos, rel_pos, None] opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
normalize_before (bool): normalize_before (bool):
True: use layer_norm before each sub-block of a layer. True: use layer_norm before each sub-block of a layer.
False: use layer_norm after each sub-block of a layer. False: use layer_norm after each sub-block of a layer.
@ -103,8 +103,8 @@ class BaseEncoder(nn.Layer):
pos_enc_class = PositionalEncoding pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "rel_pos": elif pos_enc_layer_type == "rel_pos":
pos_enc_class = RelPositionalEncoding pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type is None: elif pos_enc_layer_type is "no_pos":
pos_enc_class = NonePositionalEncoding pos_enc_class = NoPositionalEncoding
else: else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)

Loading…
Cancel
Save