format code

pull/1054/head
Hui Zhang 3 years ago
parent d395c2b8e3
commit 39228864bb

@ -339,6 +339,3 @@ You need to prepare an audio file, please confirm the sample rate of the audio i
```bash ```bash
CUDA_VISIBLE_DEVICES= ./local/test_hub.sh conf/transformer.yaml exp/transformer/checkpoints/avg_20 data/test_audio.wav CUDA_VISIBLE_DEVICES= ./local/test_hub.sh conf/transformer.yaml exp/transformer/checkpoints/avg_20 data/test_audio.wav
``` ```

@ -129,8 +129,8 @@ class U2Trainer(Trainer):
losses_np_v = losses_np.copy() losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()}) losses_np_v.update({"lr": self.lr_scheduler()})
for key, val in losses_np_v.items(): for key, val in losses_np_v.items():
self.visualizer.add_scalar(tag='train/'+key, value=val, step=self.iteration-1) self.visualizer.add_scalar(
tag='train/' + key, value=val, step=self.iteration - 1)
@paddle.no_grad() @paddle.no_grad()
def valid(self): def valid(self):
@ -238,8 +238,10 @@ class U2Trainer(Trainer):
logger.info( logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer: if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch) self.visualizer.add_scalar(
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch) tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch() self.new_epoch()

@ -132,7 +132,8 @@ class U2Trainer(Trainer):
losses_np_v = losses_np.copy() losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()}) losses_np_v.update({"lr": self.lr_scheduler()})
for key, val in losses_np_v.items(): for key, val in losses_np_v.items():
self.visualizer.add_scalar(tag="train/"+key, value=val, step=self.iteration - 1) self.visualizer.add_scalar(
tag="train/" + key, value=val, step=self.iteration - 1)
@paddle.no_grad() @paddle.no_grad()
def valid(self): def valid(self):
@ -222,8 +223,10 @@ class U2Trainer(Trainer):
logger.info( logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer: if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch) self.visualizer.add_scalar(
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch) tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch() self.new_epoch()

@ -139,7 +139,8 @@ class U2STTrainer(Trainer):
losses_np_v = losses_np.copy() losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()}) losses_np_v.update({"lr": self.lr_scheduler()})
for key, val in losses_np_v.items(): for key, val in losses_np_v.items():
self.visualizer.add_scalar(tag="train/"+key, value=val, step=self.iteration - 1) self.visualizer.add_scalar(
tag="train/" + key, value=val, step=self.iteration - 1)
@paddle.no_grad() @paddle.no_grad()
def valid(self): def valid(self):
@ -235,8 +236,10 @@ class U2STTrainer(Trainer):
logger.info( logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer: if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch) self.visualizer.add_scalar(
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch) tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch() self.new_epoch()

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains the impulse response augmentation model.""" """Contains the impulse response augmentation model."""
import jsonlines import jsonlines
from paddlespeech.s2t.frontend.audio import AudioSegment from paddlespeech.s2t.frontend.audio import AudioSegment
from paddlespeech.s2t.frontend.augmentor.base import AugmentorBase from paddlespeech.s2t.frontend.augmentor.base import AugmentorBase

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains the noise perturb augmentation model.""" """Contains the noise perturb augmentation model."""
import jsonlines import jsonlines
from paddlespeech.s2t.frontend.audio import AudioSegment from paddlespeech.s2t.frontend.audio import AudioSegment
from paddlespeech.s2t.frontend.augmentor.base import AugmentorBase from paddlespeech.s2t.frontend.augmentor.base import AugmentorBase

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains feature normalizers.""" """Contains feature normalizers."""
import json import json
import jsonlines import jsonlines
import numpy as np import numpy as np
import paddle import paddle
@ -27,6 +28,7 @@ __all__ = ["FeatureNormalizer"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
# https://github.com/PaddlePaddle/Paddle/pull/31481 # https://github.com/PaddlePaddle/Paddle/pull/31481
class CollateFunc(object): class CollateFunc(object):
def __init__(self, feature_func): def __init__(self, feature_func):

@ -15,8 +15,8 @@ from typing import Any
from typing import Dict from typing import Dict
from typing import List from typing import List
from typing import Text from typing import Text
import jsonlines
import jsonlines
import numpy as np import numpy as np
from paddle.io import DataLoader from paddle.io import DataLoader

@ -14,6 +14,7 @@
# Modified from espnet(https://github.com/espnet/espnet) # Modified from espnet(https://github.com/espnet/espnet)
# Modified from wenet(https://github.com/wenet-e2e/wenet) # Modified from wenet(https://github.com/wenet-e2e/wenet)
from typing import Optional from typing import Optional
import jsonlines import jsonlines
from paddle.io import Dataset from paddle.io import Dataset
from yacs.config import CfgNode from yacs.config import CfgNode

@ -309,8 +309,10 @@ class Trainer():
logger.info( logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer: if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch) self.visualizer.add_scalar(
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch) tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
# after epoch # after epoch
self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.save(tag=self.epoch, infos={'val_loss': cv_loss})

@ -20,6 +20,7 @@ import time
import wave import wave
from time import gmtime from time import gmtime
from time import strftime from time import strftime
import jsonlines import jsonlines
__all__ = ["socket_send", "warm_up_test", "AsrTCPServer", "AsrRequestHandler"] __all__ = ["socket_send", "warm_up_test", "AsrTCPServer", "AsrRequestHandler"]

@ -252,8 +252,10 @@ class Trainer():
self.logger.info("Epoch {} Val info val_loss {}, F1_score {}". self.logger.info("Epoch {} Val info val_loss {}, F1_score {}".
format(self.epoch, total_loss, F1_score)) format(self.epoch, total_loss, F1_score))
if self.visualizer: if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch) self.visualizer.add_scalar(
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch) tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.save( self.save(
tag=self.epoch, infos={"val_loss": total_loss, tag=self.epoch, infos={"val_loss": total_loss,

@ -19,9 +19,10 @@ import argparse
import functools import functools
import os import os
import tempfile import tempfile
import jsonlines
from collections import Counter from collections import Counter
import jsonlines
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import BLANK from paddlespeech.s2t.frontend.utility import BLANK
from paddlespeech.s2t.frontend.utility import SOS from paddlespeech.s2t.frontend.utility import SOS

@ -16,6 +16,7 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
import jsonlines import jsonlines
key_whitelist = set(['feat', 'text', 'syllable', 'phone']) key_whitelist = set(['feat', 'text', 'syllable', 'phone'])

@ -15,9 +15,10 @@
"""format manifest with more metadata.""" """format manifest with more metadata."""
import argparse import argparse
import functools import functools
import jsonlines
import json import json
import jsonlines
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import load_cmvn from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.io.utility import feat_type from paddlespeech.s2t.io.utility import feat_type

@ -16,6 +16,7 @@
import argparse import argparse
import functools import functools
import json import json
import jsonlines import jsonlines
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer

@ -3,6 +3,7 @@
import argparse import argparse
import functools import functools
from pathlib import Path from pathlib import Path
import jsonlines import jsonlines
from utils.utility import add_arguments from utils.utility import add_arguments

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import hashlib import hashlib
import json
import os import os
import sys import sys
import tarfile import tarfile

Loading…
Cancel
Save