using visualdl; fix read_manifest

pull/1054/head
Hui Zhang 3 years ago
parent 05a6f7767b
commit 7554b6107a

@ -128,8 +128,9 @@ class U2Trainer(Trainer):
if dist.get_rank() == 0 and self.visualizer: if dist.get_rank() == 0 and self.visualizer:
losses_np_v = losses_np.copy() losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()}) losses_np_v.update({"lr": self.lr_scheduler()})
self.visualizer.add_scalars("step", losses_np_v, for key, val in losses_np_v.items():
self.iteration - 1) self.visualizer.add_scalar(tag='train/'+key, value=val, step=self.iteration-1)
@paddle.no_grad() @paddle.no_grad()
def valid(self): def valid(self):
@ -237,9 +238,8 @@ class U2Trainer(Trainer):
logger.info( logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer: if self.visualizer:
self.visualizer.add_scalars( self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch)
'epoch', {'cv_loss': cv_loss, self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
'lr': self.lr_scheduler()}, self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch() self.new_epoch()

@ -131,8 +131,8 @@ class U2Trainer(Trainer):
if dist.get_rank() == 0 and self.visualizer: if dist.get_rank() == 0 and self.visualizer:
losses_np_v = losses_np.copy() losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()}) losses_np_v.update({"lr": self.lr_scheduler()})
self.visualizer.add_scalars("step", losses_np_v, for key, val in losses_np_v.items():
self.iteration - 1) self.visualizer.add_scalar(tag="train/"+key, value=val, step=self.iteration - 1)
@paddle.no_grad() @paddle.no_grad()
def valid(self): def valid(self):
@ -222,9 +222,9 @@ class U2Trainer(Trainer):
logger.info( logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer: if self.visualizer:
self.visualizer.add_scalars( self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch)
'epoch', {'cv_loss': cv_loss, self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
'lr': self.lr_scheduler()}, self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch() self.new_epoch()

@ -138,8 +138,8 @@ class U2STTrainer(Trainer):
if dist.get_rank() == 0 and self.visualizer: if dist.get_rank() == 0 and self.visualizer:
losses_np_v = losses_np.copy() losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()}) losses_np_v.update({"lr": self.lr_scheduler()})
self.visualizer.add_scalars("step", losses_np_v, for key, val in losses_np_v.items():
self.iteration - 1) self.visualizer.add_scalar(tag="train/"+key, value=val, step=self.iteration - 1)
@paddle.no_grad() @paddle.no_grad()
def valid(self): def valid(self):
@ -235,9 +235,9 @@ class U2STTrainer(Trainer):
logger.info( logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer: if self.visualizer:
self.visualizer.add_scalars( self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch)
'epoch', {'cv_loss': cv_loss, self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
'lr': self.lr_scheduler()}, self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch() self.new_epoch()

@ -16,18 +16,35 @@ import json
import numpy as np import numpy as np
import paddle import paddle
import jsonlines
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import Dataset from paddle.io import Dataset
from paddlespeech.s2t.frontend.audio import AudioSegment from paddlespeech.s2t.frontend.audio import AudioSegment
from paddlespeech.s2t.frontend.utility import load_cmvn from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.frontend.utility import read_manifest
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
__all__ = ["FeatureNormalizer"] __all__ = ["FeatureNormalizer"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
def read_manifest(manifest_path):
"""Load and parse manifest file.
Args:
manifest_path ([type]): Manifest file to load and parse.
Raises:
IOError: If failed to parse the manifest.
Returns:
List[dict]: Manifest parsing results.
"""
manifest = []
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
manifest.append(json_data)
return manifest
# https://github.com/PaddlePaddle/Paddle/pull/31481 # https://github.com/PaddlePaddle/Paddle/pull/31481
class CollateFunc(object): class CollateFunc(object):
@ -61,7 +78,11 @@ class CollateFunc(object):
class AudioDataset(Dataset): class AudioDataset(Dataset):
def __init__(self, manifest_path, num_samples=-1, rng=None, random_seed=0): def __init__(self, manifest_path, num_samples=-1, rng=None, random_seed=0):
self._rng = rng if rng else np.random.RandomState(random_seed) self._rng = rng if rng else np.random.RandomState(random_seed)
manifest = read_manifest(manifest_path) manifest = []
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
manifest.append(json_data)
if num_samples == -1: if num_samples == -1:
sampled_manifest = manifest sampled_manifest = manifest
else: else:

@ -65,7 +65,26 @@ def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]:
return char_list return char_list
def read_manifest( def read_manifest(manifest_path,):
"""Load and parse manifest file.
Args:
manifest_path ([type]): Manifest file to load and parse.
Raises:
IOError: If failed to parse the manifest.
Returns:
List[dict]: Manifest parsing results.
"""
manifest = []
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
manifest.append(json_data)
return manifest
def read_manifest_filter(
manifest_path, manifest_path,
max_input_len=float('inf'), max_input_len=float('inf'),
min_input_len=0.0, min_input_len=0.0,
@ -98,7 +117,6 @@ def read_manifest(
Returns: Returns:
List[dict]: Manifest parsing results. List[dict]: Manifest parsing results.
""" """
manifest = [] manifest = []
with jsonlines.open(manifest_path, 'r') as reader: with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader: for json_data in reader:

@ -95,7 +95,7 @@ class ManifestDataset(Dataset):
super().__init__() super().__init__()
# read manifest # read manifest
self._manifest = read_manifest( self._manifest = read_manifest_filter(
manifest_path=manifest_path, manifest_path=manifest_path,
max_input_len=max_input_len, max_input_len=max_input_len,
min_input_len=min_input_len, min_input_len=min_input_len,

@ -19,7 +19,7 @@ from pathlib import Path
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
from tensorboardX import SummaryWriter from visualdl import LogWriter
from paddlespeech.s2t.training.reporter import ObsScope from paddlespeech.s2t.training.reporter import ObsScope
from paddlespeech.s2t.training.reporter import report from paddlespeech.s2t.training.reporter import report
@ -309,9 +309,8 @@ class Trainer():
logger.info( logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer: if self.visualizer:
self.visualizer.add_scalars( self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch)
'epoch', {'cv_loss': cv_loss, self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
'lr': self.lr_scheduler()}, self.epoch)
# after epoch # after epoch
self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.save(tag=self.epoch, infos={'val_loss': cv_loss})
@ -427,7 +426,7 @@ class Trainer():
unexpected behaviors. unexpected behaviors.
""" """
# visualizer # visualizer
visualizer = SummaryWriter(logdir=str(self.visual_dir)) visualizer = LogWriter(logdir=str(self.visual_dir))
self.visualizer = visualizer self.visualizer = visualizer
@mp_tools.rank_zero_only @mp_tools.rank_zero_only

@ -34,7 +34,7 @@ from speechtask.punctuation_restoration.model.lstm import RnnLm
from speechtask.punctuation_restoration.utils import layer_tools from speechtask.punctuation_restoration.utils import layer_tools
from speechtask.punctuation_restoration.utils import mp_tools from speechtask.punctuation_restoration.utils import mp_tools
from speechtask.punctuation_restoration.utils.checkpoint import Checkpoint from speechtask.punctuation_restoration.utils.checkpoint import Checkpoint
from tensorboardX import SummaryWriter from visualdl import LogWriter
__all__ = ["Trainer", "Tester"] __all__ = ["Trainer", "Tester"]
@ -252,10 +252,8 @@ class Trainer():
self.logger.info("Epoch {} Val info val_loss {}, F1_score {}". self.logger.info("Epoch {} Val info val_loss {}, F1_score {}".
format(self.epoch, total_loss, F1_score)) format(self.epoch, total_loss, F1_score))
if self.visualizer: if self.visualizer:
self.visualizer.add_scalars("epoch", { self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch)
"total_loss": total_loss, self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
"lr": self.lr_scheduler()
}, self.epoch)
self.save( self.save(
tag=self.epoch, infos={"val_loss": total_loss, tag=self.epoch, infos={"val_loss": total_loss,
@ -341,7 +339,7 @@ class Trainer():
unexpected behaviors. unexpected behaviors.
""" """
# visualizer # visualizer
visualizer = SummaryWriter(logdir=str(self.output_dir)) visualizer = LogWriter(logdir=str(self.output_dir))
self.visualizer = visualizer self.visualizer = visualizer
@mp_tools.rank_zero_only @mp_tools.rank_zero_only

@ -40,7 +40,6 @@ snakeviz
soundfile~=0.10 soundfile~=0.10
sox sox
soxbindings soxbindings
tensorboardX
textgrid textgrid
timer timer
tqdm tqdm

@ -19,11 +19,11 @@ import argparse
import functools import functools
import os import os
import tempfile import tempfile
import jsonlines
from collections import Counter from collections import Counter
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import BLANK from paddlespeech.s2t.frontend.utility import BLANK
from paddlespeech.s2t.frontend.utility import read_manifest
from paddlespeech.s2t.frontend.utility import SOS from paddlespeech.s2t.frontend.utility import SOS
from paddlespeech.s2t.frontend.utility import SPACE from paddlespeech.s2t.frontend.utility import SPACE
from paddlespeech.s2t.frontend.utility import UNK from paddlespeech.s2t.frontend.utility import UNK
@ -59,13 +59,21 @@ args = parser.parse_args()
def count_manifest(counter, text_feature, manifest_path): def count_manifest(counter, text_feature, manifest_path):
manifest_jsons = read_manifest(manifest_path) manifest_jsons = []
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
manifest_jsons.append(json_data)
for line_json in manifest_jsons: for line_json in manifest_jsons:
line = text_feature.tokenize(line_json['text'], replace_space=False) line = text_feature.tokenize(line_json['text'], replace_space=False)
counter.update(line) counter.update(line)
def dump_text_manifest(fileobj, manifest_path, key='text'): def dump_text_manifest(fileobj, manifest_path, key='text'):
manifest_jsons = read_manifest(manifest_path) manifest_jsons = []
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
manifest_jsons.append(json_data)
for line_json in manifest_jsons: for line_json in manifest_jsons:
fileobj.write(line_json[key] + "\n") fileobj.write(line_json[key] + "\n")

@ -42,6 +42,7 @@ def read_manifest(manifest_path):
for json_line in open(manifest_path, 'r'): for json_line in open(manifest_path, 'r'):
try: try:
json_data = json.loads(json_line) json_data = json.loads(json_line)
manifest.append(json_data)
except Exception as e: except Exception as e:
raise IOError("Error reading manifest: %s" % str(e)) raise IOError("Error reading manifest: %s" % str(e))
return manifest return manifest

Loading…
Cancel
Save