[s2t] move s2t data preprocess into paddlespeech.dataset (#3189)
* move s2t data preprocess into paddlespeech.dataset * avg model, compute wer, format rsl into paddlespeech.dataset * fix format rsl * fix avg ckptspull/3193/head
parent
8c7859d3bc
commit
df3be4acae
@ -0,0 +1,20 @@
|
|||||||
|
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# s2t utils binaries.
|
||||||
|
from .avg_model import main as avg_ckpts_main
|
||||||
|
from .build_vocab import main as build_vocab_main
|
||||||
|
from .compute_mean_std import main as compute_mean_std_main
|
||||||
|
from .compute_wer import main as compute_wer_main
|
||||||
|
from .format_data import main as format_data_main
|
||||||
|
from .format_rsl import main as format_rsl_main
|
@ -0,0 +1,125 @@
|
|||||||
|
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
|
||||||
|
def define_argparse():
|
||||||
|
parser = argparse.ArgumentParser(description='average model')
|
||||||
|
parser.add_argument('--dst_model', required=True, help='averaged model')
|
||||||
|
parser.add_argument(
|
||||||
|
'--ckpt_dir', required=True, help='ckpt model dir for average')
|
||||||
|
parser.add_argument(
|
||||||
|
'--val_best', action="store_true", help='averaged model')
|
||||||
|
parser.add_argument(
|
||||||
|
'--num', default=5, type=int, help='nums for averaged model')
|
||||||
|
parser.add_argument(
|
||||||
|
'--min_epoch',
|
||||||
|
default=0,
|
||||||
|
type=int,
|
||||||
|
help='min epoch used for averaging model')
|
||||||
|
parser.add_argument(
|
||||||
|
'--max_epoch',
|
||||||
|
default=65536, # Big enough
|
||||||
|
type=int,
|
||||||
|
help='max epoch used for averaging model')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def average_checkpoints(dst_model="",
|
||||||
|
ckpt_dir="",
|
||||||
|
val_best=True,
|
||||||
|
num=5,
|
||||||
|
min_epoch=0,
|
||||||
|
max_epoch=65536):
|
||||||
|
paddle.set_device('cpu')
|
||||||
|
|
||||||
|
val_scores = []
|
||||||
|
jsons = glob.glob(f'{ckpt_dir}/[!train]*.json')
|
||||||
|
jsons = sorted(jsons, key=os.path.getmtime, reverse=True)
|
||||||
|
for y in jsons:
|
||||||
|
with open(y, 'r') as f:
|
||||||
|
dic_json = json.load(f)
|
||||||
|
loss = dic_json['val_loss']
|
||||||
|
epoch = dic_json['epoch']
|
||||||
|
if epoch >= min_epoch and epoch <= max_epoch:
|
||||||
|
val_scores.append((epoch, loss))
|
||||||
|
assert val_scores, f"Not find any valid checkpoints: {val_scores}"
|
||||||
|
val_scores = np.array(val_scores)
|
||||||
|
|
||||||
|
if val_best:
|
||||||
|
sort_idx = np.argsort(val_scores[:, 1])
|
||||||
|
sorted_val_scores = val_scores[sort_idx]
|
||||||
|
else:
|
||||||
|
sorted_val_scores = val_scores
|
||||||
|
|
||||||
|
beat_val_scores = sorted_val_scores[:num, 1]
|
||||||
|
selected_epochs = sorted_val_scores[:num, 0].astype(np.int64)
|
||||||
|
avg_val_score = np.mean(beat_val_scores)
|
||||||
|
print("selected val scores = " + str(beat_val_scores))
|
||||||
|
print("selected epochs = " + str(selected_epochs))
|
||||||
|
print("averaged val score = " + str(avg_val_score))
|
||||||
|
|
||||||
|
path_list = [
|
||||||
|
ckpt_dir + '/{}.pdparams'.format(int(epoch))
|
||||||
|
for epoch in sorted_val_scores[:num, 0]
|
||||||
|
]
|
||||||
|
print(path_list)
|
||||||
|
|
||||||
|
avg = None
|
||||||
|
num = args.num
|
||||||
|
assert num == len(path_list)
|
||||||
|
for path in path_list:
|
||||||
|
print(f'Processing {path}')
|
||||||
|
states = paddle.load(path)
|
||||||
|
if avg is None:
|
||||||
|
avg = states
|
||||||
|
else:
|
||||||
|
for k in avg.keys():
|
||||||
|
avg[k] += states[k]
|
||||||
|
# average
|
||||||
|
for k in avg.keys():
|
||||||
|
if avg[k] is not None:
|
||||||
|
avg[k] /= num
|
||||||
|
|
||||||
|
paddle.save(avg, args.dst_model)
|
||||||
|
print(f'Saving to {args.dst_model}')
|
||||||
|
|
||||||
|
meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json'
|
||||||
|
with open(meta_path, 'w') as f:
|
||||||
|
data = json.dumps({
|
||||||
|
"mode": 'val_best' if args.val_best else 'latest',
|
||||||
|
"avg_ckpt": args.dst_model,
|
||||||
|
"val_loss_mean": avg_val_score,
|
||||||
|
"ckpts": path_list,
|
||||||
|
"epochs": selected_epochs.tolist(),
|
||||||
|
"val_losses": beat_val_scores.tolist(),
|
||||||
|
})
|
||||||
|
f.write(data + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = define_argparse()
|
||||||
|
average_checkpoints(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -0,0 +1,166 @@
|
|||||||
|
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Build vocabulary from manifest files.
|
||||||
|
Each item in vocabulary file is a character.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import functools
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
import jsonlines
|
||||||
|
|
||||||
|
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
|
||||||
|
from paddlespeech.s2t.frontend.utility import BLANK
|
||||||
|
from paddlespeech.s2t.frontend.utility import SOS
|
||||||
|
from paddlespeech.s2t.frontend.utility import SPACE
|
||||||
|
from paddlespeech.s2t.frontend.utility import UNK
|
||||||
|
from paddlespeech.utils.argparse import add_arguments
|
||||||
|
from paddlespeech.utils.argparse import print_arguments
|
||||||
|
|
||||||
|
|
||||||
|
def count_manifest(counter, text_feature, manifest_path):
|
||||||
|
manifest_jsons = []
|
||||||
|
with jsonlines.open(manifest_path, 'r') as reader:
|
||||||
|
for json_data in reader:
|
||||||
|
manifest_jsons.append(json_data)
|
||||||
|
|
||||||
|
for line_json in manifest_jsons:
|
||||||
|
if isinstance(line_json['text'], str):
|
||||||
|
tokens = text_feature.tokenize(
|
||||||
|
line_json['text'], replace_space=False)
|
||||||
|
|
||||||
|
counter.update(tokens)
|
||||||
|
else:
|
||||||
|
assert isinstance(line_json['text'], list)
|
||||||
|
for text in line_json['text']:
|
||||||
|
tokens = text_feature.tokenize(text, replace_space=False)
|
||||||
|
counter.update(tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def dump_text_manifest(fileobj, manifest_path, key='text'):
|
||||||
|
manifest_jsons = []
|
||||||
|
with jsonlines.open(manifest_path, 'r') as reader:
|
||||||
|
for json_data in reader:
|
||||||
|
manifest_jsons.append(json_data)
|
||||||
|
|
||||||
|
for line_json in manifest_jsons:
|
||||||
|
if isinstance(line_json[key], str):
|
||||||
|
fileobj.write(line_json[key] + "\n")
|
||||||
|
else:
|
||||||
|
assert isinstance(line_json[key], list)
|
||||||
|
for line in line_json[key]:
|
||||||
|
fileobj.write(line + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def build_vocab(manifest_paths="",
|
||||||
|
vocab_path="examples/librispeech/data/vocab.txt",
|
||||||
|
unit_type="char",
|
||||||
|
count_threshold=0,
|
||||||
|
text_keys='text',
|
||||||
|
spm_mode="unigram",
|
||||||
|
spm_vocab_size=0,
|
||||||
|
spm_model_prefix="",
|
||||||
|
spm_character_coverage=0.9995):
|
||||||
|
fout = open(vocab_path, 'w', encoding='utf-8')
|
||||||
|
fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC
|
||||||
|
fout.write(UNK + '\n') # <unk> must be 1
|
||||||
|
|
||||||
|
if unit_type == 'spm':
|
||||||
|
# tools/spm_train --input=$wave_data/lang_char/input.txt
|
||||||
|
# --vocab_size=${nbpe} --model_type=${bpemode}
|
||||||
|
# --model_prefix=${bpemodel} --input_sentence_size=100000000
|
||||||
|
import sentencepiece as spm
|
||||||
|
|
||||||
|
fp = tempfile.NamedTemporaryFile(mode='w', delete=False)
|
||||||
|
for manifest_path in manifest_paths:
|
||||||
|
_text_keys = [text_keys] if type(
|
||||||
|
text_keys) is not list else text_keys
|
||||||
|
for text_key in _text_keys:
|
||||||
|
dump_text_manifest(fp, manifest_path, key=text_key)
|
||||||
|
fp.close()
|
||||||
|
# train
|
||||||
|
spm.SentencePieceTrainer.Train(
|
||||||
|
input=fp.name,
|
||||||
|
vocab_size=spm_vocab_size,
|
||||||
|
model_type=spm_mode,
|
||||||
|
model_prefix=spm_model_prefix,
|
||||||
|
input_sentence_size=100000000,
|
||||||
|
character_coverage=spm_character_coverage)
|
||||||
|
os.unlink(fp.name)
|
||||||
|
|
||||||
|
# encode
|
||||||
|
text_feature = TextFeaturizer(unit_type, "", spm_model_prefix)
|
||||||
|
counter = Counter()
|
||||||
|
|
||||||
|
for manifest_path in manifest_paths:
|
||||||
|
count_manifest(counter, text_feature, manifest_path)
|
||||||
|
|
||||||
|
count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
tokens = []
|
||||||
|
for token, count in count_sorted:
|
||||||
|
if count < count_threshold:
|
||||||
|
break
|
||||||
|
# replace space by `<space>`
|
||||||
|
token = SPACE if token == ' ' else token
|
||||||
|
tokens.append(token)
|
||||||
|
|
||||||
|
tokens = sorted(tokens)
|
||||||
|
for token in tokens:
|
||||||
|
fout.write(token + '\n')
|
||||||
|
|
||||||
|
fout.write(SOS + "\n") # <sos/eos>
|
||||||
|
fout.close()
|
||||||
|
|
||||||
|
|
||||||
|
def define_argparse():
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
add_arg = functools.partial(add_arguments, argparser=parser)
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm")
|
||||||
|
add_arg('count_threshold', int, 0,
|
||||||
|
"Truncation threshold for char/word counts.Default 0, no truncate.")
|
||||||
|
add_arg('vocab_path', str,
|
||||||
|
'examples/librispeech/data/vocab.txt',
|
||||||
|
"Filepath to write the vocabulary.")
|
||||||
|
add_arg('manifest_paths', str,
|
||||||
|
None,
|
||||||
|
"Filepaths of manifests for building vocabulary. "
|
||||||
|
"You can provide multiple manifest files.",
|
||||||
|
nargs='+',
|
||||||
|
required=True)
|
||||||
|
add_arg('text_keys', str,
|
||||||
|
'text',
|
||||||
|
"keys of the text in manifest for building vocabulary. "
|
||||||
|
"You can provide multiple k.",
|
||||||
|
nargs='+')
|
||||||
|
# bpe
|
||||||
|
add_arg('spm_vocab_size', int, 0, "Vocab size for spm.")
|
||||||
|
add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm")
|
||||||
|
add_arg('spm_model_prefix', str, "", "spm_model_%(spm_mode)_%(count_threshold), spm model prefix, only need when `unit_type` is spm")
|
||||||
|
add_arg('spm_character_coverage', float, 0.9995, "character coverage to determine the minimum symbols")
|
||||||
|
# yapf: disable
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = define_argparse()
|
||||||
|
print_arguments(args, globals())
|
||||||
|
build_vocab(**vars(args))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -0,0 +1,106 @@
|
|||||||
|
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Compute mean and std for feature normalizer, and save to file."""
|
||||||
|
import argparse
|
||||||
|
import functools
|
||||||
|
|
||||||
|
from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline
|
||||||
|
from paddlespeech.s2t.frontend.featurizer.audio_featurizer import AudioFeaturizer
|
||||||
|
from paddlespeech.s2t.frontend.normalizer import FeatureNormalizer
|
||||||
|
from paddlespeech.utils.argparse import add_arguments
|
||||||
|
from paddlespeech.utils.argparse import print_arguments
|
||||||
|
|
||||||
|
|
||||||
|
def compute_cmvn(manifest_path="data/librispeech/manifest.train",
|
||||||
|
output_path="data/librispeech/mean_std.npz",
|
||||||
|
num_samples=2000,
|
||||||
|
num_workers=0,
|
||||||
|
spectrum_type="linear",
|
||||||
|
feat_dim=13,
|
||||||
|
delta_delta=False,
|
||||||
|
stride_ms=10,
|
||||||
|
window_ms=20,
|
||||||
|
sample_rate=16000,
|
||||||
|
use_dB_normalization=True,
|
||||||
|
target_dB=-20):
|
||||||
|
|
||||||
|
augmentation_pipeline = AugmentationPipeline('{}')
|
||||||
|
audio_featurizer = AudioFeaturizer(
|
||||||
|
spectrum_type=spectrum_type,
|
||||||
|
feat_dim=feat_dim,
|
||||||
|
delta_delta=delta_delta,
|
||||||
|
stride_ms=float(stride_ms),
|
||||||
|
window_ms=float(window_ms),
|
||||||
|
n_fft=None,
|
||||||
|
max_freq=None,
|
||||||
|
target_sample_rate=sample_rate,
|
||||||
|
use_dB_normalization=use_dB_normalization,
|
||||||
|
target_dB=target_dB,
|
||||||
|
dither=0.0)
|
||||||
|
|
||||||
|
def augment_and_featurize(audio_segment):
|
||||||
|
augmentation_pipeline.transform_audio(audio_segment)
|
||||||
|
return audio_featurizer.featurize(audio_segment)
|
||||||
|
|
||||||
|
normalizer = FeatureNormalizer(
|
||||||
|
mean_std_filepath=None,
|
||||||
|
manifest_path=manifest_path,
|
||||||
|
featurize_func=augment_and_featurize,
|
||||||
|
num_samples=num_samples,
|
||||||
|
num_workers=num_workers)
|
||||||
|
normalizer.write_to_file(output_path)
|
||||||
|
|
||||||
|
|
||||||
|
def define_argparse():
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
add_arg = functools.partial(add_arguments, argparser=parser)
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
add_arg('manifest_path', str,
|
||||||
|
'data/librispeech/manifest.train',
|
||||||
|
"Filepath of manifest to compute normalizer's mean and stddev.")
|
||||||
|
|
||||||
|
add_arg('output_path', str,
|
||||||
|
'data/librispeech/mean_std.npz',
|
||||||
|
"Filepath of write mean and stddev to (.npz).")
|
||||||
|
add_arg('num_samples', int, 2000, "# of samples to for statistics.")
|
||||||
|
add_arg('num_workers',
|
||||||
|
default=0,
|
||||||
|
type=int,
|
||||||
|
help='num of subprocess workers for processing')
|
||||||
|
|
||||||
|
|
||||||
|
add_arg('spectrum_type', str,
|
||||||
|
'linear',
|
||||||
|
"Audio feature type. Options: linear, mfcc, fbank.",
|
||||||
|
choices=['linear', 'mfcc', 'fbank'])
|
||||||
|
add_arg('feat_dim', int, 13, "Audio feature dim.")
|
||||||
|
add_arg('delta_delta', bool, False, "Audio feature with delta delta.")
|
||||||
|
add_arg('stride_ms', int, 10, "stride length in ms.")
|
||||||
|
add_arg('window_ms', int, 20, "stride length in ms.")
|
||||||
|
add_arg('sample_rate', int, 16000, "target sample rate.")
|
||||||
|
add_arg('use_dB_normalization', bool, True, "do dB normalization.")
|
||||||
|
add_arg('target_dB', int, -20, "target dB.")
|
||||||
|
# yapf: disable
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = define_argparse()
|
||||||
|
print_arguments(args, globals())
|
||||||
|
compute_cmvn(**vars(args))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -0,0 +1,154 @@
|
|||||||
|
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""format manifest with more metadata."""
|
||||||
|
import argparse
|
||||||
|
import functools
|
||||||
|
import json
|
||||||
|
|
||||||
|
import jsonlines
|
||||||
|
|
||||||
|
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
|
||||||
|
from paddlespeech.s2t.frontend.utility import load_cmvn
|
||||||
|
from paddlespeech.s2t.io.utility import feat_type
|
||||||
|
from paddlespeech.utils.argparse import add_arguments
|
||||||
|
from paddlespeech.utils.argparse import print_arguments
|
||||||
|
|
||||||
|
|
||||||
|
def define_argparse():
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
add_arg = functools.partial(add_arguments, argparser=parser)
|
||||||
|
# yapf: disable
|
||||||
|
add_arg('manifest_paths', str,
|
||||||
|
None,
|
||||||
|
"Filepaths of manifests for building vocabulary. "
|
||||||
|
"You can provide multiple manifest files.",
|
||||||
|
nargs='+',
|
||||||
|
required=True)
|
||||||
|
add_arg('output_path', str, None, "filepath of formated manifest.", required=True)
|
||||||
|
add_arg('cmvn_path', str,
|
||||||
|
'examples/librispeech/data/mean_std.json',
|
||||||
|
"Filepath of cmvn.")
|
||||||
|
add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm")
|
||||||
|
add_arg('vocab_path', str,
|
||||||
|
'examples/librispeech/data/vocab.txt',
|
||||||
|
"Filepath of the vocabulary.")
|
||||||
|
# bpe
|
||||||
|
add_arg('spm_model_prefix', str, None,
|
||||||
|
"spm model prefix, spm_model_%(bpe_mode)_%(count_threshold), only need when `unit_type` is spm")
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
def format_data(
|
||||||
|
manifest_paths="",
|
||||||
|
output_path="",
|
||||||
|
cmvn_path="examples/librispeech/data/mean_std.json",
|
||||||
|
unit_type="char",
|
||||||
|
vocab_path="examples/librispeech/data/vocab.txt",
|
||||||
|
spm_model_prefix=""):
|
||||||
|
|
||||||
|
fout = open(output_path, 'w', encoding='utf-8')
|
||||||
|
|
||||||
|
# get feat dim
|
||||||
|
filetype = cmvn_path.split(".")[-1]
|
||||||
|
mean, istd = load_cmvn(cmvn_path, filetype=filetype)
|
||||||
|
feat_dim = mean.shape[0] #(D)
|
||||||
|
print(f"Feature dim: {feat_dim}")
|
||||||
|
|
||||||
|
text_feature = TextFeaturizer(unit_type, vocab_path, spm_model_prefix)
|
||||||
|
vocab_size = text_feature.vocab_size
|
||||||
|
print(f"Vocab size: {vocab_size}")
|
||||||
|
|
||||||
|
# josnline like this
|
||||||
|
# {
|
||||||
|
# "input": [{"name": "input1", "shape": (100, 83), "feat": "xxx.ark:123"}],
|
||||||
|
# "output": [{"name":"target1", "shape": (40, 5002), "text": "a b c de"}],
|
||||||
|
# "utt2spk": "111-2222",
|
||||||
|
# "utt": "111-2222-333"
|
||||||
|
# }
|
||||||
|
count = 0
|
||||||
|
for manifest_path in manifest_paths:
|
||||||
|
with jsonlines.open(str(manifest_path), 'r') as reader:
|
||||||
|
manifest_jsons = list(reader)
|
||||||
|
|
||||||
|
for line_json in manifest_jsons:
|
||||||
|
output_json = {
|
||||||
|
"input": [],
|
||||||
|
"output": [],
|
||||||
|
'utt': line_json['utt'],
|
||||||
|
'utt2spk': line_json.get('utt2spk', 'global'),
|
||||||
|
}
|
||||||
|
|
||||||
|
# output
|
||||||
|
line = line_json['text']
|
||||||
|
if isinstance(line, str):
|
||||||
|
# only one target
|
||||||
|
tokens = text_feature.tokenize(line)
|
||||||
|
tokenids = text_feature.featurize(line)
|
||||||
|
output_json['output'].append({
|
||||||
|
'name': 'target1',
|
||||||
|
'shape': (len(tokenids), vocab_size),
|
||||||
|
'text': line,
|
||||||
|
'token': ' '.join(tokens),
|
||||||
|
'tokenid': ' '.join(map(str, tokenids)),
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# isinstance(line, list), multi target in one vocab
|
||||||
|
for i, item in enumerate(line, 1):
|
||||||
|
tokens = text_feature.tokenize(item)
|
||||||
|
tokenids = text_feature.featurize(item)
|
||||||
|
output_json['output'].append({
|
||||||
|
'name': f'target{i}',
|
||||||
|
'shape': (len(tokenids), vocab_size),
|
||||||
|
'text': item,
|
||||||
|
'token': ' '.join(tokens),
|
||||||
|
'tokenid': ' '.join(map(str, tokenids)),
|
||||||
|
})
|
||||||
|
|
||||||
|
# input
|
||||||
|
line = line_json['feat']
|
||||||
|
if isinstance(line, str):
|
||||||
|
# only one input
|
||||||
|
feat_shape = line_json['feat_shape']
|
||||||
|
assert isinstance(feat_shape, (list, tuple)), type(feat_shape)
|
||||||
|
filetype = feat_type(line)
|
||||||
|
if filetype == 'sound':
|
||||||
|
feat_shape.append(feat_dim)
|
||||||
|
else: # kaldi
|
||||||
|
raise NotImplementedError('no support kaldi feat now!')
|
||||||
|
|
||||||
|
output_json['input'].append({
|
||||||
|
"name": "input1",
|
||||||
|
"shape": feat_shape,
|
||||||
|
"feat": line,
|
||||||
|
"filetype": filetype,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# isinstance(line, list), multi input
|
||||||
|
raise NotImplementedError("not support multi input now!")
|
||||||
|
|
||||||
|
fout.write(json.dumps(output_json) + '\n')
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
print(f"{manifest_paths} Examples number: {count}")
|
||||||
|
fout.close()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = define_argparse()
|
||||||
|
print_arguments(args, globals())
|
||||||
|
format_data(**vars(args))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -0,0 +1,143 @@
|
|||||||
|
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
format ref/hyp file for `utt text` format to compute CER/WER/MER.
|
||||||
|
|
||||||
|
norm:
|
||||||
|
BAC009S0764W0196 明确了发展目标和重点任务
|
||||||
|
BAC009S0764W0186 实现我国房地产市场的平稳运行
|
||||||
|
|
||||||
|
|
||||||
|
sclite:
|
||||||
|
加大对结构机械化环境和收集谈控机制力度(BAC009S0906W0240.wav)
|
||||||
|
河南省新乡市丰秋县刘光镇政府东五零左右(BAC009S0770W0441.wav)
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import jsonlines
|
||||||
|
|
||||||
|
from paddlespeech.utils.argparse import print_arguments
|
||||||
|
|
||||||
|
|
||||||
|
def transform_hyp(origin, trans, trans_sclite):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
origin: The input json file which contains the model output
|
||||||
|
trans: The output file for caculate CER/WER
|
||||||
|
trans_sclite: The output file for caculate CER/WER using sclite
|
||||||
|
"""
|
||||||
|
input_dict = {}
|
||||||
|
|
||||||
|
with open(origin, "r+", encoding="utf8") as f:
|
||||||
|
for item in jsonlines.Reader(f):
|
||||||
|
input_dict[item["utt"]] = item["hyps"][0]
|
||||||
|
|
||||||
|
if trans:
|
||||||
|
with open(trans, "w+", encoding="utf8") as f:
|
||||||
|
for key in input_dict.keys():
|
||||||
|
f.write(key + " " + input_dict[key] + "\n")
|
||||||
|
print(f"transform_hyp output: {trans}")
|
||||||
|
|
||||||
|
if trans_sclite:
|
||||||
|
with open(trans_sclite, "w+") as f:
|
||||||
|
for key in input_dict.keys():
|
||||||
|
line = input_dict[key] + "(" + key + ".wav" + ")" + "\n"
|
||||||
|
f.write(line)
|
||||||
|
print(f"transform_hyp output: {trans_sclite}")
|
||||||
|
|
||||||
|
|
||||||
|
def transform_ref(origin, trans, trans_sclite):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
origin: The input json file which contains the model output
|
||||||
|
trans: The output file for caculate CER/WER
|
||||||
|
trans_sclite: The output file for caculate CER/WER using sclite
|
||||||
|
"""
|
||||||
|
input_dict = {}
|
||||||
|
|
||||||
|
with open(origin, "r", encoding="utf8") as f:
|
||||||
|
for item in jsonlines.Reader(f):
|
||||||
|
input_dict[item["utt"]] = item["text"]
|
||||||
|
|
||||||
|
if trans:
|
||||||
|
with open(trans, "w", encoding="utf8") as f:
|
||||||
|
for key in input_dict.keys():
|
||||||
|
f.write(key + " " + input_dict[key] + "\n")
|
||||||
|
print(f"transform_hyp output: {trans}")
|
||||||
|
|
||||||
|
if trans_sclite:
|
||||||
|
with open(trans_sclite, "w") as f:
|
||||||
|
for key in input_dict.keys():
|
||||||
|
line = input_dict[key] + "(" + key + ".wav" + ")" + "\n"
|
||||||
|
f.write(line)
|
||||||
|
print(f"transform_hyp output: {trans_sclite}")
|
||||||
|
|
||||||
|
|
||||||
|
def define_argparse():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog='format ref/hyp file for compute CER/WER', add_help=True)
|
||||||
|
parser.add_argument(
|
||||||
|
'--origin_hyp', type=str, default="", help='origin hyp file')
|
||||||
|
parser.add_argument(
|
||||||
|
'--trans_hyp',
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help='hyp file for caculating CER/WER')
|
||||||
|
parser.add_argument(
|
||||||
|
'--trans_hyp_sclite',
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help='hyp file for caculating CER/WER by sclite')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--origin_ref', type=str, default="", help='origin ref file')
|
||||||
|
parser.add_argument(
|
||||||
|
'--trans_ref',
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help='ref file for caculating CER/WER')
|
||||||
|
parser.add_argument(
|
||||||
|
'--trans_ref_sclite',
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help='ref file for caculating CER/WER by sclite')
|
||||||
|
parser_args = parser.parse_args()
|
||||||
|
return parser_args
|
||||||
|
|
||||||
|
|
||||||
|
def format_result(origin_hyp="",
|
||||||
|
trans_hyp="",
|
||||||
|
trans_hyp_sclite="",
|
||||||
|
origin_ref="",
|
||||||
|
trans_ref="",
|
||||||
|
trans_ref_sclite=""):
|
||||||
|
|
||||||
|
if origin_hyp:
|
||||||
|
transform_hyp(
|
||||||
|
origin=origin_hyp, trans=trans_hyp, trans_sclite=trans_hyp_sclite)
|
||||||
|
|
||||||
|
if origin_ref:
|
||||||
|
transform_ref(
|
||||||
|
origin=origin_ref, trans=trans_ref, trans_sclite=trans_ref_sclite)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = define_argparse()
|
||||||
|
print_arguments(args, globals())
|
||||||
|
|
||||||
|
format_result(**vars(args))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in new issue