You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
580 lines
23 KiB
580 lines
23 KiB
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import argparse
|
|
from typing import Dict
|
|
from typing import List
|
|
from typing import Optional
|
|
|
|
import paddle
|
|
import yaml
|
|
from paddle import nn
|
|
from yacs.config import CfgNode
|
|
|
|
from paddlespeech.t2s.modules.activation import get_activation
|
|
from paddlespeech.t2s.modules.conformer.convolution import ConvolutionModule
|
|
from paddlespeech.t2s.modules.conformer.encoder_layer import EncoderLayer
|
|
from paddlespeech.t2s.modules.layer_norm import LayerNorm
|
|
from paddlespeech.t2s.modules.masked_fill import masked_fill
|
|
from paddlespeech.t2s.modules.nets_utils import initialize
|
|
from paddlespeech.t2s.modules.tacotron2.decoder import Postnet
|
|
from paddlespeech.t2s.modules.transformer.attention import LegacyRelPositionMultiHeadedAttention
|
|
from paddlespeech.t2s.modules.transformer.attention import MultiHeadedAttention
|
|
from paddlespeech.t2s.modules.transformer.attention import RelPositionMultiHeadedAttention
|
|
from paddlespeech.t2s.modules.transformer.embedding import LegacyRelPositionalEncoding
|
|
from paddlespeech.t2s.modules.transformer.embedding import PositionalEncoding
|
|
from paddlespeech.t2s.modules.transformer.embedding import RelPositionalEncoding
|
|
from paddlespeech.t2s.modules.transformer.embedding import ScaledPositionalEncoding
|
|
from paddlespeech.t2s.modules.transformer.multi_layer_conv import Conv1dLinear
|
|
from paddlespeech.t2s.modules.transformer.multi_layer_conv import MultiLayeredConv1d
|
|
from paddlespeech.t2s.modules.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
|
from paddlespeech.t2s.modules.transformer.repeat import repeat
|
|
from paddlespeech.t2s.modules.transformer.subsampling import Conv2dSubsampling
|
|
|
|
|
|
# MLM -> Mask Language Model
|
|
class mySequential(nn.Sequential):
|
|
def forward(self, *inputs):
|
|
for module in self._sub_layers.values():
|
|
if type(inputs) == tuple:
|
|
inputs = module(*inputs)
|
|
else:
|
|
inputs = module(inputs)
|
|
return inputs
|
|
|
|
|
|
class MaskInputLayer(nn.Layer):
|
|
def __init__(self, out_features: int) -> None:
|
|
super().__init__()
|
|
self.mask_feature = paddle.create_parameter(
|
|
shape=(1, 1, out_features),
|
|
dtype=paddle.float32,
|
|
default_initializer=paddle.nn.initializer.Assign(
|
|
paddle.normal(shape=(1, 1, out_features))))
|
|
|
|
def forward(self, input: paddle.Tensor,
|
|
masked_pos: paddle.Tensor=None) -> paddle.Tensor:
|
|
masked_pos = paddle.expand_as(paddle.unsqueeze(masked_pos, -1), input)
|
|
masked_input = masked_fill(input, masked_pos, 0) + masked_fill(
|
|
paddle.expand_as(self.mask_feature, input), ~masked_pos, 0)
|
|
return masked_input
|
|
|
|
|
|
class MLMEncoder(nn.Layer):
|
|
"""Conformer encoder module.
|
|
|
|
Args:
|
|
idim (int): Input dimension.
|
|
attention_dim (int): Dimension of attention.
|
|
attention_heads (int): The number of heads of multi head attention.
|
|
linear_units (int): The number of units of position-wise feed forward.
|
|
num_blocks (int): The number of decoder blocks.
|
|
dropout_rate (float): Dropout rate.
|
|
positional_dropout_rate (float): Dropout rate after adding positional encoding.
|
|
attention_dropout_rate (float): Dropout rate in attention.
|
|
input_layer (Union[str, paddle.nn.Layer]): Input layer type.
|
|
normalize_before (bool): Whether to use layer_norm before the first block.
|
|
concat_after (bool): Whether to concat attention layer's input and output.
|
|
if True, additional linear will be applied.
|
|
i.e. x -> x + linear(concat(x, att(x)))
|
|
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
|
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
|
|
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
|
|
macaron_style (bool): Whether to use macaron style for positionwise layer.
|
|
pos_enc_layer_type (str): Encoder positional encoding layer type.
|
|
selfattention_layer_type (str): Encoder attention layer type.
|
|
activation_type (str): Encoder activation function type.
|
|
use_cnn_module (bool): Whether to use convolution module.
|
|
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
|
|
cnn_module_kernel (int): Kernerl size of convolution module.
|
|
padding_idx (int): Padding idx for input_layer=embed.
|
|
stochastic_depth_rate (float): Maximum probability to skip the encoder layer.
|
|
|
|
"""
|
|
|
|
def __init__(self,
|
|
idim: int,
|
|
vocab_size: int=0,
|
|
pre_speech_layer: int=0,
|
|
attention_dim: int=256,
|
|
attention_heads: int=4,
|
|
linear_units: int=2048,
|
|
num_blocks: int=6,
|
|
dropout_rate: float=0.1,
|
|
positional_dropout_rate: float=0.1,
|
|
attention_dropout_rate: float=0.0,
|
|
input_layer: str="conv2d",
|
|
normalize_before: bool=True,
|
|
concat_after: bool=False,
|
|
positionwise_layer_type: str="linear",
|
|
positionwise_conv_kernel_size: int=1,
|
|
macaron_style: bool=False,
|
|
pos_enc_layer_type: str="abs_pos",
|
|
selfattention_layer_type: str="selfattn",
|
|
activation_type: str="swish",
|
|
use_cnn_module: bool=False,
|
|
zero_triu: bool=False,
|
|
cnn_module_kernel: int=31,
|
|
padding_idx: int=-1,
|
|
stochastic_depth_rate: float=0.0,
|
|
text_masking: bool=False):
|
|
"""Construct an Encoder object."""
|
|
super().__init__()
|
|
self._output_size = attention_dim
|
|
self.text_masking = text_masking
|
|
if self.text_masking:
|
|
self.text_masking_layer = MaskInputLayer(attention_dim)
|
|
activation = get_activation(activation_type)
|
|
if pos_enc_layer_type == "abs_pos":
|
|
pos_enc_class = PositionalEncoding
|
|
elif pos_enc_layer_type == "scaled_abs_pos":
|
|
pos_enc_class = ScaledPositionalEncoding
|
|
elif pos_enc_layer_type == "rel_pos":
|
|
assert selfattention_layer_type == "rel_selfattn"
|
|
pos_enc_class = RelPositionalEncoding
|
|
elif pos_enc_layer_type == "legacy_rel_pos":
|
|
pos_enc_class = LegacyRelPositionalEncoding
|
|
assert selfattention_layer_type == "legacy_rel_selfattn"
|
|
else:
|
|
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
|
|
|
|
self.conv_subsampling_factor = 1
|
|
if input_layer == "linear":
|
|
self.embed = nn.Sequential(
|
|
nn.Linear(idim, attention_dim),
|
|
nn.LayerNorm(attention_dim),
|
|
nn.Dropout(dropout_rate),
|
|
nn.ReLU(),
|
|
pos_enc_class(attention_dim, positional_dropout_rate), )
|
|
elif input_layer == "conv2d":
|
|
self.embed = Conv2dSubsampling(
|
|
idim,
|
|
attention_dim,
|
|
dropout_rate,
|
|
pos_enc_class(attention_dim, positional_dropout_rate), )
|
|
self.conv_subsampling_factor = 4
|
|
elif input_layer == "embed":
|
|
self.embed = nn.Sequential(
|
|
nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
|
|
pos_enc_class(attention_dim, positional_dropout_rate), )
|
|
elif input_layer == "mlm":
|
|
self.segment_emb = None
|
|
self.speech_embed = mySequential(
|
|
MaskInputLayer(idim),
|
|
nn.Linear(idim, attention_dim),
|
|
nn.LayerNorm(attention_dim),
|
|
nn.ReLU(),
|
|
pos_enc_class(attention_dim, positional_dropout_rate))
|
|
self.text_embed = nn.Sequential(
|
|
nn.Embedding(
|
|
vocab_size, attention_dim, padding_idx=padding_idx),
|
|
pos_enc_class(attention_dim, positional_dropout_rate), )
|
|
elif input_layer == "sega_mlm":
|
|
self.segment_emb = nn.Embedding(
|
|
500, attention_dim, padding_idx=padding_idx)
|
|
self.speech_embed = mySequential(
|
|
MaskInputLayer(idim),
|
|
nn.Linear(idim, attention_dim),
|
|
nn.LayerNorm(attention_dim),
|
|
nn.ReLU(),
|
|
pos_enc_class(attention_dim, positional_dropout_rate))
|
|
self.text_embed = nn.Sequential(
|
|
nn.Embedding(
|
|
vocab_size, attention_dim, padding_idx=padding_idx),
|
|
pos_enc_class(attention_dim, positional_dropout_rate), )
|
|
elif isinstance(input_layer, nn.Layer):
|
|
self.embed = nn.Sequential(
|
|
input_layer,
|
|
pos_enc_class(attention_dim, positional_dropout_rate), )
|
|
elif input_layer is None:
|
|
self.embed = nn.Sequential(
|
|
pos_enc_class(attention_dim, positional_dropout_rate))
|
|
else:
|
|
raise ValueError("unknown input_layer: " + input_layer)
|
|
self.normalize_before = normalize_before
|
|
|
|
# self-attention module definition
|
|
if selfattention_layer_type == "selfattn":
|
|
encoder_selfattn_layer = MultiHeadedAttention
|
|
encoder_selfattn_layer_args = (attention_heads, attention_dim,
|
|
attention_dropout_rate, )
|
|
elif selfattention_layer_type == "legacy_rel_selfattn":
|
|
assert pos_enc_layer_type == "legacy_rel_pos"
|
|
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
|
|
encoder_selfattn_layer_args = (attention_heads, attention_dim,
|
|
attention_dropout_rate, )
|
|
elif selfattention_layer_type == "rel_selfattn":
|
|
assert pos_enc_layer_type == "rel_pos"
|
|
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
|
encoder_selfattn_layer_args = (attention_heads, attention_dim,
|
|
attention_dropout_rate, zero_triu, )
|
|
else:
|
|
raise ValueError("unknown encoder_attn_layer: " +
|
|
selfattention_layer_type)
|
|
|
|
# feed-forward module definition
|
|
if positionwise_layer_type == "linear":
|
|
positionwise_layer = PositionwiseFeedForward
|
|
positionwise_layer_args = (attention_dim, linear_units,
|
|
dropout_rate, activation, )
|
|
elif positionwise_layer_type == "conv1d":
|
|
positionwise_layer = MultiLayeredConv1d
|
|
positionwise_layer_args = (attention_dim, linear_units,
|
|
positionwise_conv_kernel_size,
|
|
dropout_rate, )
|
|
elif positionwise_layer_type == "conv1d-linear":
|
|
positionwise_layer = Conv1dLinear
|
|
positionwise_layer_args = (attention_dim, linear_units,
|
|
positionwise_conv_kernel_size,
|
|
dropout_rate, )
|
|
else:
|
|
raise NotImplementedError("Support only linear or conv1d.")
|
|
|
|
# convolution module definition
|
|
convolution_layer = ConvolutionModule
|
|
convolution_layer_args = (attention_dim, cnn_module_kernel, activation)
|
|
|
|
self.encoders = repeat(
|
|
num_blocks,
|
|
lambda lnum: EncoderLayer(
|
|
attention_dim,
|
|
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
|
positionwise_layer(*positionwise_layer_args),
|
|
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
|
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
|
dropout_rate,
|
|
normalize_before,
|
|
concat_after,
|
|
stochastic_depth_rate * float(1 + lnum) / num_blocks, ), )
|
|
self.pre_speech_layer = pre_speech_layer
|
|
self.pre_speech_encoders = repeat(
|
|
self.pre_speech_layer,
|
|
lambda lnum: EncoderLayer(
|
|
attention_dim,
|
|
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
|
positionwise_layer(*positionwise_layer_args),
|
|
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
|
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
|
dropout_rate,
|
|
normalize_before,
|
|
concat_after,
|
|
stochastic_depth_rate * float(1 + lnum) / self.pre_speech_layer, ),
|
|
)
|
|
if self.normalize_before:
|
|
self.after_norm = LayerNorm(attention_dim)
|
|
|
|
def forward(self,
|
|
speech: paddle.Tensor,
|
|
text: paddle.Tensor,
|
|
masked_pos: paddle.Tensor,
|
|
speech_mask: paddle.Tensor=None,
|
|
text_mask: paddle.Tensor=None,
|
|
speech_seg_pos: paddle.Tensor=None,
|
|
text_seg_pos: paddle.Tensor=None):
|
|
"""Encode input sequence.
|
|
|
|
"""
|
|
if masked_pos is not None:
|
|
speech = self.speech_embed(speech, masked_pos)
|
|
else:
|
|
speech = self.speech_embed(speech)
|
|
if text is not None:
|
|
text = self.text_embed(text)
|
|
if speech_seg_pos is not None and text_seg_pos is not None and self.segment_emb:
|
|
speech_seg_emb = self.segment_emb(speech_seg_pos)
|
|
text_seg_emb = self.segment_emb(text_seg_pos)
|
|
text = (text[0] + text_seg_emb, text[1])
|
|
speech = (speech[0] + speech_seg_emb, speech[1])
|
|
if self.pre_speech_encoders:
|
|
speech, _ = self.pre_speech_encoders(speech, speech_mask)
|
|
|
|
if text is not None:
|
|
xs = paddle.concat([speech[0], text[0]], axis=1)
|
|
xs_pos_emb = paddle.concat([speech[1], text[1]], axis=1)
|
|
masks = paddle.concat([speech_mask, text_mask], axis=-1)
|
|
else:
|
|
xs = speech[0]
|
|
xs_pos_emb = speech[1]
|
|
masks = speech_mask
|
|
|
|
xs, masks = self.encoders((xs, xs_pos_emb), masks)
|
|
|
|
if isinstance(xs, tuple):
|
|
xs = xs[0]
|
|
if self.normalize_before:
|
|
xs = self.after_norm(xs)
|
|
|
|
return xs, masks
|
|
|
|
|
|
class MLMDecoder(MLMEncoder):
|
|
def forward(self, xs: paddle.Tensor, masks: paddle.Tensor):
|
|
"""Encode input sequence.
|
|
|
|
Args:
|
|
xs (paddle.Tensor): Input tensor (#batch, time, idim).
|
|
masks (paddle.Tensor): Mask tensor (#batch, time).
|
|
|
|
Returns:
|
|
paddle.Tensor: Output tensor (#batch, time, attention_dim).
|
|
paddle.Tensor: Mask tensor (#batch, time).
|
|
|
|
"""
|
|
xs = self.embed(xs)
|
|
xs, masks = self.encoders(xs, masks)
|
|
|
|
if isinstance(xs, tuple):
|
|
xs = xs[0]
|
|
if self.normalize_before:
|
|
xs = self.after_norm(xs)
|
|
|
|
return xs, masks
|
|
|
|
|
|
# encoder and decoder is nn.Layer, not str
|
|
class MLM(nn.Layer):
|
|
def __init__(self,
|
|
odim: int,
|
|
encoder: nn.Layer,
|
|
decoder: Optional[nn.Layer],
|
|
postnet_layers: int=0,
|
|
postnet_chans: int=0,
|
|
postnet_filts: int=0,
|
|
text_masking: bool=False):
|
|
|
|
super().__init__()
|
|
self.odim = odim
|
|
self.encoder = encoder
|
|
self.decoder = decoder
|
|
self.vocab_size = encoder.text_embed[0]._num_embeddings
|
|
|
|
if self.decoder is None or not (hasattr(self.decoder,
|
|
'output_layer') and
|
|
self.decoder.output_layer is not None):
|
|
self.sfc = nn.Linear(self.encoder._output_size, odim)
|
|
else:
|
|
self.sfc = None
|
|
if text_masking:
|
|
self.text_sfc = nn.Linear(
|
|
self.encoder.text_embed[0]._embedding_dim,
|
|
self.vocab_size,
|
|
weight_attr=self.encoder.text_embed[0]._weight_attr)
|
|
else:
|
|
self.text_sfc = None
|
|
|
|
self.postnet = (None if postnet_layers == 0 else Postnet(
|
|
idim=self.encoder._output_size,
|
|
odim=odim,
|
|
n_layers=postnet_layers,
|
|
n_chans=postnet_chans,
|
|
n_filts=postnet_filts,
|
|
use_batch_norm=True,
|
|
dropout_rate=0.5, ))
|
|
|
|
def inference(
|
|
self,
|
|
speech: paddle.Tensor,
|
|
text: paddle.Tensor,
|
|
masked_pos: paddle.Tensor,
|
|
speech_mask: paddle.Tensor,
|
|
text_mask: paddle.Tensor,
|
|
speech_seg_pos: paddle.Tensor,
|
|
text_seg_pos: paddle.Tensor,
|
|
span_bdy: List[int],
|
|
use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]:
|
|
'''
|
|
Args:
|
|
speech (paddle.Tensor): input speech (1, Tmax, D).
|
|
text (paddle.Tensor): input text (1, Tmax2).
|
|
masked_pos (paddle.Tensor): masked position of input speech (1, Tmax)
|
|
speech_mask (paddle.Tensor): mask of speech (1, 1, Tmax).
|
|
text_mask (paddle.Tensor): mask of text (1, 1, Tmax2).
|
|
speech_seg_pos (paddle.Tensor): n-th phone of each mel, 0<=n<=Tmax2 (1, Tmax).
|
|
text_seg_pos (paddle.Tensor): n-th phone of each phone, 0<=n<=Tmax2 (1, Tmax2).
|
|
span_bdy (List[int]): masked mel boundary of input speech (2,)
|
|
use_teacher_forcing (bool): whether to use teacher forcing
|
|
Returns:
|
|
List[Tensor]:
|
|
eg:
|
|
[Tensor(shape=[1, 181, 80]), Tensor(shape=[80, 80]), Tensor(shape=[1, 67, 80])]
|
|
'''
|
|
|
|
z_cache = None
|
|
if use_teacher_forcing:
|
|
before_outs, zs, *_ = self.forward(
|
|
speech=speech,
|
|
text=text,
|
|
masked_pos=masked_pos,
|
|
speech_mask=speech_mask,
|
|
text_mask=text_mask,
|
|
speech_seg_pos=speech_seg_pos,
|
|
text_seg_pos=text_seg_pos)
|
|
if zs is None:
|
|
zs = before_outs
|
|
|
|
speech = speech.squeeze(0)
|
|
outs = [speech[:span_bdy[0]]]
|
|
outs += [zs[0][span_bdy[0]:span_bdy[1]]]
|
|
outs += [speech[span_bdy[1]:]]
|
|
return outs
|
|
return None
|
|
|
|
|
|
class MLMEncAsDecoder(MLM):
|
|
def forward(self,
|
|
speech: paddle.Tensor,
|
|
text: paddle.Tensor,
|
|
masked_pos: paddle.Tensor,
|
|
speech_mask: paddle.Tensor,
|
|
text_mask: paddle.Tensor,
|
|
speech_seg_pos: paddle.Tensor,
|
|
text_seg_pos: paddle.Tensor):
|
|
# feats: (Batch, Length, Dim)
|
|
# -> encoder_out: (Batch, Length2, Dim2)
|
|
encoder_out, h_masks = self.encoder(
|
|
speech=speech,
|
|
text=text,
|
|
masked_pos=masked_pos,
|
|
speech_mask=speech_mask,
|
|
text_mask=text_mask,
|
|
speech_seg_pos=speech_seg_pos,
|
|
text_seg_pos=text_seg_pos)
|
|
if self.decoder is not None:
|
|
zs, _ = self.decoder(encoder_out, h_masks)
|
|
else:
|
|
zs = encoder_out
|
|
speech_hidden_states = zs[:, :paddle.shape(speech)[1], :]
|
|
if self.sfc is not None:
|
|
before_outs = paddle.reshape(
|
|
self.sfc(speech_hidden_states),
|
|
(paddle.shape(speech_hidden_states)[0], -1, self.odim))
|
|
else:
|
|
before_outs = speech_hidden_states
|
|
if self.postnet is not None:
|
|
after_outs = before_outs + paddle.transpose(
|
|
self.postnet(paddle.transpose(before_outs, [0, 2, 1])),
|
|
[0, 2, 1])
|
|
else:
|
|
after_outs = None
|
|
return before_outs, after_outs, None
|
|
|
|
|
|
class MLMDualMaksing(MLM):
|
|
def forward(self,
|
|
speech: paddle.Tensor,
|
|
text: paddle.Tensor,
|
|
masked_pos: paddle.Tensor,
|
|
speech_mask: paddle.Tensor,
|
|
text_mask: paddle.Tensor,
|
|
speech_seg_pos: paddle.Tensor,
|
|
text_seg_pos: paddle.Tensor):
|
|
# feats: (Batch, Length, Dim)
|
|
# -> encoder_out: (Batch, Length2, Dim2)
|
|
encoder_out, h_masks = self.encoder(
|
|
speech=speech,
|
|
text=text,
|
|
masked_pos=masked_pos,
|
|
speech_mask=speech_mask,
|
|
text_mask=text_mask,
|
|
speech_seg_pos=speech_seg_pos,
|
|
text_seg_pos=text_seg_pos)
|
|
if self.decoder is not None:
|
|
zs, _ = self.decoder(encoder_out, h_masks)
|
|
else:
|
|
zs = encoder_out
|
|
speech_hidden_states = zs[:, :paddle.shape(speech)[1], :]
|
|
if self.text_sfc:
|
|
text_hiddent_states = zs[:, paddle.shape(speech)[1]:, :]
|
|
text_outs = paddle.reshape(
|
|
self.text_sfc(text_hiddent_states),
|
|
(paddle.shape(text_hiddent_states)[0], -1, self.vocab_size))
|
|
if self.sfc is not None:
|
|
before_outs = paddle.reshape(
|
|
self.sfc(speech_hidden_states),
|
|
(paddle.shape(speech_hidden_states)[0], -1, self.odim))
|
|
else:
|
|
before_outs = speech_hidden_states
|
|
if self.postnet is not None:
|
|
after_outs = before_outs + paddle.transpose(
|
|
self.postnet(paddle.transpose(before_outs, [0, 2, 1])),
|
|
[0, 2, 1])
|
|
else:
|
|
after_outs = None
|
|
return before_outs, after_outs, text_outs
|
|
|
|
|
|
def build_model_from_file(config_file, model_file):
|
|
|
|
state_dict = paddle.load(model_file)
|
|
model_class = MLMDualMaksing if 'conformer_combine_vctk_aishell3_dual_masking' in config_file \
|
|
else MLMEncAsDecoder
|
|
|
|
# 构建模型
|
|
with open(config_file) as f:
|
|
conf = CfgNode(yaml.safe_load(f))
|
|
model = build_model(conf, model_class)
|
|
model.set_state_dict(state_dict)
|
|
return model, conf
|
|
|
|
|
|
# select encoder and decoder here
|
|
def build_model(args: argparse.Namespace, model_class=MLMEncAsDecoder) -> MLM:
|
|
if isinstance(args.token_list, str):
|
|
with open(args.token_list, encoding="utf-8") as f:
|
|
token_list = [line.rstrip() for line in f]
|
|
|
|
# Overwriting token_list to keep it as "portable".
|
|
args.token_list = list(token_list)
|
|
elif isinstance(args.token_list, (tuple, list)):
|
|
token_list = list(args.token_list)
|
|
else:
|
|
raise RuntimeError("token_list must be str or list")
|
|
|
|
vocab_size = len(token_list)
|
|
odim = 80
|
|
|
|
# Encoder
|
|
encoder_class = MLMEncoder
|
|
|
|
if 'text_masking' in args.model_conf.keys() and args.model_conf[
|
|
'text_masking']:
|
|
args.encoder_conf['text_masking'] = True
|
|
else:
|
|
args.encoder_conf['text_masking'] = False
|
|
|
|
encoder = encoder_class(
|
|
args.input_size, vocab_size=vocab_size, **args.encoder_conf)
|
|
|
|
# Decoder
|
|
if args.decoder != 'no_decoder':
|
|
decoder_class = MLMDecoder
|
|
decoder = decoder_class(
|
|
idim=0,
|
|
input_layer=None,
|
|
**args.decoder_conf, )
|
|
else:
|
|
decoder = None
|
|
|
|
# Build model
|
|
model = model_class(
|
|
odim=odim,
|
|
encoder=encoder,
|
|
decoder=decoder,
|
|
**args.model_conf, )
|
|
|
|
# Initialize
|
|
if args.init is not None:
|
|
initialize(model, args.init)
|
|
|
|
return model
|