|
|
|
@ -12,10 +12,14 @@
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
import logging
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Sequence
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
from paddle import distributed as dist
|
|
|
|
|
from paddle.io import DataLoader
|
|
|
|
|
from paddle.nn import Layer
|
|
|
|
|
from paddle.optimizer import Optimizer
|
|
|
|
|
|
|
|
|
|
from paddlespeech.t2s.modules.losses import GuidedMultiHeadAttentionLoss
|
|
|
|
|
from paddlespeech.t2s.modules.losses import Tacotron2Loss as TransformerTTSLoss
|
|
|
|
@ -32,14 +36,14 @@ logger.setLevel(logging.INFO)
|
|
|
|
|
class TransformerTTSUpdater(StandardUpdater):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
model,
|
|
|
|
|
optimizer,
|
|
|
|
|
dataloader,
|
|
|
|
|
model: Layer,
|
|
|
|
|
optimizer: Optimizer,
|
|
|
|
|
dataloader: DataLoader,
|
|
|
|
|
init_state=None,
|
|
|
|
|
use_masking=False,
|
|
|
|
|
use_weighted_masking=False,
|
|
|
|
|
output_dir=None,
|
|
|
|
|
bce_pos_weight=5.0,
|
|
|
|
|
use_masking: bool=False,
|
|
|
|
|
use_weighted_masking: bool=False,
|
|
|
|
|
output_dir: Path=None,
|
|
|
|
|
bce_pos_weight: float=5.0,
|
|
|
|
|
loss_type: str="L1",
|
|
|
|
|
use_guided_attn_loss: bool=True,
|
|
|
|
|
modules_applied_guided_attn: Sequence[str]=("encoder-decoder"),
|
|
|
|
@ -185,13 +189,13 @@ class TransformerTTSUpdater(StandardUpdater):
|
|
|
|
|
class TransformerTTSEvaluator(StandardEvaluator):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
model,
|
|
|
|
|
dataloader,
|
|
|
|
|
model: Layer,
|
|
|
|
|
dataloader: DataLoader,
|
|
|
|
|
init_state=None,
|
|
|
|
|
use_masking=False,
|
|
|
|
|
use_weighted_masking=False,
|
|
|
|
|
output_dir=None,
|
|
|
|
|
bce_pos_weight=5.0,
|
|
|
|
|
use_masking: bool=False,
|
|
|
|
|
use_weighted_masking: bool=False,
|
|
|
|
|
output_dir: Path=None,
|
|
|
|
|
bce_pos_weight: float=5.0,
|
|
|
|
|
loss_type: str="L1",
|
|
|
|
|
use_guided_attn_loss: bool=True,
|
|
|
|
|
modules_applied_guided_attn: Sequence[str]=("encoder-decoder"),
|
|
|
|
|