format code

pull/1054/head
Hui Zhang 3 years ago
parent d395c2b8e3
commit 39228864bb

@ -339,6 +339,3 @@ You need to prepare an audio file, please confirm the sample rate of the audio i
```bash
CUDA_VISIBLE_DEVICES= ./local/test_hub.sh conf/transformer.yaml exp/transformer/checkpoints/avg_20 data/test_audio.wav
```

@ -129,8 +129,8 @@ class U2Trainer(Trainer):
losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()})
for key, val in losses_np_v.items():
self.visualizer.add_scalar(tag='train/'+key, value=val, step=self.iteration-1)
self.visualizer.add_scalar(
tag='train/' + key, value=val, step=self.iteration - 1)
@paddle.no_grad()
def valid(self):
@ -238,8 +238,10 @@ class U2Trainer(Trainer):
logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.visualizer.add_scalar(
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch()

@ -132,7 +132,8 @@ class U2Trainer(Trainer):
losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()})
for key, val in losses_np_v.items():
self.visualizer.add_scalar(tag="train/"+key, value=val, step=self.iteration - 1)
self.visualizer.add_scalar(
tag="train/" + key, value=val, step=self.iteration - 1)
@paddle.no_grad()
def valid(self):
@ -222,9 +223,11 @@ class U2Trainer(Trainer):
logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.visualizer.add_scalar(
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch()

@ -139,7 +139,8 @@ class U2STTrainer(Trainer):
losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()})
for key, val in losses_np_v.items():
self.visualizer.add_scalar(tag="train/"+key, value=val, step=self.iteration - 1)
self.visualizer.add_scalar(
tag="train/" + key, value=val, step=self.iteration - 1)
@paddle.no_grad()
def valid(self):
@ -235,9 +236,11 @@ class U2STTrainer(Trainer):
logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.visualizer.add_scalar(
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch()

@ -13,6 +13,7 @@
# limitations under the License.
"""Contains the impulse response augmentation model."""
import jsonlines
from paddlespeech.s2t.frontend.audio import AudioSegment
from paddlespeech.s2t.frontend.augmentor.base import AugmentorBase

@ -13,6 +13,7 @@
# limitations under the License.
"""Contains the noise perturb augmentation model."""
import jsonlines
from paddlespeech.s2t.frontend.audio import AudioSegment
from paddlespeech.s2t.frontend.augmentor.base import AugmentorBase

@ -13,6 +13,7 @@
# limitations under the License.
"""Contains feature normalizers."""
import json
import jsonlines
import numpy as np
import paddle
@ -26,7 +27,8 @@ from paddlespeech.s2t.utils.log import Log
__all__ = ["FeatureNormalizer"]
logger = Log(__name__).getlog()
# https://github.com/PaddlePaddle/Paddle/pull/31481
class CollateFunc(object):
def __init__(self, feature_func):
@ -62,7 +64,7 @@ class AudioDataset(Dataset):
with jsonlines.open(manifest_path, 'r') as reader:
manifest = list(reader)
if num_samples == -1:
sampled_manifest = manifest
else:

@ -64,7 +64,7 @@ def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]:
char_list.append(MASKCTC)
return char_list
def read_manifest(
manifest_path,
max_input_len=float('inf'),

@ -15,8 +15,8 @@ from typing import Any
from typing import Dict
from typing import List
from typing import Text
import jsonlines
import jsonlines
import numpy as np
from paddle.io import DataLoader
@ -93,7 +93,7 @@ class BatchDataLoader():
# read json data
with jsonlines.open(json_file, 'r') as reader:
self.data_json = list(reader)
self.feat_dim, self.vocab_size = feat_dim_and_vocab_size(
self.data_json, mode='asr')

@ -14,6 +14,7 @@
# Modified from espnet(https://github.com/espnet/espnet)
# Modified from wenet(https://github.com/wenet-e2e/wenet)
from typing import Optional
import jsonlines
from paddle.io import Dataset
from yacs.config import CfgNode

@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size))
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False

@ -309,8 +309,10 @@ class Trainer():
logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.visualizer.add_scalar(
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
# after epoch
self.save(tag=self.epoch, infos={'val_loss': cv_loss})

@ -20,6 +20,7 @@ import time
import wave
from time import gmtime
from time import strftime
import jsonlines
__all__ = ["socket_send", "warm_up_test", "AsrTCPServer", "AsrRequestHandler"]

@ -252,8 +252,10 @@ class Trainer():
self.logger.info("Epoch {} Val info val_loss {}, F1_score {}".
format(self.epoch, total_loss, F1_score))
if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.visualizer.add_scalar(
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.save(
tag=self.epoch, infos={"val_loss": total_loss,

@ -19,9 +19,10 @@ import argparse
import functools
import os
import tempfile
import jsonlines
from collections import Counter
import jsonlines
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import BLANK
from paddlespeech.s2t.frontend.utility import SOS
@ -63,7 +64,7 @@ def count_manifest(counter, text_feature, manifest_path):
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
manifest_jsons.append(json_data)
for line_json in manifest_jsons:
line = text_feature.tokenize(line_json['text'], replace_space=False)
counter.update(line)
@ -73,7 +74,7 @@ def dump_text_manifest(fileobj, manifest_path, key='text'):
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
manifest_jsons.append(json_data)
for line_json in manifest_jsons:
fileobj.write(line_json[key] + "\n")

@ -16,6 +16,7 @@
import argparse
from pathlib import Path
from typing import Union
import jsonlines
key_whitelist = set(['feat', 'text', 'syllable', 'phone'])
@ -34,7 +35,7 @@ def dump_manifest(manifest_path, output_dir: Union[str, Path]):
with jsonlines.open(str(manifest_path), 'r') as reader:
manifest_jsons = list(reader)
first_line = manifest_jsons[0]
file_map = {}

@ -15,9 +15,10 @@
"""format manifest with more metadata."""
import argparse
import functools
import jsonlines
import json
import jsonlines
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.io.utility import feat_type
@ -73,7 +74,7 @@ def main():
for manifest_path in args.manifest_paths:
with jsonlines.open(str(manifest_path), 'r') as reader:
manifest_jsons = list(reader)
for line_json in manifest_jsons:
output_json = {
"input": [],

@ -16,6 +16,7 @@
import argparse
import functools
import json
import jsonlines
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer

@ -3,6 +3,7 @@
import argparse
import functools
from pathlib import Path
import jsonlines
from utils.utility import add_arguments

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import json
import os
import sys
import tarfile

Loading…
Cancel
Save