format code

pull/1054/head
Hui Zhang 3 years ago
parent d395c2b8e3
commit 39228864bb

@ -339,6 +339,3 @@ You need to prepare an audio file, please confirm the sample rate of the audio i
```bash ```bash
CUDA_VISIBLE_DEVICES= ./local/test_hub.sh conf/transformer.yaml exp/transformer/checkpoints/avg_20 data/test_audio.wav CUDA_VISIBLE_DEVICES= ./local/test_hub.sh conf/transformer.yaml exp/transformer/checkpoints/avg_20 data/test_audio.wav
``` ```

@ -129,8 +129,8 @@ class U2Trainer(Trainer):
losses_np_v = losses_np.copy() losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()}) losses_np_v.update({"lr": self.lr_scheduler()})
for key, val in losses_np_v.items(): for key, val in losses_np_v.items():
self.visualizer.add_scalar(tag='train/'+key, value=val, step=self.iteration-1) self.visualizer.add_scalar(
tag='train/' + key, value=val, step=self.iteration - 1)
@paddle.no_grad() @paddle.no_grad()
def valid(self): def valid(self):
@ -238,8 +238,10 @@ class U2Trainer(Trainer):
logger.info( logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer: if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch) self.visualizer.add_scalar(
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch) tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch() self.new_epoch()

@ -132,7 +132,8 @@ class U2Trainer(Trainer):
losses_np_v = losses_np.copy() losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()}) losses_np_v.update({"lr": self.lr_scheduler()})
for key, val in losses_np_v.items(): for key, val in losses_np_v.items():
self.visualizer.add_scalar(tag="train/"+key, value=val, step=self.iteration - 1) self.visualizer.add_scalar(
tag="train/" + key, value=val, step=self.iteration - 1)
@paddle.no_grad() @paddle.no_grad()
def valid(self): def valid(self):
@ -222,9 +223,11 @@ class U2Trainer(Trainer):
logger.info( logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer: if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch) self.visualizer.add_scalar(
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch) tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch() self.new_epoch()

@ -139,7 +139,8 @@ class U2STTrainer(Trainer):
losses_np_v = losses_np.copy() losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()}) losses_np_v.update({"lr": self.lr_scheduler()})
for key, val in losses_np_v.items(): for key, val in losses_np_v.items():
self.visualizer.add_scalar(tag="train/"+key, value=val, step=self.iteration - 1) self.visualizer.add_scalar(
tag="train/" + key, value=val, step=self.iteration - 1)
@paddle.no_grad() @paddle.no_grad()
def valid(self): def valid(self):
@ -235,9 +236,11 @@ class U2STTrainer(Trainer):
logger.info( logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer: if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch) self.visualizer.add_scalar(
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch) tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch() self.new_epoch()

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains the impulse response augmentation model.""" """Contains the impulse response augmentation model."""
import jsonlines import jsonlines
from paddlespeech.s2t.frontend.audio import AudioSegment from paddlespeech.s2t.frontend.audio import AudioSegment
from paddlespeech.s2t.frontend.augmentor.base import AugmentorBase from paddlespeech.s2t.frontend.augmentor.base import AugmentorBase

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains the noise perturb augmentation model.""" """Contains the noise perturb augmentation model."""
import jsonlines import jsonlines
from paddlespeech.s2t.frontend.audio import AudioSegment from paddlespeech.s2t.frontend.audio import AudioSegment
from paddlespeech.s2t.frontend.augmentor.base import AugmentorBase from paddlespeech.s2t.frontend.augmentor.base import AugmentorBase

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains feature normalizers.""" """Contains feature normalizers."""
import json import json
import jsonlines import jsonlines
import numpy as np import numpy as np
import paddle import paddle
@ -26,7 +27,8 @@ from paddlespeech.s2t.utils.log import Log
__all__ = ["FeatureNormalizer"] __all__ = ["FeatureNormalizer"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
# https://github.com/PaddlePaddle/Paddle/pull/31481 # https://github.com/PaddlePaddle/Paddle/pull/31481
class CollateFunc(object): class CollateFunc(object):
def __init__(self, feature_func): def __init__(self, feature_func):
@ -62,7 +64,7 @@ class AudioDataset(Dataset):
with jsonlines.open(manifest_path, 'r') as reader: with jsonlines.open(manifest_path, 'r') as reader:
manifest = list(reader) manifest = list(reader)
if num_samples == -1: if num_samples == -1:
sampled_manifest = manifest sampled_manifest = manifest
else: else:

@ -64,7 +64,7 @@ def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]:
char_list.append(MASKCTC) char_list.append(MASKCTC)
return char_list return char_list
def read_manifest( def read_manifest(
manifest_path, manifest_path,
max_input_len=float('inf'), max_input_len=float('inf'),

@ -15,8 +15,8 @@ from typing import Any
from typing import Dict from typing import Dict
from typing import List from typing import List
from typing import Text from typing import Text
import jsonlines
import jsonlines
import numpy as np import numpy as np
from paddle.io import DataLoader from paddle.io import DataLoader
@ -93,7 +93,7 @@ class BatchDataLoader():
# read json data # read json data
with jsonlines.open(json_file, 'r') as reader: with jsonlines.open(json_file, 'r') as reader:
self.data_json = list(reader) self.data_json = list(reader)
self.feat_dim, self.vocab_size = feat_dim_and_vocab_size( self.feat_dim, self.vocab_size = feat_dim_and_vocab_size(
self.data_json, mode='asr') self.data_json, mode='asr')

@ -14,6 +14,7 @@
# Modified from espnet(https://github.com/espnet/espnet) # Modified from espnet(https://github.com/espnet/espnet)
# Modified from wenet(https://github.com/wenet-e2e/wenet) # Modified from wenet(https://github.com/wenet-e2e/wenet)
from typing import Optional from typing import Optional
import jsonlines import jsonlines
from paddle.io import Dataset from paddle.io import Dataset
from yacs.config import CfgNode from yacs.config import CfgNode

@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
""" """
rng = np.random.RandomState(epoch) rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1) shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size)) batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices) rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch] batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False assert clipped is False

@ -309,8 +309,10 @@ class Trainer():
logger.info( logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer: if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch) self.visualizer.add_scalar(
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch) tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
# after epoch # after epoch
self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.save(tag=self.epoch, infos={'val_loss': cv_loss})

@ -20,6 +20,7 @@ import time
import wave import wave
from time import gmtime from time import gmtime
from time import strftime from time import strftime
import jsonlines import jsonlines
__all__ = ["socket_send", "warm_up_test", "AsrTCPServer", "AsrRequestHandler"] __all__ = ["socket_send", "warm_up_test", "AsrTCPServer", "AsrRequestHandler"]

@ -252,8 +252,10 @@ class Trainer():
self.logger.info("Epoch {} Val info val_loss {}, F1_score {}". self.logger.info("Epoch {} Val info val_loss {}, F1_score {}".
format(self.epoch, total_loss, F1_score)) format(self.epoch, total_loss, F1_score))
if self.visualizer: if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch) self.visualizer.add_scalar(
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch) tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.save( self.save(
tag=self.epoch, infos={"val_loss": total_loss, tag=self.epoch, infos={"val_loss": total_loss,

@ -19,9 +19,10 @@ import argparse
import functools import functools
import os import os
import tempfile import tempfile
import jsonlines
from collections import Counter from collections import Counter
import jsonlines
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import BLANK from paddlespeech.s2t.frontend.utility import BLANK
from paddlespeech.s2t.frontend.utility import SOS from paddlespeech.s2t.frontend.utility import SOS
@ -63,7 +64,7 @@ def count_manifest(counter, text_feature, manifest_path):
with jsonlines.open(manifest_path, 'r') as reader: with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader: for json_data in reader:
manifest_jsons.append(json_data) manifest_jsons.append(json_data)
for line_json in manifest_jsons: for line_json in manifest_jsons:
line = text_feature.tokenize(line_json['text'], replace_space=False) line = text_feature.tokenize(line_json['text'], replace_space=False)
counter.update(line) counter.update(line)
@ -73,7 +74,7 @@ def dump_text_manifest(fileobj, manifest_path, key='text'):
with jsonlines.open(manifest_path, 'r') as reader: with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader: for json_data in reader:
manifest_jsons.append(json_data) manifest_jsons.append(json_data)
for line_json in manifest_jsons: for line_json in manifest_jsons:
fileobj.write(line_json[key] + "\n") fileobj.write(line_json[key] + "\n")

@ -16,6 +16,7 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
import jsonlines import jsonlines
key_whitelist = set(['feat', 'text', 'syllable', 'phone']) key_whitelist = set(['feat', 'text', 'syllable', 'phone'])
@ -34,7 +35,7 @@ def dump_manifest(manifest_path, output_dir: Union[str, Path]):
with jsonlines.open(str(manifest_path), 'r') as reader: with jsonlines.open(str(manifest_path), 'r') as reader:
manifest_jsons = list(reader) manifest_jsons = list(reader)
first_line = manifest_jsons[0] first_line = manifest_jsons[0]
file_map = {} file_map = {}

@ -15,9 +15,10 @@
"""format manifest with more metadata.""" """format manifest with more metadata."""
import argparse import argparse
import functools import functools
import jsonlines
import json import json
import jsonlines
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import load_cmvn from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.io.utility import feat_type from paddlespeech.s2t.io.utility import feat_type
@ -73,7 +74,7 @@ def main():
for manifest_path in args.manifest_paths: for manifest_path in args.manifest_paths:
with jsonlines.open(str(manifest_path), 'r') as reader: with jsonlines.open(str(manifest_path), 'r') as reader:
manifest_jsons = list(reader) manifest_jsons = list(reader)
for line_json in manifest_jsons: for line_json in manifest_jsons:
output_json = { output_json = {
"input": [], "input": [],

@ -16,6 +16,7 @@
import argparse import argparse
import functools import functools
import json import json
import jsonlines import jsonlines
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer

@ -3,6 +3,7 @@
import argparse import argparse
import functools import functools
from pathlib import Path from pathlib import Path
import jsonlines import jsonlines
from utils.utility import add_arguments from utils.utility import add_arguments

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import hashlib import hashlib
import json
import os import os
import sys import sys
import tarfile import tarfile

Loading…
Cancel
Save