Add the feature: caculating the perplexity of transformerLM

pull/952/head
huangyuxin 3 years ago
parent fc8a7a152e
commit d64f6e9ea5

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,82 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import configargparse
def get_parser():
"""Get default arguments."""
parser = configargparse.ArgumentParser(
description="The parser for caculating the perplexity of transformer language model ",
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=configargparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument(
"--rnnlm", type=str, default=None, help="RNNLM model file to read")
parser.add_argument(
"--rnnlm-conf",
type=str,
default=None,
help="RNNLM model config file to read")
parser.add_argument(
"--vocab_path",
type=str,
default=None,
help="vocab path to for token2id")
parser.add_argument(
"--bpeprefix",
type=str,
default=None,
help="The path of bpeprefix for loading")
parser.add_argument(
"--text_path",
type=str,
default=None,
help="The path of text file for testing ")
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpu to use, 0 for using cpu instead")
parser.add_argument(
"--dtype",
choices=("float16", "float32", "float64"),
default="float32",
help="Float precision (only available in --api v2)", )
parser.add_argument(
"--output_dir",
type=str,
default=".",
help="The output directory to store the sentence PPL")
return parser
def main(args):
parser = get_parser()
args = parser.parse_args(args)
from deepspeech.exps.lm.transformer.lm_cacu_perplexity import run_get_perplexity
run_get_perplexity(args)
if __name__ == "__main__":
main(sys.argv[1:])

@ -0,0 +1,132 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Caculating the PPL of LM model
import os
import numpy as np
import paddle
from paddle.io import DataLoader
from yacs.config import CfgNode
from deepspeech.io.collator import TextCollatorSpm
from deepspeech.io.dataset import TextDataset
from deepspeech.models.lm_interface import dynamic_import_lm
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
def get_config(config_path):
confs = CfgNode(new_allowed=True)
confs.merge_from_file(config_path)
return confs
def load_trained_lm(args):
lm_config = get_config(args.rnnlm_conf)
lm_model_module = lm_config.model_module
lm_class = dynamic_import_lm(lm_model_module)
lm = lm_class(**lm_config.model)
model_dict = paddle.load(args.rnnlm)
lm.set_state_dict(model_dict)
return lm, lm_config
def write_dict_into_file(ppl_dict, name):
with open(name, "w") as f:
for key in ppl_dict.keys():
f.write(key + " " + ppl_dict[key] + "\n")
return
def cacu_perplexity(
lm_model,
lm_config,
args,
log_base=None, ):
unit_type = lm_config.data.unit_type
batch_size = lm_config.decoding.batch_size
num_workers = lm_config.decoding.num_workers
text_file_path = args.text_path
total_nll = 0.0
total_ntokens = 0
ppl_dict = {}
len_dict = {}
text_dataset = TextDataset.from_file(text_file_path)
collate_fn_text = TextCollatorSpm(
unit_type=unit_type,
vocab_filepath=args.vocab_path,
spm_model_prefix=args.bpeprefix)
train_loader = DataLoader(
text_dataset,
batch_size=batch_size,
collate_fn=collate_fn_text,
num_workers=num_workers)
logger.info("start caculating PPL......")
for i, (keys, ys_input_pad, ys_output_pad,
y_lens) in enumerate(train_loader()):
ys_input_pad = paddle.to_tensor(ys_input_pad)
ys_output_pad = paddle.to_tensor(ys_output_pad)
_, unused_logp, unused_count, nll, nll_count = lm_model.forward(
ys_input_pad, ys_output_pad)
nll = nll.numpy()
nll_count = nll_count.numpy()
for key, _nll, ntoken in zip(keys, nll, nll_count):
if log_base is None:
utt_ppl = np.exp(_nll / ntoken)
else:
utt_ppl = log_base**(_nll / ntoken / np.log(log_base))
# Write PPL of each utts for debugging or analysis
ppl_dict[key] = str(utt_ppl)
len_dict[key] = str(ntoken)
total_nll += nll.sum()
total_ntokens += nll_count.sum()
logger.info("Current total nll: " + str(total_nll))
logger.info("Current total tokens: " + str(total_ntokens))
write_dict_into_file(ppl_dict, os.path.join(args.output_dir, "uttPPL"))
write_dict_into_file(len_dict, os.path.join(args.output_dir, "uttLEN"))
if log_base is None:
ppl = np.exp(total_nll / total_ntokens)
else:
ppl = log_base**(total_nll / total_ntokens / np.log(log_base))
if log_base is None:
log_base = np.e
else:
log_base = log_base
return ppl, log_base
def run_get_perplexity(args):
if args.ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
if args.ngpu == 1:
device = "gpu:0"
else:
device = "cpu"
paddle.set_device(device)
dtype = getattr(paddle, args.dtype)
logger.info(f"Decoding device={device}, dtype={dtype}")
lm_model, lm_config = load_trained_lm(args)
lm_model.to(device=device, dtype=dtype)
lm_model.eval()
PPL, log_base = cacu_perplexity(lm_model, lm_config, args, None)
logger.info("Final PPL: " + str(PPL))
logger.info("The log base is:" + str("%.2f" % log_base))

@ -53,7 +53,7 @@ class TextFeaturizer():
self.maskctc = maskctc
if vocab_filepath:
self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id = self._load_vocabulary_from_file(
self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id, self.blank_id = self._load_vocabulary_from_file(
vocab_filepath, maskctc)
self.vocab_size = len(self.vocab_list)
@ -227,4 +227,4 @@ class TextFeaturizer():
logger.info(f"SOS id: {sos_id}")
logger.info(f"SPACE id: {space_id}")
logger.info(f"MASKCTC id: {maskctc_id}")
return token2id, id2token, vocab_list, unk_id, eos_id
return token2id, id2token, vocab_list, unk_id, eos_id, blank_id

@ -19,6 +19,7 @@ from yacs.config import CfgNode
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
from deepspeech.frontend.normalizer import FeatureNormalizer
from deepspeech.frontend.speech import SpeechSegment
from deepspeech.frontend.utility import IGNORE_ID
@ -33,7 +34,7 @@ logger = Log(__name__).getlog()
def _tokenids(text, keep_transcription_text):
# for training text is token ids
# for training text is token ids
tokens = text # token ids
if keep_transcription_text:
@ -45,6 +46,43 @@ def _tokenids(text, keep_transcription_text):
return tokens
class TextCollatorSpm():
def __init__(self, unit_type, vocab_filepath, spm_model_prefix):
assert (vocab_filepath is not None)
self.text_featurizer = TextFeaturizer(
unit_type=unit_type,
vocab_filepath=vocab_filepath,
spm_model_prefix=spm_model_prefix)
self.eos_id = self.text_featurizer.eos_id
self.blank_id = self.text_featurizer.blank_id
def __call__(self, batch):
"""
return type [List, np.array [B, T], np.array [B, T], np.array[B]]
"""
keys = []
texts = []
texts_input = []
texts_output = []
text_lens = []
for idx, item in enumerate(batch):
key = item.split(" ")[0].strip()
text = " ".join(item.split(" ")[1:])
keys.append(key)
token_ids = self.text_featurizer.featurize(text)
texts_input.append(
np.array([self.eos_id] + token_ids).astype(np.int64))
texts_output.append(
np.array(token_ids + [self.eos_id]).astype(np.int64))
text_lens.append(len(token_ids) + 1)
ys_input_pad = pad_list(texts_input, self.blank_id).astype(np.int64)
ys_output_pad = pad_list(texts_output, self.blank_id).astype(np.int64)
y_lens = np.array(text_lens).astype(np.int64)
return keys, ys_input_pad, ys_output_pad, y_lens
class SpeechCollatorBase():
def __init__(
self,

@ -24,6 +24,25 @@ __all__ = ["ManifestDataset", "TransformDataset"]
logger = Log(__name__).getlog()
class TextDataset(Dataset):
@classmethod
def from_file(cls, file_path):
dataset = cls(file_path)
return dataset
def __init__(self, file_path):
self._manifest = []
with open(file_path) as f:
for line in f:
self._manifest.append(line.strip())
def __len__(self):
return len(self._manifest)
def __getitem__(self, idx):
return self._manifest[idx]
class ManifestDataset(Dataset):
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:

@ -111,6 +111,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
"""
batch_size = x.size(0)
xm = x != 0
xlen = xm.sum(axis=1)
if self.embed_drop is not None:
@ -121,11 +122,13 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
y = self.decoder(h)
loss = F.cross_entropy(
y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
mask = xm.to(dtype=loss.dtype)
mask = xm.to(loss.dtype)
logp = loss * mask.view(-1)
nll = logp.view(batch_size, -1).sum(-1)
nll_count = mask.sum(-1)
logp = logp.sum()
count = mask.sum()
return logp / count, logp, count
return logp / count, logp, count, nll, nll_count
# beam search API (see ScorerInterface)
def score(self, y: paddle.Tensor, state: Any,

@ -1,4 +1,8 @@
model_module: transformer
data:
unit_type: spm
model:
n_vocab: 5002
pos_enc: null
@ -11,3 +15,7 @@ model:
emb_dropout_rate: 0.0
att_dropout_rate: 0.0
tie_weights: False
decoding:
batch_size: 30
num_workers: 2

@ -0,0 +1,53 @@
#!/bin/bash
set -e
stage=-1
stop_stage=100
expdir=exp
datadir=data
ngpu=0
# lm params
rnnlm_config_path=conf/lm/transformer.yaml
lmexpdir=exp/lm/transformer
lang_model=transformerLM.pdparams
#data path
test_set=${datadir}/test_clean/text
test_set_lower=${datadir}/test_clean/text_lower
train_set=train_960
# bpemode (unigram or bpe)
nbpe=5000
bpemode=unigram
bpeprefix=${datadir}/lang_char/${train_set}_${bpemode}${nbpe}
bpemodel=${bpeprefix}.model
vocabfile=${bpeprefix}_units.txt
vocabfile_lower=${bpeprefix}_units_lower.txt
output_dir=${expdir}/lm/transformer/perplexity
mkdir -p ${output_dir}
# Transform the data upper case to lower
if [ -f ${vocabfile} ]; then
tr A-Z a-z < ${vocabfile} > ${vocabfile_lower}
fi
if [ -f ${test_set} ]; then
tr A-Z a-z < ${test_set} > ${test_set_lower}
fi
python ${LM_BIN_DIR}/cacu_perplexity.py \
--rnnlm ${lmexpdir}/${lang_model} \
--rnnlm-conf ${rnnlm_config_path} \
--vocab_path ${vocabfile_lower} \
--bpeprefix ${bpeprefix} \
--text_path ${test_set_lower} \
--output_dir ${output_dir} \
--ngpu ${ngpu}

@ -51,3 +51,7 @@ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
CUDA_VISIBLE_DEVICES= ./local/cacu_perplexity.sh || exit -1
fi

Loading…
Cancel
Save