fix mv writer to visualdl in train

pull/935/head
TianYuan 4 years ago
parent 950d17cbcf
commit 88668513b1

@ -9,34 +9,34 @@ English | [简体中文](README_ch.md)
</p> </p>
<div align="center"> <div align="center">
<h3> <h3>
<a href="https://github.com/Mingxue-Xu/DeepSpeech#quick-start"> Quick Start </a> <a href="https://github.com/Mingxue-Xu/DeepSpeech#quick-start"> Quick Start </a>
| <a href="https://github.com/Mingxue-Xu/DeepSpeech#tutorials"> Tutorials </a> | <a href="https://github.com/Mingxue-Xu/DeepSpeech#tutorials"> Tutorials </a>
| <a href="https://github.com/Mingxue-Xu/DeepSpeech#model-list"> Models List </a> | <a href="https://github.com/Mingxue-Xu/DeepSpeech#model-list"> Models List </a>
</div> </div>
------------------------------------------------------------------------------------ ------------------------------------------------------------------------------------
![License](https://img.shields.io/badge/license-Apache%202-red.svg) ![License](https://img.shields.io/badge/license-Apache%202-red.svg)
![python version](https://img.shields.io/badge/python-3.7+-orange.svg) ![python version](https://img.shields.io/badge/python-3.7+-orange.svg)
![support os](https://img.shields.io/badge/os-linux-yellow.svg) ![support os](https://img.shields.io/badge/os-linux-yellow.svg)
<!--- <!---
why they should use your module, why they should use your module,
how they can install it, how they can install it,
how they can use it how they can use it
--> -->
**PaddleSpeech** is an open-source toolkit on [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) platform for two critical tasks in Speech - **Automatic Speech Recognition (ASR)** and **Text-To-Speech Synthesis (TTS)**, with modules involving state-of-art and influential models. **PaddleSpeech** is an open-source toolkit on [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) platform for two critical tasks in Speech - **Automatic Speech Recognition (ASR)** and **Text-To-Speech Synthesis (TTS)**, with modules involving state-of-art and influential models.
Via the easy-to-use, efficient, flexible and scalable implementation, our vision is to empower both industrial application and academic research, including training, inference & testing module, and deployment. Besides, this toolkit also features at: Via the easy-to-use, efficient, flexible and scalable implementation, our vision is to empower both industrial application and academic research, including training, inference & testing module, and deployment. Besides, this toolkit also features at:
- **Fast and Light-weight**: we provide a high-speed and ultra-lightweight model that is convenient for industrial deployment. - **Fast and Light-weight**: we provide a high-speed and ultra-lightweight model that is convenient for industrial deployment.
- **Rule-based Chinese frontend**: our frontend contains Text Normalization (TN) and Grapheme-to-Phoneme (G2P, including Polyphone and Tone Sandhi). Moreover, we use self-defined linguistic rules to adapt Chinese context. - **Rule-based Chinese frontend**: our frontend contains Text Normalization (TN) and Grapheme-to-Phoneme (G2P, including Polyphone and Tone Sandhi). Moreover, we use self-defined linguistic rules to adapt Chinese context.
- **Varieties of Functions that Vitalize Research**: - **Varieties of Functions that Vitalize Research**:
- *Integration of mainstream models and datasets*: the toolkit implements modules that participate in the whole pipeline of both ASR and TTS, and uses datasets like LibriSpeech, LJSpeech, AIShell, etc. See also [model lists](#models-list) for more details. - *Integration of mainstream models and datasets*: the toolkit implements modules that participate in the whole pipeline of both ASR and TTS, and uses datasets like LibriSpeech, LJSpeech, AIShell, etc. See also [model lists](#models-list) for more details.
- *Support of ASR streaming and non-streaming data*: This toolkit contains non-streaming/streaming models like [DeepSpeech2](http://proceedings.mlr.press/v48/amodei16.pdf), [Transformer](https://arxiv.org/abs/1706.03762), [Conformer](https://arxiv.org/abs/2005.08100) and [U2](https://arxiv.org/pdf/2012.05481.pdf). - *Support of ASR streaming and non-streaming data*: This toolkit contains non-streaming/streaming models like [DeepSpeech2](http://proceedings.mlr.press/v48/amodei16.pdf), [Transformer](https://arxiv.org/abs/1706.03762), [Conformer](https://arxiv.org/abs/2005.08100) and [U2](https://arxiv.org/pdf/2012.05481.pdf).
Let's install PaddleSpeech with only a few lines of code! Let's install PaddleSpeech with only a few lines of code!
>Note: The official name is still deepspeech. 2021/10/26 >Note: The official name is still deepspeech. 2021/10/26
@ -44,7 +44,7 @@ Let's install PaddleSpeech with only a few lines of code!
# 1. Install essential libraries and paddlepaddle first. # 1. Install essential libraries and paddlepaddle first.
# install prerequisites # install prerequisites
sudo apt-get install -y sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev libsndfile1 sudo apt-get install -y sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev libsndfile1
# `pip install paddlepaddle-gpu` instead if you are using GPU. # `pip install paddlepaddle-gpu` instead if you are using GPU.
pip install paddlepaddle pip install paddlepaddle
# 2.Then install PaddleSpeech. # 2.Then install PaddleSpeech.
@ -109,7 +109,7 @@ If you want to try more functions like training and tuning, please see [ASR gett
PaddleSpeech ASR supports a lot of mainstream models, which are summarized as follow. For more information, please refer to [ASR Models](./docs/source/asr/released_model.md). PaddleSpeech ASR supports a lot of mainstream models, which are summarized as follow. For more information, please refer to [ASR Models](./docs/source/asr/released_model.md).
<!--- <!---
The current hyperlinks redirect to [Previous Parakeet](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples). The current hyperlinks redirect to [Previous Parakeet](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples).
--> -->
<table> <table>
@ -125,7 +125,7 @@ The current hyperlinks redirect to [Previous Parakeet](https://github.com/Paddle
<tr> <tr>
<td rowspan="6">Acoustic Model</td> <td rowspan="6">Acoustic Model</td>
<td rowspan="4" >Aishell</td> <td rowspan="4" >Aishell</td>
<td >2 Conv + 5 LSTM layers with only forward direction </td> <td >2 Conv + 5 LSTM layers with only forward direction </td>
<td> <td>
<a href = "https://deepspeech.bj.bcebos.com/release2.1/aishell/s0/aishell.s0.ds_online.5rnn.debug.tar.gz">Ds2 Online Aishell Model</a> <a href = "https://deepspeech.bj.bcebos.com/release2.1/aishell/s0/aishell.s0.ds_online.5rnn.debug.tar.gz">Ds2 Online Aishell Model</a>
</td> </td>
@ -199,7 +199,7 @@ PaddleSpeech TTS mainly contains three modules: *Text Frontend*, *Acoustic Model
<tr> <tr>
<td> Text Frontend</td> <td> Text Frontend</td>
<td colspan="2"> &emsp; </td> <td colspan="2"> &emsp; </td>
<td> <td>
<a href = "https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/other/text_frontend">chinese-fronted</a> <a href = "https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/other/text_frontend">chinese-fronted</a>
</td> </td>
</tr> </tr>
@ -292,11 +292,11 @@ PaddleSpeech TTS mainly contains three modules: *Text Frontend*, *Acoustic Model
</table> </table>
## Tutorials ## Tutorials
Normally, [Speech SoTA](https://paperswithcode.com/area/speech) gives you an overview of the hot academic topics in speech. If you want to focus on the two tasks in PaddleSpeech, you will find the following guidelines are helpful to grasp the core ideas. Normally, [Speech SoTA](https://paperswithcode.com/area/speech) gives you an overview of the hot academic topics in speech. If you want to focus on the two tasks in PaddleSpeech, you will find the following guidelines are helpful to grasp the core ideas.
The original ASR module is based on [Baidu's DeepSpeech](https://arxiv.org/abs/1412.5567) which is an independent product named [DeepSpeech](https://deepspeech.readthedocs.io). However, the toolkit aligns almost all the SoTA modules in the pipeline. Specifically, these modules are The original ASR module is based on [Baidu's DeepSpeech](https://arxiv.org/abs/1412.5567) which is an independent product named [DeepSpeech](https://deepspeech.readthedocs.io). However, the toolkit aligns almost all the SoTA modules in the pipeline. Specifically, these modules are
* [Data Prepration](docs/source/asr/data_preparation.md) * [Data Prepration](docs/source/asr/data_preparation.md)
* [Data Augmentation](docs/source/asr/augmentation.md) * [Data Augmentation](docs/source/asr/augmentation.md)
@ -318,4 +318,3 @@ PaddleSpeech is provided under the [Apache-2.0 License](./LICENSE).
## Acknowledgement ## Acknowledgement
PaddleSpeech depends on a lot of open source repos. See [references](docs/source/asr/reference.md) for more information. PaddleSpeech depends on a lot of open source repos. See [references](docs/source/asr/reference.md) for more information.

@ -10,13 +10,13 @@ Example instruction to install paddlepaddle via pip is listed below.
### PaddlePaddle with GPU ### PaddlePaddle with GPU
```python ```python
# PaddlePaddle for CUDA10.1 # PaddlePaddle for CUDA10.1
python -m pip install paddlepaddle-gpu==2.1.2.post101 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html python -m pip install paddlepaddle-gpu==2.1.2.post101 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
# PaddlePaddle for CUDA10.2 # PaddlePaddle for CUDA10.2
python -m pip install paddlepaddle-gpu -i https://mirror.baidu.com/pypi/simple python -m pip install paddlepaddle-gpu -i https://mirror.baidu.com/pypi/simple
# PaddlePaddle for CUDA11.0 # PaddlePaddle for CUDA11.0
python -m pip install paddlepaddle-gpu==2.1.2.post110 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html python -m pip install paddlepaddle-gpu==2.1.2.post110 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
# PaddlePaddle for CUDA11.2 # PaddlePaddle for CUDA11.2
python -m pip install paddlepaddle-gpu==2.1.2.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html python -m pip install paddlepaddle-gpu==2.1.2.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
``` ```
### PaddlePaddle with CPU ### PaddlePaddle with CPU

@ -37,4 +37,3 @@
```bash ```bash
bash local/export.sh ckpt_path saved_jit_model_path bash local/export.sh ckpt_path saved_jit_model_path
``` ```

@ -53,8 +53,8 @@ def batch_text_id(minibatch, pad_id=0, dtype=np.int64):
peek_example = minibatch[0] peek_example = minibatch[0]
assert len(peek_example.shape) == 1, "text example is an 1D tensor" assert len(peek_example.shape) == 1, "text example is an 1D tensor"
lengths = [example.shape[0] for example in lengths = [example.shape[0] for example in minibatch
minibatch] # assume (channel, n_samples) or (n_samples, ) ] # assume (channel, n_samples) or (n_samples, )
max_len = np.max(lengths) max_len = np.max(lengths)
batch = [] batch = []

@ -25,7 +25,6 @@ from paddle import DataParallel
from paddle import distributed as dist from paddle import distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from visualdl import LogWriter
from yacs.config import CfgNode from yacs.config import CfgNode
from parakeet.datasets.am_batch_fn import fastspeech2_multi_spk_batch_fn from parakeet.datasets.am_batch_fn import fastspeech2_multi_spk_batch_fn
@ -160,8 +159,7 @@ def train_sp(args, config):
if dist.get_rank() == 0: if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch")) trainer.extend(evaluator, trigger=(1, "epoch"))
writer = LogWriter(str(output_dir)) trainer.extend(VisualDL(output_dir), trigger=(1, "iteration"))
trainer.extend(VisualDL(writer), trigger=(1, "iteration"))
trainer.extend( trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch')) Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
# print(trainer.extensions) # print(trainer.extensions)

@ -28,7 +28,6 @@ from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from paddle.optimizer import Adam from paddle.optimizer import Adam
from paddle.optimizer.lr import MultiStepDecay from paddle.optimizer.lr import MultiStepDecay
from visualdl import LogWriter
from yacs.config import CfgNode from yacs.config import CfgNode
from parakeet.datasets.data_table import DataTable from parakeet.datasets.data_table import DataTable
@ -219,8 +218,7 @@ def train_sp(args, config):
if dist.get_rank() == 0: if dist.get_rank() == 0:
trainer.extend( trainer.extend(
evaluator, trigger=(config.eval_interval_steps, 'iteration')) evaluator, trigger=(config.eval_interval_steps, 'iteration'))
writer = LogWriter(str(trainer.out)) trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration'))
trainer.extend(VisualDL(writer), trigger=(1, 'iteration'))
trainer.extend( trainer.extend(
Snapshot(max_size=config.num_snapshots), Snapshot(max_size=config.num_snapshots),
trigger=(config.save_interval_steps, 'iteration')) trigger=(config.save_interval_steps, 'iteration'))

@ -28,7 +28,6 @@ from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from paddle.optimizer import Adam # No RAdaom from paddle.optimizer import Adam # No RAdaom
from paddle.optimizer.lr import StepDecay from paddle.optimizer.lr import StepDecay
from visualdl import LogWriter
from yacs.config import CfgNode from yacs.config import CfgNode
from parakeet.datasets.data_table import DataTable from parakeet.datasets.data_table import DataTable
@ -193,8 +192,7 @@ def train_sp(args, config):
if dist.get_rank() == 0: if dist.get_rank() == 0:
trainer.extend( trainer.extend(
evaluator, trigger=(config.eval_interval_steps, 'iteration')) evaluator, trigger=(config.eval_interval_steps, 'iteration'))
writer = LogWriter(str(trainer.out)) trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration'))
trainer.extend(VisualDL(writer), trigger=(1, 'iteration'))
trainer.extend( trainer.extend(
Snapshot(max_size=config.num_snapshots), Snapshot(max_size=config.num_snapshots),
trigger=(config.save_interval_steps, 'iteration')) trigger=(config.save_interval_steps, 'iteration'))

@ -25,7 +25,6 @@ from paddle import DataParallel
from paddle import distributed as dist from paddle import distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from visualdl import LogWriter
from yacs.config import CfgNode from yacs.config import CfgNode
from parakeet.datasets.am_batch_fn import speedyspeech_batch_fn from parakeet.datasets.am_batch_fn import speedyspeech_batch_fn
@ -153,8 +152,7 @@ def train_sp(args, config):
if dist.get_rank() == 0: if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch")) trainer.extend(evaluator, trigger=(1, "epoch"))
writer = LogWriter(str(output_dir)) trainer.extend(VisualDL(output_dir), trigger=(1, "iteration"))
trainer.extend(VisualDL(writer), trigger=(1, "iteration"))
trainer.extend( trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch')) Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
trainer.run() trainer.run()

@ -67,16 +67,19 @@ class LJSpeechCollector(object):
# Sort by text_len in descending order # Sort by text_len in descending order
texts = [ texts = [
i for i, _ in sorted( i
for i, _ in sorted(
zip(texts, text_lens), key=lambda x: x[1], reverse=True) zip(texts, text_lens), key=lambda x: x[1], reverse=True)
] ]
mels = [ mels = [
i for i, _ in sorted( i
for i, _ in sorted(
zip(mels, text_lens), key=lambda x: x[1], reverse=True) zip(mels, text_lens), key=lambda x: x[1], reverse=True)
] ]
mel_lens = [ mel_lens = [
i for i, _ in sorted( i
for i, _ in sorted(
zip(mel_lens, text_lens), key=lambda x: x[1], reverse=True) zip(mel_lens, text_lens), key=lambda x: x[1], reverse=True)
] ]

@ -25,7 +25,6 @@ from paddle import DataParallel
from paddle import distributed as dist from paddle import distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from visualdl import LogWriter
from yacs.config import CfgNode from yacs.config import CfgNode
from parakeet.datasets.am_batch_fn import transformer_single_spk_batch_fn from parakeet.datasets.am_batch_fn import transformer_single_spk_batch_fn
@ -148,8 +147,7 @@ def train_sp(args, config):
if dist.get_rank() == 0: if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch")) trainer.extend(evaluator, trigger=(1, "epoch"))
writer = LogWriter(str(output_dir)) trainer.extend(VisualDL(output_dir), trigger=(1, "iteration"))
trainer.extend(VisualDL(writer), trigger=(1, "iteration"))
trainer.extend( trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch')) Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
# print(trainer.extensions) # print(trainer.extensions)

@ -1,231 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from paddle.optimizer.lr import LRScheduler
from timer import timer
from parakeet.training.extensions.evaluator import StandardEvaluator
from parakeet.training.reporter import report
from parakeet.training.updaters.standard_updater import StandardUpdater
from parakeet.training.updaters.standard_updater import UpdaterState
logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class PWGUpdater(StandardUpdater):
def __init__(self,
models: Dict[str, Layer],
optimizers: Dict[str, Optimizer],
criterions: Dict[str, Layer],
schedulers: Dict[str, LRScheduler],
dataloader: DataLoader,
discriminator_train_start_steps: int,
lambda_adv: float,
output_dir=None):
self.models = models
self.generator: Layer = models['generator']
self.discriminator: Layer = models['discriminator']
self.optimizers = optimizers
self.optimizer_g: Optimizer = optimizers['generator']
self.optimizer_d: Optimizer = optimizers['discriminator']
self.criterions = criterions
self.criterion_stft = criterions['stft']
self.criterion_mse = criterions['mse']
self.schedulers = schedulers
self.scheduler_g = schedulers['generator']
self.scheduler_d = schedulers['discriminator']
self.dataloader = dataloader
self.discriminator_train_start_steps = discriminator_train_start_steps
self.lambda_adv = lambda_adv
self.state = UpdaterState(iteration=0, epoch=0)
self.train_iterator = iter(self.dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {}
# parse batch
wav, mel = batch
# Generator
noise = paddle.randn(wav.shape)
with timer() as t:
wav_ = self.generator(noise, mel)
# logging.debug(f"Generator takes {t.elapse}s.")
# initialize
gen_loss = 0.0
## Multi-resolution stft loss
with timer() as t:
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
# logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.")
report("train/spectral_convergence_loss", float(sc_loss))
report("train/log_stft_magnitude_loss", float(mag_loss))
losses_dict["spectral_convergence_loss"] = float(sc_loss)
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
gen_loss += sc_loss + mag_loss
## Adversarial loss
if self.state.iteration > self.discriminator_train_start_steps:
with timer() as t:
p_ = self.discriminator(wav_)
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
# logging.debug(
# f"Discriminator and adversarial loss takes {t.elapse}s")
report("train/adversarial_loss", float(adv_loss))
losses_dict["adversarial_loss"] = float(adv_loss)
gen_loss += self.lambda_adv * adv_loss
report("train/generator_loss", float(gen_loss))
losses_dict["generator_loss"] = float(gen_loss)
with timer() as t:
self.optimizer_g.clear_grad()
gen_loss.backward()
# logging.debug(f"Backward takes {t.elapse}s.")
with timer() as t:
self.optimizer_g.step()
self.scheduler_g.step()
# logging.debug(f"Update takes {t.elapse}s.")
# Disctiminator
if self.state.iteration > self.discriminator_train_start_steps:
with paddle.no_grad():
wav_ = self.generator(noise, mel)
p = self.discriminator(wav)
p_ = self.discriminator(wav_.detach())
real_loss = self.criterion_mse(p, paddle.ones_like(p))
fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_))
dis_loss = real_loss + fake_loss
report("train/real_loss", float(real_loss))
report("train/fake_loss", float(fake_loss))
report("train/discriminator_loss", float(dis_loss))
losses_dict["real_loss"] = float(real_loss)
losses_dict["fake_loss"] = float(fake_loss)
losses_dict["discriminator_loss"] = float(dis_loss)
self.optimizer_d.clear_grad()
dis_loss.backward()
self.optimizer_d.step()
self.scheduler_d.step()
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
class PWGEvaluator(StandardEvaluator):
def __init__(self,
models,
criterions,
dataloader,
lambda_adv,
output_dir=None):
self.models = models
self.generator = models['generator']
self.discriminator = models['discriminator']
self.criterions = criterions
self.criterion_stft = criterions['stft']
self.criterion_mse = criterions['mse']
self.dataloader = dataloader
self.lambda_adv = lambda_adv
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def evaluate_core(self, batch):
# logging.debug("Evaluate: ")
self.msg = "Evaluate: "
losses_dict = {}
wav, mel = batch
noise = paddle.randn(wav.shape)
with timer() as t:
wav_ = self.generator(noise, mel)
# logging.debug(f"Generator takes {t.elapse}s")
## Adversarial loss
with timer() as t:
p_ = self.discriminator(wav_)
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
# logging.debug(
# f"Discriminator and adversarial loss takes {t.elapse}s")
report("eval/adversarial_loss", float(adv_loss))
losses_dict["adversarial_loss"] = float(adv_loss)
gen_loss = self.lambda_adv * adv_loss
# stft loss
with timer() as t:
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
# logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
report("eval/spectral_convergence_loss", float(sc_loss))
report("eval/log_stft_magnitude_loss", float(mag_loss))
losses_dict["spectral_convergence_loss"] = float(sc_loss)
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
gen_loss += sc_loss + mag_loss
report("eval/generator_loss", float(gen_loss))
losses_dict["generator_loss"] = float(gen_loss)
# Disctiminator
p = self.discriminator(wav)
real_loss = self.criterion_mse(p, paddle.ones_like(p))
fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_))
dis_loss = real_loss + fake_loss
report("eval/real_loss", float(real_loss))
report("eval/fake_loss", float(fake_loss))
report("eval/discriminator_loss", float(dis_loss))
losses_dict["real_loss"] = float(real_loss)
losses_dict["fake_loss"] = float(fake_loss)
losses_dict["discriminator_loss"] = float(dis_loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.logger.info(self.msg)

@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from visualdl import LogWriter
from parakeet.training import extension from parakeet.training import extension
from parakeet.training.trainer import Trainer from parakeet.training.trainer import Trainer
@ -26,8 +28,8 @@ class VisualDL(extension.Extension):
default_name = 'visualdl' default_name = 'visualdl'
priority = extension.PRIORITY_READER priority = extension.PRIORITY_READER
def __init__(self, writer): def __init__(self, logdir):
self.writer = writer self.writer = LogWriter(str(logdir))
def __call__(self, trainer: Trainer): def __call__(self, trainer: Trainer):
for k, v in trainer.observation.items(): for k, v in trainer.observation.items():

Loading…
Cancel
Save