From d647cde8702f6ba225ccef8feb708e73d89ec73c Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Thu, 4 Nov 2021 09:50:33 +0000 Subject: [PATCH] change the lm dataset dir --- examples/librispeech/s2/path.sh | 6 +- .../exps/lm/transformer/lm_cacu_perplexity.py | 4 +- paddlespeech/s2t/io/collator.py | 38 ---------- paddlespeech/s2t/io/dataset.py | 19 ----- paddlespeech/s2t/models/lm/dataset.py | 74 +++++++++++++++++++ 5 files changed, 81 insertions(+), 60 deletions(-) create mode 100644 paddlespeech/s2t/models/lm/dataset.py diff --git a/examples/librispeech/s2/path.sh b/examples/librispeech/s2/path.sh index ad6b69139..840835c2f 100644 --- a/examples/librispeech/s2/path.sh +++ b/examples/librispeech/s2/path.sh @@ -14,6 +14,10 @@ export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ MODEL=u2_kaldi export BIN_DIR=${MAIN_ROOT}/paddlespeech/s2t/exps/${MODEL}/bin +LM_MODEL=transformer +export LM_BIN_DIR=${MAIN_ROOT}/paddlespeech/s2t/exps/lm/${LM_MODEL}/bin + + # srilm export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10 export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs @@ -25,4 +29,4 @@ export KALDI_ROOT=${MAIN_ROOT}/tools/kaldi [ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH [ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present, can not using Kaldi!" -[ -f $KALDI_ROOT/tools/config/common_path.sh ] && . $KALDI_ROOT/tools/config/common_path.sh \ No newline at end of file +[ -f $KALDI_ROOT/tools/config/common_path.sh ] && . $KALDI_ROOT/tools/config/common_path.sh diff --git a/paddlespeech/s2t/exps/lm/transformer/lm_cacu_perplexity.py b/paddlespeech/s2t/exps/lm/transformer/lm_cacu_perplexity.py index ab0ec8f0e..e628f3234 100644 --- a/paddlespeech/s2t/exps/lm/transformer/lm_cacu_perplexity.py +++ b/paddlespeech/s2t/exps/lm/transformer/lm_cacu_perplexity.py @@ -19,8 +19,8 @@ import paddle from paddle.io import DataLoader from yacs.config import CfgNode -from paddlespeech.s2t.io.collator import TextCollatorSpm -from paddlespeech.s2t.io.dataset import TextDataset +from paddlespeech.s2t.models.lm.dataset import TextCollatorSpm +from paddlespeech.s2t.models.lm.dataset import TextDataset from paddlespeech.s2t.models.lm_interface import dynamic_import_lm from paddlespeech.s2t.utils.log import Log diff --git a/paddlespeech/s2t/io/collator.py b/paddlespeech/s2t/io/collator.py index a500f10c9..cb7349d00 100644 --- a/paddlespeech/s2t/io/collator.py +++ b/paddlespeech/s2t/io/collator.py @@ -19,7 +19,6 @@ from yacs.config import CfgNode from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline from paddlespeech.s2t.frontend.featurizer.speech_featurizer import SpeechFeaturizer -from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.normalizer import FeatureNormalizer from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.frontend.utility import IGNORE_ID @@ -46,43 +45,6 @@ def _tokenids(text, keep_transcription_text): return tokens -class TextCollatorSpm(): - def __init__(self, unit_type, vocab_filepath, spm_model_prefix): - assert (vocab_filepath is not None) - self.text_featurizer = TextFeaturizer( - unit_type=unit_type, - vocab_filepath=vocab_filepath, - spm_model_prefix=spm_model_prefix) - self.eos_id = self.text_featurizer.eos_id - self.blank_id = self.text_featurizer.blank_id - - def __call__(self, batch): - """ - return type [List, np.array [B, T], np.array [B, T], np.array[B]] - """ - keys = [] - texts = [] - texts_input = [] - texts_output = [] - text_lens = [] - - for idx, item in enumerate(batch): - key = item.split(" ")[0].strip() - text = " ".join(item.split(" ")[1:]) - keys.append(key) - token_ids = self.text_featurizer.featurize(text) - texts_input.append( - np.array([self.eos_id] + token_ids).astype(np.int64)) - texts_output.append( - np.array(token_ids + [self.eos_id]).astype(np.int64)) - text_lens.append(len(token_ids) + 1) - - ys_input_pad = pad_list(texts_input, self.blank_id).astype(np.int64) - ys_output_pad = pad_list(texts_output, self.blank_id).astype(np.int64) - y_lens = np.array(text_lens).astype(np.int64) - return keys, ys_input_pad, ys_output_pad, y_lens - - class SpeechCollatorBase(): def __init__( self, diff --git a/paddlespeech/s2t/io/dataset.py b/paddlespeech/s2t/io/dataset.py index 121410c8b..8690879c5 100644 --- a/paddlespeech/s2t/io/dataset.py +++ b/paddlespeech/s2t/io/dataset.py @@ -24,25 +24,6 @@ __all__ = ["ManifestDataset", "TransformDataset"] logger = Log(__name__).getlog() -class TextDataset(Dataset): - @classmethod - def from_file(cls, file_path): - dataset = cls(file_path) - return dataset - - def __init__(self, file_path): - self._manifest = [] - with open(file_path) as f: - for line in f: - self._manifest.append(line.strip()) - - def __len__(self): - return len(self._manifest) - - def __getitem__(self, idx): - return self._manifest[idx] - - class ManifestDataset(Dataset): @classmethod def params(cls, config: Optional[CfgNode]=None) -> CfgNode: diff --git a/paddlespeech/s2t/models/lm/dataset.py b/paddlespeech/s2t/models/lm/dataset.py new file mode 100644 index 000000000..4059dfe2c --- /dev/null +++ b/paddlespeech/s2t/models/lm/dataset.py @@ -0,0 +1,74 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from paddle.io import Dataset + +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.io.utility import pad_list + + +class TextDataset(Dataset): + @classmethod + def from_file(cls, file_path): + dataset = cls(file_path) + return dataset + + def __init__(self, file_path): + self._manifest = [] + with open(file_path) as f: + for line in f: + self._manifest.append(line.strip()) + + def __len__(self): + return len(self._manifest) + + def __getitem__(self, idx): + return self._manifest[idx] + + +class TextCollatorSpm(): + def __init__(self, unit_type, vocab_filepath, spm_model_prefix): + assert (vocab_filepath is not None) + self.text_featurizer = TextFeaturizer( + unit_type=unit_type, + vocab_filepath=vocab_filepath, + spm_model_prefix=spm_model_prefix) + self.eos_id = self.text_featurizer.eos_id + self.blank_id = self.text_featurizer.blank_id + + def __call__(self, batch): + """ + return type [List, np.array [B, T], np.array [B, T], np.array[B]] + """ + keys = [] + texts = [] + texts_input = [] + texts_output = [] + text_lens = [] + + for idx, item in enumerate(batch): + key = item.split(" ")[0].strip() + text = " ".join(item.split(" ")[1:]) + keys.append(key) + token_ids = self.text_featurizer.featurize(text) + texts_input.append( + np.array([self.eos_id] + token_ids).astype(np.int64)) + texts_output.append( + np.array(token_ids + [self.eos_id]).astype(np.int64)) + text_lens.append(len(token_ids) + 1) + + ys_input_pad = pad_list(texts_input, self.blank_id).astype(np.int64) + ys_output_pad = pad_list(texts_output, self.blank_id).astype(np.int64) + y_lens = np.array(text_lens).astype(np.int64) + return keys, ys_input_pad, ys_output_pad, y_lens