[CLI][Demo][Text]Refactor punctuation_restoration. (#1013)
* Refactor punctuation_restoration. * Add text cli and punc demo.pull/1126/head
parent
f6ca14c5fa
commit
074559fe90
@ -0,0 +1,73 @@
|
||||
# Punctuation Restoration
|
||||
|
||||
## Introduction
|
||||
Punctuation restoration is a common post-processing problem for Automatic Speech Recognition (ASR) systems. It is important to improve the readability of the transcribed text for the human reader and facilitate NLP tasks.
|
||||
|
||||
This demo is an implementation to restore punctuation from a raw text. It can be done by a single command or a few lines in python using `PaddleSpeech`.
|
||||
|
||||
## Usage
|
||||
### 1. Installation
|
||||
```bash
|
||||
pip install paddlespeech
|
||||
```
|
||||
|
||||
### 2. Prepare Input
|
||||
Input of this demo should be a text of the specific language that can be passed via argument.
|
||||
|
||||
|
||||
### 3. Usage
|
||||
- Command Line(Recommended)
|
||||
```bash
|
||||
paddlespeech text --input 今天的天气真不错啊你下午有空吗我想约你一起去吃饭
|
||||
```
|
||||
Usage:
|
||||
```bash
|
||||
paddlespeech text --help
|
||||
```
|
||||
Arguments:
|
||||
- `input`(required): Input raw text.
|
||||
- `task`: Choose subtask. Default: `punc`.
|
||||
- `model`: Model type of text task. Default: `ernie_linear_wudao`.
|
||||
- `lang`: Choose model language.. Default: `zh`.
|
||||
- `config`: Config of text task. Use pretrained model when it is None. Default: `None`.
|
||||
- `ckpt_path`: Model checkpoint. Use pretrained model when it is None. Default: `None`.
|
||||
- `punc_vocab`: Vocabulary file of punctuation restoration task. Default: `None`.
|
||||
- `device`: Choose device to execute model inference. Default: default device of paddlepaddle in current environment.
|
||||
|
||||
Output:
|
||||
```bash
|
||||
[2021-12-14 19:50:22,200] [ INFO] [log.py] [L57] - Text Result:
|
||||
今天的天气真不错啊!你下午有空吗?我想约你一起去吃饭。
|
||||
```
|
||||
|
||||
- Python API
|
||||
```python
|
||||
import paddle
|
||||
from paddlespeech.cli import TextExecutor
|
||||
|
||||
text_executor = TextExecutor()
|
||||
result = text_executor(
|
||||
text='今天的天气真不错啊你下午有空吗我想约你一起去吃饭',
|
||||
task='punc',
|
||||
model='ernie_linear_wudao',
|
||||
lang='zh',
|
||||
config=None,
|
||||
ckpt_path=None,
|
||||
punc_vocab=None,
|
||||
device=paddle.get_device())
|
||||
print('Text Result: \n{}'.format(result))
|
||||
```
|
||||
Output:
|
||||
```bash
|
||||
Text Result:
|
||||
今天的天气真不错啊!你下午有空吗?我想约你一起去吃饭。
|
||||
```
|
||||
|
||||
|
||||
### 4.Pretrained Models
|
||||
|
||||
Here is a list of pretrained models released by PaddleSpeech that can be used by command and python api:
|
||||
|
||||
| Model | Task | Language
|
||||
| :--- | :---: | :---:
|
||||
| ernie_linear_wudao| punc(Punctuation Restoration) | zh
|
@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
paddlespeech text --input 今天的天气真好啊你下午有空吗我想约你一起去吃饭
|
@ -0,0 +1,36 @@
|
||||
data:
|
||||
dataset_type: Ernie
|
||||
train_path: data/iwslt2012_zh/train.txt
|
||||
dev_path: data/iwslt2012_zh/dev.txt
|
||||
test_path: data/iwslt2012_zh/test.txt
|
||||
data_params:
|
||||
pretrained_token: ernie-1.0
|
||||
punc_path: data/iwslt2012_zh/punc_vocab
|
||||
seq_len: 100
|
||||
batch_size: 64
|
||||
sortagrad: True
|
||||
shuffle_method: batch_shuffle
|
||||
num_workers: 0
|
||||
|
||||
checkpoint:
|
||||
kbest_n: 5
|
||||
latest_n: 10
|
||||
metric_type: F1
|
||||
|
||||
model_type: ErnieLinear
|
||||
|
||||
model_params:
|
||||
pretrained_token: ernie-1.0
|
||||
num_classes: 4
|
||||
|
||||
training:
|
||||
n_epoch: 100
|
||||
lr: !!float 1e-5
|
||||
lr_decay: 1.0
|
||||
weight_decay: !!float 1e-06
|
||||
global_grad_clip: 5.0
|
||||
log_interval: 10
|
||||
log_path: log/train_ernie_linear.log
|
||||
|
||||
testing:
|
||||
log_path: log/test_ernie_linear.log
|
@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ ! -d data ]; then
|
||||
wget -c https://paddlespeech.bj.bcebos.com/datasets/iwslt2012.tar.gz
|
||||
tar -xzf iwslt2012.tar.gz
|
||||
fi
|
||||
|
||||
echo "Finish data preparation."
|
||||
exit 0
|
@ -0,0 +1,281 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import paddle
|
||||
|
||||
from ...s2t.utils.dynamic_import import dynamic_import
|
||||
from ..executor import BaseExecutor
|
||||
from ..log import logger
|
||||
from ..utils import cli_register
|
||||
from ..utils import download_and_decompress
|
||||
from ..utils import MODEL_HOME
|
||||
|
||||
__all__ = ['TextExecutor']
|
||||
|
||||
pretrained_models = {
|
||||
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
|
||||
# e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k".
|
||||
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
|
||||
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
|
||||
"ernie_linear_wudao-punc-zh": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/text/ernie_linear_wudao-punc-zh.tar.gz',
|
||||
'md5':
|
||||
'12283e2ddde1797c5d1e57036b512746',
|
||||
'cfg_path':
|
||||
'ckpt/model_config.json',
|
||||
'ckpt_path':
|
||||
'ckpt/model_state.pdparams',
|
||||
'vocab_file':
|
||||
'punc_vocab.txt',
|
||||
},
|
||||
}
|
||||
|
||||
model_alias = {
|
||||
"ernie_linear": "paddlespeech.text.models:ErnieLinear",
|
||||
}
|
||||
|
||||
tokenizer_alias = {
|
||||
"ernie_linear": "paddlenlp.transformers:ErnieTokenizer",
|
||||
}
|
||||
|
||||
|
||||
@cli_register(name='paddlespeech.text', description='Text infer command.')
|
||||
class TextExecutor(BaseExecutor):
|
||||
def __init__(self):
|
||||
super(TextExecutor, self).__init__()
|
||||
|
||||
self.parser = argparse.ArgumentParser(
|
||||
prog='paddlespeech.text', add_help=True)
|
||||
self.parser.add_argument(
|
||||
'--input', type=str, required=True, help='Input text.')
|
||||
self.parser.add_argument(
|
||||
'--task',
|
||||
type=str,
|
||||
default='punc',
|
||||
choices=['punc'],
|
||||
help='Choose text task.')
|
||||
self.parser.add_argument(
|
||||
'--model',
|
||||
type=str,
|
||||
default='ernie_linear_wudao',
|
||||
choices=[tag[:tag.index('-')] for tag in pretrained_models.keys()],
|
||||
help='Choose model type of text task.')
|
||||
self.parser.add_argument(
|
||||
'--lang',
|
||||
type=str,
|
||||
default='zh',
|
||||
choices=['zh', 'en'],
|
||||
help='Choose model language.')
|
||||
self.parser.add_argument(
|
||||
'--config',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Config of cls task. Use deault config when it is None.')
|
||||
self.parser.add_argument(
|
||||
'--ckpt_path',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Checkpoint file of model.')
|
||||
self.parser.add_argument(
|
||||
'--punc_vocab',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Vocabulary file of punctuation restoration task.')
|
||||
self.parser.add_argument(
|
||||
'--device',
|
||||
type=str,
|
||||
default=paddle.get_device(),
|
||||
help='Choose device to execute model inference.')
|
||||
|
||||
def _get_pretrained_path(self, tag: str) -> os.PathLike:
|
||||
"""
|
||||
Download and returns pretrained resources path of current task.
|
||||
"""
|
||||
assert tag in pretrained_models, 'Can not find pretrained resources of {}.'.format(
|
||||
tag)
|
||||
|
||||
res_path = os.path.join(MODEL_HOME, tag)
|
||||
decompressed_path = download_and_decompress(pretrained_models[tag],
|
||||
res_path)
|
||||
decompressed_path = os.path.abspath(decompressed_path)
|
||||
logger.info(
|
||||
'Use pretrained model stored in: {}'.format(decompressed_path))
|
||||
|
||||
return decompressed_path
|
||||
|
||||
def _init_from_path(self,
|
||||
task: str='punc',
|
||||
model_type: str='ernie_linear_wudao',
|
||||
lang: str='zh',
|
||||
cfg_path: Optional[os.PathLike]=None,
|
||||
ckpt_path: Optional[os.PathLike]=None,
|
||||
vocab_file: Optional[os.PathLike]=None):
|
||||
"""
|
||||
Init model and other resources from a specific path.
|
||||
"""
|
||||
if hasattr(self, 'model'):
|
||||
logger.info('Model had been initialized.')
|
||||
return
|
||||
|
||||
self.task = task
|
||||
|
||||
if cfg_path is None or ckpt_path is None or vocab_file is None:
|
||||
tag = '-'.join([model_type, task, lang])
|
||||
self.res_path = self._get_pretrained_path(tag)
|
||||
self.cfg_path = os.path.join(self.res_path,
|
||||
pretrained_models[tag]['cfg_path'])
|
||||
self.ckpt_path = os.path.join(self.res_path,
|
||||
pretrained_models[tag]['ckpt_path'])
|
||||
self.vocab_file = os.path.join(self.res_path,
|
||||
pretrained_models[tag]['vocab_file'])
|
||||
else:
|
||||
self.cfg_path = os.path.abspath(cfg_path)
|
||||
self.ckpt_path = os.path.abspath(ckpt_path)
|
||||
self.vocab_file = os.path.abspath(vocab_file)
|
||||
|
||||
model_name = model_type[:model_type.rindex('_')]
|
||||
if self.task == 'punc':
|
||||
# punc list
|
||||
self._punc_list = []
|
||||
with open(self.vocab_file, 'r') as f:
|
||||
for line in f:
|
||||
self._punc_list.append(line.strip())
|
||||
|
||||
# model
|
||||
model_class = dynamic_import(model_name, model_alias)
|
||||
tokenizer_class = dynamic_import(model_name, tokenizer_alias)
|
||||
self.model = model_class(
|
||||
cfg_path=self.cfg_path, ckpt_path=self.ckpt_path)
|
||||
self.tokenizer = tokenizer_class.from_pretrained('ernie-1.0')
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.model.eval()
|
||||
|
||||
def _clean_text(self, text):
|
||||
text = text.lower()
|
||||
text = re.sub('[^A-Za-z0-9\u4e00-\u9fa5]', '', text)
|
||||
text = re.sub(f'[{"".join([p for p in self._punc_list][1:])}]', '',
|
||||
text)
|
||||
return text
|
||||
|
||||
def preprocess(self, text: Union[str, os.PathLike]):
|
||||
"""
|
||||
Input preprocess and return paddle.Tensor stored in self.input.
|
||||
Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
|
||||
"""
|
||||
logger.info("Preprocessing input text: " + text)
|
||||
if self.task == 'punc':
|
||||
clean_text = self._clean_text(text)
|
||||
assert len(clean_text) > 0, f'Invalid input string: {text}'
|
||||
|
||||
tokenized_input = self.tokenizer(
|
||||
list(clean_text), return_length=True, is_split_into_words=True)
|
||||
|
||||
self._inputs['input_ids'] = tokenized_input['input_ids']
|
||||
self._inputs['seg_ids'] = tokenized_input['token_type_ids']
|
||||
self._inputs['seq_len'] = tokenized_input['seq_len']
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@paddle.no_grad()
|
||||
def infer(self):
|
||||
"""
|
||||
Model inference and result stored in self.output.
|
||||
"""
|
||||
if self.task == 'punc':
|
||||
input_ids = paddle.to_tensor(self._inputs['input_ids']).unsqueeze(0)
|
||||
seg_ids = paddle.to_tensor(self._inputs['seg_ids']).unsqueeze(0)
|
||||
logits, _ = self.model(input_ids, seg_ids)
|
||||
preds = paddle.argmax(logits, axis=-1).squeeze(0)
|
||||
|
||||
self._outputs['preds'] = preds
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def postprocess(self) -> Union[str, os.PathLike]:
|
||||
"""
|
||||
Output postprocess and return human-readable results such as texts and audio files.
|
||||
"""
|
||||
if self.task == 'punc':
|
||||
input_ids = self._inputs['input_ids']
|
||||
seq_len = self._inputs['seq_len']
|
||||
preds = self._outputs['preds']
|
||||
|
||||
tokens = self.tokenizer.convert_ids_to_tokens(
|
||||
input_ids[1:seq_len - 1])
|
||||
labels = preds[1:seq_len - 1].tolist()
|
||||
assert len(tokens) == len(labels)
|
||||
|
||||
text = ''
|
||||
for t, l in zip(tokens, labels):
|
||||
text += t
|
||||
if l != 0: # Non punc.
|
||||
text += self._punc_list[l]
|
||||
|
||||
return text
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def execute(self, argv: List[str]) -> bool:
|
||||
"""
|
||||
Command line entry.
|
||||
"""
|
||||
parser_args = self.parser.parse_args(argv)
|
||||
|
||||
text = parser_args.input
|
||||
task = parser_args.task
|
||||
model_type = parser_args.model
|
||||
lang = parser_args.lang
|
||||
cfg_path = parser_args.config
|
||||
ckpt_path = parser_args.ckpt_path
|
||||
punc_vocab = parser_args.punc_vocab
|
||||
device = parser_args.device
|
||||
|
||||
try:
|
||||
res = self(text, task, model_type, lang, cfg_path, ckpt_path,
|
||||
punc_vocab, device)
|
||||
logger.info('Text Result:\n{}'.format(res))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return False
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: str,
|
||||
task: str='punc',
|
||||
model: str='ernie_linear_wudao',
|
||||
lang: str='zh',
|
||||
config: os.PathLike=None,
|
||||
ckpt_path: os.PathLike=None,
|
||||
punc_vocab: os.PathLike=None,
|
||||
device: str=paddle.get_device(), ):
|
||||
"""
|
||||
Python API to call an executor.
|
||||
"""
|
||||
paddle.set_device(device)
|
||||
self._init_from_path(task, model, lang, config, ckpt_path, punc_vocab)
|
||||
self.preprocess(text)
|
||||
self.infer()
|
||||
res = self.postprocess() # Retrieve result of text task.
|
||||
|
||||
return res
|
@ -1,7 +0,0 @@
|
||||
data
|
||||
glove
|
||||
.pyc
|
||||
checkpoints
|
||||
epoch
|
||||
__pycache__
|
||||
glove.840B.300d.zip
|
@ -1,34 +0,0 @@
|
||||
data:
|
||||
language: chinese
|
||||
raw_path: /data4/mahaoxin/PaddleSpeechTask/data/chinese/PFDSJ #path to raw dataset
|
||||
raw_train_file: train
|
||||
raw_dev_file: dev
|
||||
raw_test_file: test
|
||||
vocab_file: vocab
|
||||
punc_file: punc_vocab
|
||||
save_path: data/PFDSJ #path to save dataset
|
||||
seq_len: 100
|
||||
batch_size: 10
|
||||
sortagrad: True
|
||||
shuffle_method: batch_shuffle
|
||||
num_workers: 0
|
||||
|
||||
model_type: blstm
|
||||
model_params:
|
||||
vocab_size: 3751
|
||||
embedding_size: 200
|
||||
hidden_size: 100
|
||||
num_layers: 3
|
||||
num_class: 5
|
||||
init_scale: 0.1
|
||||
|
||||
training:
|
||||
n_epoch: 32
|
||||
lr: !!float 1e-4
|
||||
lr_decay: 1.0
|
||||
weight_decay: !!float 1e-06
|
||||
global_grad_clip: 5.0
|
||||
log_interval: 10
|
||||
|
||||
|
||||
|
@ -1,7 +0,0 @@
|
||||
type: chinese
|
||||
raw_path: /data4/mahaoxin/PaddleSpeechTask/data/chinese/iwslt2012_zh #path to raw dataset
|
||||
raw_train_file: iwslt2012_train_zh
|
||||
raw_dev_file: iwslt2010_dev_zh
|
||||
raw_test_file: biaobei_asr
|
||||
punc_file: punc_vocab
|
||||
save_path: data/iwslt2012_zh #path to save dataset
|
@ -1,49 +0,0 @@
|
||||
data:
|
||||
dataset_type: Bert
|
||||
train_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/train
|
||||
dev_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/dev
|
||||
test_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/test2012_revise
|
||||
data_params:
|
||||
pretrained_token: bert-base-chinese
|
||||
punc_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/punc_vocab
|
||||
seq_len: 100
|
||||
batch_size: 64
|
||||
sortagrad: True
|
||||
shuffle_method: batch_shuffle
|
||||
num_workers: 0
|
||||
|
||||
checkpoint:
|
||||
kbest_n: 5
|
||||
latest_n: 10
|
||||
metric_type: F1
|
||||
|
||||
|
||||
model_type: BertBLSTM
|
||||
model_params:
|
||||
pretrained_token: bert-base-chinese
|
||||
output_size: 4
|
||||
dropout: 0.0
|
||||
bert_size: 768
|
||||
blstm_size: 128
|
||||
num_blstm_layers: 2
|
||||
init_scale: 0.1
|
||||
|
||||
# model_type: BertChLinear
|
||||
# model_params: bert-base-chinese
|
||||
# pretrained_token:
|
||||
# output_size: 4
|
||||
# dropout: 0.0
|
||||
# bert_size: 768
|
||||
|
||||
training:
|
||||
n_epoch: 100
|
||||
lr: !!float 1e-5
|
||||
lr_decay: 1.0
|
||||
weight_decay: !!float 1e-06
|
||||
global_grad_clip: 5.0
|
||||
log_interval: 10
|
||||
log_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/log/bertBLSTM_zh0812.log
|
||||
|
||||
testing:
|
||||
log_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/log/test_bertBLSTM_zh0812.log
|
||||
|
@ -1,42 +0,0 @@
|
||||
data:
|
||||
dataset_type: Bert
|
||||
train_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/train
|
||||
dev_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/dev
|
||||
test_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/test2012
|
||||
data_params:
|
||||
pretrained_token: bert-base-chinese
|
||||
punc_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/punc_vocab
|
||||
seq_len: 100
|
||||
batch_size: 32
|
||||
sortagrad: True
|
||||
shuffle_method: batch_shuffle
|
||||
num_workers: 0
|
||||
|
||||
checkpoint:
|
||||
kbest_n: 10
|
||||
latest_n: 10
|
||||
metric_type: F1
|
||||
|
||||
|
||||
model_type: BertLinear
|
||||
model_params:
|
||||
pretrained_token: bert-base-uncased
|
||||
output_size: 4
|
||||
dropout: 0.2
|
||||
bert_size: 768
|
||||
hiddensize: 1568
|
||||
|
||||
|
||||
training:
|
||||
n_epoch: 50
|
||||
lr: !!float 1e-5
|
||||
lr_decay: 1.0
|
||||
weight_decay: !!float 1e-06
|
||||
global_grad_clip: 5.0
|
||||
log_interval: 10
|
||||
log_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/log/train_linear0812.log
|
||||
|
||||
testing:
|
||||
log_interval: 10
|
||||
log_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/log/test_linear0812.log
|
||||
|
@ -1,19 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# != 1 ];then
|
||||
echo "usage: ${0} data_pre_conf"
|
||||
echo $1
|
||||
exit -1
|
||||
fi
|
||||
|
||||
data_pre_conf=$1
|
||||
|
||||
python3 -u ${BIN_DIR}/pre_data.py \
|
||||
--config ${data_pre_conf}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in training!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
exit 0
|
@ -1,26 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# != 2 ];then
|
||||
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
config_path=$1
|
||||
ckpt_name=$2
|
||||
|
||||
mkdir -p exp
|
||||
|
||||
python3 -u ${BIN_DIR}/train.py \
|
||||
--ngpu ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--output exp/${ckpt_name}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in training!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
exit 0
|
@ -1,7 +0,0 @@
|
||||
type: english
|
||||
raw_path: /data4/mahaoxin/PaddleSpeechTask/data/english/iwslt2012_en #path to raw dataset
|
||||
raw_train_file: iwslt2012_train_en
|
||||
raw_dev_file: iwslt2010_dev_en
|
||||
raw_test_file: iwslt2011_test_en
|
||||
punc_file: punc_vocab
|
||||
save_path: data/iwslt2012_en #path to save dataset
|
@ -1,47 +0,0 @@
|
||||
data:
|
||||
dataset_type: Bert
|
||||
train_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/train
|
||||
dev_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/dev
|
||||
test_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/test2011
|
||||
data_params:
|
||||
pretrained_token: bert-base-uncased #english
|
||||
punc_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/punc_vocab
|
||||
seq_len: 50
|
||||
batch_size: 32
|
||||
sortagrad: True
|
||||
shuffle_method: batch_shuffle
|
||||
num_workers: 0
|
||||
|
||||
checkpoint:
|
||||
kbest_n: 10
|
||||
latest_n: 10
|
||||
|
||||
model_type: BertBLSTM
|
||||
model_params:
|
||||
pretrained_token: bert-base-uncased
|
||||
output_size: 4
|
||||
dropout: 0.0
|
||||
bert_size: 768
|
||||
blstm_size: 128
|
||||
num_blstm_layers: 2
|
||||
init_scale: 0.2
|
||||
# model_type: BertChLinear
|
||||
# model_params:
|
||||
# pretrained_token: bert-large-uncased
|
||||
# output_size: 4
|
||||
# dropout: 0.0
|
||||
# bert_size: 768
|
||||
|
||||
training:
|
||||
n_epoch: 100
|
||||
lr: !!float 1e-5
|
||||
lr_decay: 1.0
|
||||
weight_decay: !!float 1e-06
|
||||
global_grad_clip: 5.0
|
||||
log_interval: 10
|
||||
log_path: log/bertBLSTM_base0812.log
|
||||
|
||||
testing:
|
||||
log_path: log/testbertBLSTM_base0812.log
|
||||
|
||||
|
@ -1,39 +0,0 @@
|
||||
data:
|
||||
dataset_type: Bert
|
||||
train_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/train
|
||||
dev_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/dev
|
||||
test_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/test2011
|
||||
data_params:
|
||||
pretrained_token: bert-base-uncased #english
|
||||
punc_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/punc_vocab
|
||||
seq_len: 100
|
||||
batch_size: 32
|
||||
sortagrad: True
|
||||
shuffle_method: batch_shuffle
|
||||
num_workers: 0
|
||||
|
||||
checkpoint:
|
||||
kbest_n: 10
|
||||
latest_n: 10
|
||||
|
||||
model_type: BertLinear
|
||||
model_params:
|
||||
pretrained_token: bert-base-uncased
|
||||
output_size: 4
|
||||
dropout: 0.2
|
||||
bert_size: 768
|
||||
hiddensize: 1568
|
||||
|
||||
training:
|
||||
n_epoch: 20
|
||||
lr: !!float 1e-5
|
||||
lr_decay: 1.0
|
||||
weight_decay: !!float 1e-06
|
||||
global_grad_clip: 3.0
|
||||
log_interval: 10
|
||||
log_path: log/train_linear0820.log
|
||||
|
||||
testing:
|
||||
log_path: log/test2011_linear0820.log
|
||||
|
||||
|
@ -1,23 +0,0 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "usage: ${0} ckpt_dir avg_num"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ckpt_dir=${1}
|
||||
average_num=${2}
|
||||
decode_checkpoint=${ckpt_dir}/avg_${average_num}.pdparams
|
||||
|
||||
python3 -u ${BIN_DIR}/avg_model.py \
|
||||
--dst_model ${decode_checkpoint} \
|
||||
--ckpt_dir ${ckpt_dir} \
|
||||
--num ${average_num} \
|
||||
--val_best
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in avg ckpt!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
exit 0
|
@ -1,18 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# != 1 ];then
|
||||
echo "usage: ${0} config_path"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
config_path=$1
|
||||
|
||||
python3 -u ${BIN_DIR}/pre_data.py \
|
||||
--config ${config_path}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in training!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
exit 0
|
@ -1,27 +0,0 @@
|
||||
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# != 2 ];then
|
||||
echo "usage: ${0} config_path ckpt_path_prefix"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
config_path=$1
|
||||
ckpt_prefix=$2
|
||||
|
||||
|
||||
python3 -u ${BIN_DIR}/test.py \
|
||||
--ngpu 1 \
|
||||
--config ${config_path} \
|
||||
--result_file ${ckpt_prefix}.rsl \
|
||||
--checkpoint_path ${ckpt_prefix}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
exit 0
|
@ -1,13 +0,0 @@
|
||||
export MAIN_ROOT=${PWD}/../../../
|
||||
|
||||
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||
|
||||
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
|
||||
|
||||
|
||||
export BIN_DIR=${MAIN_ROOT}/speechtask/punctuation_restoration/bin
|
@ -1,47 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
|
||||
## stage, gpu, data_pre_config, train_config, avg_num
|
||||
if [ $# -lt 4 ]; then
|
||||
echo "usage: bash ./run.sh stage gpu train_config avg_num data_config"
|
||||
echo "eg: bash ./run.sh 0 0 train_config 1 data_config "
|
||||
exit -1
|
||||
fi
|
||||
|
||||
stage=$1
|
||||
stop_stage=100
|
||||
gpus=$2
|
||||
conf_path=$3
|
||||
avg_num=$4
|
||||
avg_ckpt=avg_${avg_num}
|
||||
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
|
||||
echo "checkpoint name ${ckpt}"
|
||||
|
||||
if [ $stage -le 0 ]; then
|
||||
if [ $# -eq 5 ]; then
|
||||
data_pre_conf=$5
|
||||
# prepare data
|
||||
bash ./local/data.sh ${data_pre_conf} || exit -1
|
||||
else
|
||||
echo "data_pre_conf is not exist!"
|
||||
exit -1
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# train model, all `ckpt` under `exp` dir
|
||||
CUDA_VISIBLE_DEVICES=${gpus} bash ./local/train.sh ${conf_path} ${ckpt}
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# avg n best model
|
||||
bash ./local/avg.sh exp/${ckpt}/checkpoints ${avg_num}
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
# test ckpt avg_n
|
||||
CUDA_VISIBLE_DEVICES=${gpus} bash ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
|
||||
fi
|
@ -0,0 +1,15 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .ernie_crf import ErnieCrf
|
||||
from .ernie_linear import ErnieLinear
|
@ -0,0 +1,65 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddlenlp.layers.crf import LinearChainCrf
|
||||
from paddlenlp.layers.crf import LinearChainCrfLoss
|
||||
from paddlenlp.layers.crf import ViterbiDecoder
|
||||
from paddlenlp.transformers import ErnieForTokenClassification
|
||||
|
||||
|
||||
class ErnieCrf(nn.Layer):
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
pretrained_token='ernie-1.0',
|
||||
crf_lr=100,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.ernie = ErnieForTokenClassification.from_pretrained(
|
||||
pretrained_token, num_classes=num_classes, **kwargs)
|
||||
self.num_classes = num_classes
|
||||
self.crf = LinearChainCrf(
|
||||
self.num_classes, crf_lr=crf_lr, with_start_stop_tag=False)
|
||||
self.crf_loss = LinearChainCrfLoss(self.crf)
|
||||
self.viterbi_decoder = ViterbiDecoder(
|
||||
self.crf.transitions, with_start_stop_tag=False)
|
||||
|
||||
def forward(self,
|
||||
input_ids,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
attention_mask=None,
|
||||
lengths=None,
|
||||
labels=None):
|
||||
logits = self.ernie(
|
||||
input_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids)
|
||||
|
||||
if lengths is None:
|
||||
lengths = paddle.ones(
|
||||
shape=[input_ids.shape[0]],
|
||||
dtype=paddle.int64) * input_ids.shape[1]
|
||||
|
||||
_, prediction = self.viterbi_decoder(logits, lengths)
|
||||
prediction = prediction.reshape([-1])
|
||||
|
||||
if labels is not None:
|
||||
labels = labels.reshape([input_ids.shape[0], -1])
|
||||
loss = self.crf_loss(logits, lengths, labels)
|
||||
avg_loss = paddle.mean(loss)
|
||||
return avg_loss, prediction
|
||||
else:
|
||||
return prediction
|
@ -0,0 +1,14 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .model import ErnieLinear
|
@ -0,0 +1,155 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle.io import Dataset
|
||||
from paddlenlp.transformers import ErnieTokenizer
|
||||
|
||||
__all__ = ["PuncDataset", "PuncDatasetFromErnieTokenizer"]
|
||||
|
||||
|
||||
class PuncDataset(Dataset):
|
||||
def __init__(self, train_path, vocab_path, punc_path, seq_len=100):
|
||||
self.seq_len = seq_len
|
||||
|
||||
self.word2id = self.load_vocab(
|
||||
vocab_path, extra_word_list=['<UNK>', '<END>'])
|
||||
self.id2word = {v: k for k, v in self.word2id.items()}
|
||||
self.punc2id = self.load_vocab(punc_path, extra_word_list=[" "])
|
||||
self.id2punc = {k: v for (v, k) in self.punc2id.items()}
|
||||
|
||||
tmp_seqs = open(train_path, encoding='utf-8').readlines()
|
||||
self.txt_seqs = [i for seq in tmp_seqs for i in seq.split()]
|
||||
self.preprocess(self.txt_seqs)
|
||||
|
||||
def __len__(self):
|
||||
"""return the sentence nums in .txt
|
||||
"""
|
||||
return self.in_len
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.input_data[index], self.label[index]
|
||||
|
||||
def load_vocab(self, vocab_path, extra_word_list=[], encoding='utf-8'):
|
||||
n = len(extra_word_list)
|
||||
with open(vocab_path, encoding='utf-8') as vf:
|
||||
vocab = {word.strip(): i + n for i, word in enumerate(vf)}
|
||||
for i, word in enumerate(extra_word_list):
|
||||
vocab[word] = i
|
||||
return vocab
|
||||
|
||||
def preprocess(self, txt_seqs: list):
|
||||
input_data = []
|
||||
label = []
|
||||
input_r = []
|
||||
label_r = []
|
||||
|
||||
count = 0
|
||||
length = len(txt_seqs)
|
||||
for token in txt_seqs:
|
||||
count += 1
|
||||
if count == length:
|
||||
break
|
||||
if token in self.punc2id:
|
||||
continue
|
||||
punc = txt_seqs[count]
|
||||
if punc not in self.punc2id:
|
||||
label.append(self.punc2id[" "])
|
||||
input_data.append(
|
||||
self.word2id.get(token, self.word2id["<UNK>"]))
|
||||
input_r.append(token)
|
||||
label_r.append(' ')
|
||||
else:
|
||||
label.append(self.punc2id[punc])
|
||||
input_data.append(
|
||||
self.word2id.get(token, self.word2id["<UNK>"]))
|
||||
input_r.append(token)
|
||||
label_r.append(punc)
|
||||
if len(input_data) != len(label):
|
||||
assert 'error: length input_data != label'
|
||||
|
||||
self.in_len = len(input_data) // self.seq_len
|
||||
len_tmp = self.in_len * self.seq_len
|
||||
input_data = input_data[:len_tmp]
|
||||
label = label[:len_tmp]
|
||||
|
||||
self.input_data = paddle.to_tensor(
|
||||
np.array(input_data, dtype='int64').reshape(-1, self.seq_len))
|
||||
self.label = paddle.to_tensor(
|
||||
np.array(label, dtype='int64').reshape(-1, self.seq_len))
|
||||
|
||||
|
||||
class PuncDatasetFromErnieTokenizer(Dataset):
|
||||
def __init__(self,
|
||||
train_path,
|
||||
punc_path,
|
||||
pretrained_token='ernie-1.0',
|
||||
seq_len=100):
|
||||
self.tokenizer = ErnieTokenizer.from_pretrained(pretrained_token)
|
||||
self.paddingID = self.tokenizer.pad_token_id
|
||||
self.seq_len = seq_len
|
||||
|
||||
self.punc2id = self.load_vocab(punc_path, extra_word_list=[" "])
|
||||
self.id2punc = {k: v for (v, k) in self.punc2id.items()}
|
||||
|
||||
tmp_seqs = open(train_path, encoding='utf-8').readlines()
|
||||
self.txt_seqs = [i for seq in tmp_seqs for i in seq.split()]
|
||||
self.preprocess(self.txt_seqs)
|
||||
|
||||
def __len__(self):
|
||||
return self.in_len
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.input_data[index], self.label[index]
|
||||
|
||||
def load_vocab(self, vocab_path, extra_word_list=[], encoding='utf-8'):
|
||||
n = len(extra_word_list)
|
||||
with open(vocab_path, encoding='utf-8') as vf:
|
||||
vocab = {word.strip(): i + n for i, word in enumerate(vf)}
|
||||
for i, word in enumerate(extra_word_list):
|
||||
vocab[word] = i
|
||||
return vocab
|
||||
|
||||
def preprocess(self, txt_seqs: list):
|
||||
input_data = []
|
||||
label = []
|
||||
count = 0
|
||||
for i in range(len(txt_seqs) - 1):
|
||||
word = txt_seqs[i]
|
||||
punc = txt_seqs[i + 1]
|
||||
if word in self.punc2id:
|
||||
continue
|
||||
|
||||
token = self.tokenizer(word)
|
||||
x = token["input_ids"][1:-1]
|
||||
input_data.extend(x)
|
||||
|
||||
for i in range(len(x) - 1):
|
||||
label.append(self.punc2id[" "])
|
||||
|
||||
if punc not in self.punc2id:
|
||||
label.append(self.punc2id[" "])
|
||||
else:
|
||||
label.append(self.punc2id[punc])
|
||||
|
||||
if len(input_data) != len(label):
|
||||
assert 'error: length input_data != label'
|
||||
|
||||
self.in_len = len(input_data) // self.seq_len
|
||||
len_tmp = self.in_len * self.seq_len
|
||||
input_data = input_data[:len_tmp]
|
||||
label = label[:len_tmp]
|
||||
self.input_data = np.array(
|
||||
input_data, dtype='int64').reshape(-1, self.seq_len)
|
||||
self.label = np.array(label, dtype='int64').reshape(-1, self.seq_len)
|
@ -0,0 +1,65 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddlenlp.transformers import ErnieForTokenClassification
|
||||
|
||||
|
||||
class ErnieLinear(nn.Layer):
|
||||
def __init__(self,
|
||||
num_classes=None,
|
||||
pretrained_token='ernie-1.0',
|
||||
cfg_path=None,
|
||||
ckpt_path=None,
|
||||
**kwargs):
|
||||
super(ErnieLinear, self).__init__()
|
||||
|
||||
if cfg_path is not None and ckpt_path is not None:
|
||||
cfg_path = os.path.abspath(os.path.expanduser(cfg_path))
|
||||
ckpt_path = os.path.abspath(os.path.expanduser(ckpt_path))
|
||||
|
||||
assert os.path.isfile(
|
||||
cfg_path), 'Config file is not valid: {}'.format(cfg_path)
|
||||
assert os.path.isfile(
|
||||
ckpt_path), 'Checkpoint file is not valid: {}'.format(ckpt_path)
|
||||
|
||||
self.ernie = ErnieForTokenClassification.from_pretrained(
|
||||
os.path.dirname(cfg_path))
|
||||
else:
|
||||
assert isinstance(
|
||||
num_classes, int
|
||||
) and num_classes > 0, 'Argument `num_classes` must be an integer.'
|
||||
self.ernie = ErnieForTokenClassification.from_pretrained(
|
||||
pretrained_token, num_classes=num_classes, **kwargs)
|
||||
|
||||
self.num_classes = self.ernie.num_classes
|
||||
self.softmax = nn.Softmax()
|
||||
|
||||
def forward(self,
|
||||
input_ids,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
attention_mask=None):
|
||||
y = self.ernie(
|
||||
input_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids)
|
||||
|
||||
y = paddle.reshape(y, shape=[-1, self.num_classes])
|
||||
logits = self.softmax(y)
|
||||
|
||||
return y, logits
|
@ -1,6 +0,0 @@
|
||||
numpy
|
||||
pyyaml
|
||||
tensorboardX
|
||||
tqdm
|
||||
ujson
|
||||
yacs
|
@ -1,48 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Data preparation for punctuation_restoration task."""
|
||||
import yaml
|
||||
from speechtask.punctuation_restoration.utils.default_parser import default_argument_parser
|
||||
from speechtask.punctuation_restoration.utils.punct_pre import process_chinese_pure_senetence
|
||||
from speechtask.punctuation_restoration.utils.punct_pre import process_english_pure_senetence
|
||||
from speechtask.punctuation_restoration.utils.utility import print_arguments
|
||||
|
||||
|
||||
# create dataset from raw data files
|
||||
def main(config, args):
|
||||
print("Start preparing data from raw data.")
|
||||
if (config['type'] == 'chinese'):
|
||||
process_chinese_pure_senetence(config)
|
||||
elif (config['type'] == 'english'):
|
||||
print('english!!!!')
|
||||
process_english_pure_senetence(config)
|
||||
else:
|
||||
print('Error: Type should be chinese or english!!!!')
|
||||
raise ValueError('Type should be chinese or english')
|
||||
|
||||
print("Finish preparing data.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = default_argument_parser()
|
||||
args = parser.parse_args()
|
||||
print_arguments(args, globals())
|
||||
|
||||
# https://yaml.org/type/float.html
|
||||
with open(args.config, "r") as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
# config.freeze()
|
||||
print(config)
|
||||
main(config, args)
|
@ -1,64 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
|
||||
__all__ = ["TextCollator"]
|
||||
|
||||
|
||||
class TextCollator():
|
||||
def __init__(self, padding_value):
|
||||
self.padding_value = padding_value
|
||||
|
||||
def __call__(self, batch):
|
||||
"""batch examples
|
||||
Args:
|
||||
batch ([List]): batch is (text, punctuation)
|
||||
text (List[int] ) shape (batch, L)
|
||||
punctuation (List[int] or str): shape (batch, L)
|
||||
Returns:
|
||||
tuple(text, punctuation): batched data.
|
||||
text : (B, Lmax)
|
||||
punctuation : (B, Lmax)
|
||||
"""
|
||||
texts = []
|
||||
punctuations = []
|
||||
for text, punctuation in batch:
|
||||
|
||||
texts.append(text)
|
||||
punctuations.append(punctuation)
|
||||
|
||||
#[B, T, D]
|
||||
x_pad = self.pad_sequence(texts).astype(np.int64)
|
||||
# print(x_pad.shape)
|
||||
# pad_list(audios, 0.0).astype(np.float32)
|
||||
# ilens = np.array(audio_lens).astype(np.int64)
|
||||
y_pad = self.pad_sequence(punctuations).astype(np.int64)
|
||||
# print(y_pad.shape)
|
||||
# olens = np.array(text_lens).astype(np.int64)
|
||||
return x_pad, y_pad
|
||||
|
||||
def pad_sequence(self, sequences):
|
||||
# assuming trailing dimensions and type of all the Tensors
|
||||
# in sequences are same and fetching those from sequences[0]
|
||||
max_len = max([len(s) for s in sequences])
|
||||
out_dims = (len(sequences), max_len)
|
||||
|
||||
out_tensor = np.full(out_dims,
|
||||
self.padding_value) #, dtype=sequences[0].dtype)
|
||||
for i, tensor in enumerate(sequences):
|
||||
length = len(tensor)
|
||||
# use index notation to prevent duplicate references to the tensor
|
||||
out_tensor[i, :length] = tensor
|
||||
|
||||
return out_tensor
|
@ -1,55 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import codecs
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
import ujson
|
||||
|
||||
PAD = "<PAD>"
|
||||
UNK = "<UNK>"
|
||||
NUM = "<NUM>"
|
||||
END = "</S>"
|
||||
SPACE = "_SPACE"
|
||||
|
||||
|
||||
def write_json(filename, dataset):
|
||||
with codecs.open(filename, mode="w", encoding="utf-8") as f:
|
||||
ujson.dump(dataset, f)
|
||||
|
||||
|
||||
def word_convert(word, keep_number=True, lowercase=True):
|
||||
if not keep_number:
|
||||
if is_digit(word):
|
||||
word = NUM
|
||||
if lowercase:
|
||||
word = word.lower()
|
||||
return word
|
||||
|
||||
|
||||
def is_digit(word):
|
||||
try:
|
||||
float(word)
|
||||
return True
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
unicodedata.numeric(word)
|
||||
return True
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
result = re.compile(r'^[-+]?[0-9]+,[0-9]+$').match(word)
|
||||
if result:
|
||||
return True
|
||||
return False
|
@ -1,74 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.initializer as I
|
||||
from paddlenlp.transformers import BertForTokenClassification
|
||||
|
||||
|
||||
class BertBLSTMPunc(nn.Layer):
|
||||
def __init__(self,
|
||||
pretrained_token="bert-large-uncased",
|
||||
output_size=4,
|
||||
dropout=0.0,
|
||||
bert_size=768,
|
||||
blstm_size=128,
|
||||
num_blstm_layers=2,
|
||||
init_scale=0.1):
|
||||
super(BertBLSTMPunc, self).__init__()
|
||||
self.output_size = output_size
|
||||
self.bert = BertForTokenClassification.from_pretrained(
|
||||
pretrained_token, num_classes=bert_size)
|
||||
# self.bert_vocab_size = vocab_size
|
||||
# self.bn = nn.BatchNorm1d(segment_size*self.bert_vocab_size)
|
||||
# self.fc = nn.Linear(segment_size*self.bert_vocab_size, output_size)
|
||||
|
||||
self.lstm = nn.LSTM(
|
||||
input_size=bert_size,
|
||||
hidden_size=blstm_size,
|
||||
num_layers=num_blstm_layers,
|
||||
direction="bidirect",
|
||||
weight_ih_attr=paddle.ParamAttr(initializer=I.Uniform(
|
||||
low=-init_scale, high=init_scale)),
|
||||
weight_hh_attr=paddle.ParamAttr(initializer=I.Uniform(
|
||||
low=-init_scale, high=init_scale)))
|
||||
|
||||
# NOTE dense*2 使用bert中间层 dense hidden_state self.bert_size
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.fc = nn.Linear(blstm_size * 2, output_size)
|
||||
self.softmax = nn.Softmax()
|
||||
|
||||
def forward(self, x):
|
||||
# print('input :', x.shape)
|
||||
x = self.bert(x) #[0]
|
||||
# print('after bert :', x.shape)
|
||||
|
||||
y, (_, _) = self.lstm(x)
|
||||
# print('after lstm :', y.shape)
|
||||
y = self.fc(self.dropout(y))
|
||||
y = paddle.reshape(y, shape=[-1, self.output_size])
|
||||
# print('after fc :', y.shape)
|
||||
|
||||
logit = self.softmax(y)
|
||||
# print('after softmax :', logit.shape)
|
||||
|
||||
return y, logit
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('start model')
|
||||
model = BertBLSTMPunc()
|
||||
x = paddle.randint(low=0, high=40, shape=[2, 5])
|
||||
print(x)
|
||||
y, logit = model(x)
|
@ -1,63 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddlenlp.transformers import BertForTokenClassification
|
||||
|
||||
|
||||
class BertLinearPunc(nn.Layer):
|
||||
def __init__(self,
|
||||
pretrained_token="bert-base-uncased",
|
||||
output_size=4,
|
||||
dropout=0.2,
|
||||
bert_size=768,
|
||||
hiddensize=1568):
|
||||
super(BertLinearPunc, self).__init__()
|
||||
self.output_size = output_size
|
||||
self.bert = BertForTokenClassification.from_pretrained(
|
||||
pretrained_token, num_classes=bert_size)
|
||||
# self.bert_vocab_size = vocab_size
|
||||
# self.bn = nn.BatchNorm1d(segment_size*self.bert_vocab_size)
|
||||
# self.fc = nn.Linear(segment_size*self.bert_vocab_size, output_size)
|
||||
|
||||
# NOTE dense*2 使用bert中间层 dense hidden_state self.bert_size
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.fc1 = nn.Linear(bert_size, hiddensize)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(hiddensize, output_size)
|
||||
self.softmax = nn.Softmax()
|
||||
|
||||
def forward(self, x):
|
||||
# print('input :', x.shape)
|
||||
x = self.bert(x) #[0]
|
||||
# print('after bert :', x.shape)
|
||||
|
||||
x = self.fc1(self.dropout1(x))
|
||||
x = self.fc2(self.relu(self.dropout2(x)))
|
||||
x = paddle.reshape(x, shape=[-1, self.output_size])
|
||||
# print('after fc :', x.shape)
|
||||
|
||||
logit = self.softmax(x)
|
||||
# print('after softmax :', logit.shape)
|
||||
|
||||
return x, logit
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('start model')
|
||||
model = BertLinearPunc()
|
||||
x = paddle.randint(low=0, high=40, shape=[2, 5])
|
||||
print(x)
|
||||
y, logit = model(x)
|
@ -1,85 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.initializer as I
|
||||
|
||||
|
||||
class RnnLm(nn.Layer):
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
punc_size,
|
||||
hidden_size,
|
||||
num_layers=1,
|
||||
init_scale=0.1,
|
||||
dropout=0.0):
|
||||
super(RnnLm, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.init_scale = init_scale
|
||||
self.punc_size = punc_size
|
||||
|
||||
self.embedder = nn.Embedding(
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
weight_attr=paddle.ParamAttr(initializer=I.Uniform(
|
||||
low=-init_scale, high=init_scale)))
|
||||
|
||||
self.lstm = nn.LSTM(
|
||||
input_size=hidden_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout,
|
||||
weight_ih_attr=paddle.ParamAttr(initializer=I.Uniform(
|
||||
low=-init_scale, high=init_scale)),
|
||||
weight_hh_attr=paddle.ParamAttr(initializer=I.Uniform(
|
||||
low=-init_scale, high=init_scale)))
|
||||
|
||||
self.fc = nn.Linear(
|
||||
hidden_size,
|
||||
punc_size,
|
||||
weight_attr=paddle.ParamAttr(initializer=I.Uniform(
|
||||
low=-init_scale, high=init_scale)),
|
||||
bias_attr=paddle.ParamAttr(initializer=I.Uniform(
|
||||
low=-init_scale, high=init_scale)))
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
self.softmax = nn.Softmax()
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
x_emb = self.embedder(x)
|
||||
x_emb = self.dropout(x_emb)
|
||||
|
||||
y, (_, _) = self.lstm(x_emb)
|
||||
|
||||
y = self.dropout(y)
|
||||
y = self.fc(y)
|
||||
y = paddle.reshape(y, shape=[-1, self.punc_size])
|
||||
logit = self.softmax(y)
|
||||
return y, logit
|
||||
|
||||
|
||||
class CrossEntropyLossForLm(nn.Layer):
|
||||
def __init__(self):
|
||||
super(CrossEntropyLossForLm, self).__init__()
|
||||
|
||||
def forward(self, y, label):
|
||||
label = paddle.unsqueeze(label, axis=2)
|
||||
loss = paddle.nn.functional.cross_entropy(
|
||||
input=y, label=label, reduction='none')
|
||||
loss = paddle.squeeze(loss, axis=[2])
|
||||
loss = paddle.mean(loss, axis=[0])
|
||||
loss = paddle.sum(loss)
|
||||
return loss
|
@ -1,141 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import OrderedDict
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
__all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock"]
|
||||
|
||||
|
||||
def brelu(x, t_min=0.0, t_max=24.0, name=None):
|
||||
# paddle.to_tensor is dygraph_only can not work under JIT
|
||||
t_min = paddle.full(shape=[1], fill_value=t_min, dtype='float32')
|
||||
t_max = paddle.full(shape=[1], fill_value=t_max, dtype='float32')
|
||||
return x.maximum(t_min).minimum(t_max)
|
||||
|
||||
|
||||
class LinearGLUBlock(nn.Layer):
|
||||
"""A linear Gated Linear Units (GLU) block."""
|
||||
|
||||
def __init__(self, idim: int):
|
||||
""" GLU.
|
||||
Args:
|
||||
idim (int): input and output dimension
|
||||
"""
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(idim, idim * 2)
|
||||
|
||||
def forward(self, xs):
|
||||
return glu(self.fc(xs), dim=-1)
|
||||
|
||||
|
||||
class ConvGLUBlock(nn.Layer):
|
||||
def __init__(self, kernel_size, in_ch, out_ch, bottlececk_dim=0,
|
||||
dropout=0.):
|
||||
"""A convolutional Gated Linear Units (GLU) block.
|
||||
|
||||
Args:
|
||||
kernel_size (int): kernel size
|
||||
in_ch (int): number of input channels
|
||||
out_ch (int): number of output channels
|
||||
bottlececk_dim (int): dimension of the bottleneck layers for computational efficiency. Defaults to 0.
|
||||
dropout (float): dropout probability. Defaults to 0..
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.conv_residual = None
|
||||
if in_ch != out_ch:
|
||||
self.conv_residual = nn.utils.weight_norm(
|
||||
nn.Conv2D(
|
||||
in_channels=in_ch, out_channels=out_ch, kernel_size=(1, 1)),
|
||||
name='weight',
|
||||
dim=0)
|
||||
self.dropout_residual = nn.Dropout(p=dropout)
|
||||
|
||||
self.pad_left = ConstantPad2d((0, 0, kernel_size - 1, 0), 0)
|
||||
|
||||
layers = OrderedDict()
|
||||
if bottlececk_dim == 0:
|
||||
layers['conv'] = nn.utils.weight_norm(
|
||||
nn.Conv2D(
|
||||
in_channels=in_ch,
|
||||
out_channels=out_ch * 2,
|
||||
kernel_size=(kernel_size, 1)),
|
||||
name='weight',
|
||||
dim=0)
|
||||
# TODO(hirofumi0810): padding?
|
||||
layers['dropout'] = nn.Dropout(p=dropout)
|
||||
layers['glu'] = GLU()
|
||||
|
||||
elif bottlececk_dim > 0:
|
||||
layers['conv_in'] = nn.utils.weight_norm(
|
||||
nn.Conv2D(
|
||||
in_channels=in_ch,
|
||||
out_channels=bottlececk_dim,
|
||||
kernel_size=(1, 1)),
|
||||
name='weight',
|
||||
dim=0)
|
||||
layers['dropout_in'] = nn.Dropout(p=dropout)
|
||||
layers['conv_bottleneck'] = nn.utils.weight_norm(
|
||||
nn.Conv2D(
|
||||
in_channels=bottlececk_dim,
|
||||
out_channels=bottlececk_dim,
|
||||
kernel_size=(kernel_size, 1)),
|
||||
name='weight',
|
||||
dim=0)
|
||||
layers['dropout'] = nn.Dropout(p=dropout)
|
||||
layers['glu'] = GLU()
|
||||
layers['conv_out'] = nn.utils.weight_norm(
|
||||
nn.Conv2D(
|
||||
in_channels=bottlececk_dim,
|
||||
out_channels=out_ch * 2,
|
||||
kernel_size=(1, 1)),
|
||||
name='weight',
|
||||
dim=0)
|
||||
layers['dropout_out'] = nn.Dropout(p=dropout)
|
||||
|
||||
self.layers = nn.Sequential(layers)
|
||||
|
||||
def forward(self, xs):
|
||||
"""Forward pass.
|
||||
Args:
|
||||
xs (FloatTensor): `[B, in_ch, T, feat_dim]`
|
||||
Returns:
|
||||
out (FloatTensor): `[B, out_ch, T, feat_dim]`
|
||||
"""
|
||||
residual = xs
|
||||
if self.conv_residual is not None:
|
||||
residual = self.dropout_residual(self.conv_residual(residual))
|
||||
xs = self.pad_left(xs) # `[B, embed_dim, T+kernel-1, 1]`
|
||||
xs = self.layers(xs) # `[B, out_ch * 2, T ,1]`
|
||||
xs = xs + residual
|
||||
return xs
|
||||
|
||||
|
||||
def get_activation(act):
|
||||
"""Return activation function."""
|
||||
# Lazy load to avoid unused import
|
||||
activation_funcs = {
|
||||
"hardtanh": paddle.nn.Hardtanh,
|
||||
"tanh": paddle.nn.Tanh,
|
||||
"relu": paddle.nn.ReLU,
|
||||
"selu": paddle.nn.SELU,
|
||||
"swish": paddle.nn.Swish,
|
||||
"gelu": paddle.nn.GELU,
|
||||
"brelu": brelu,
|
||||
}
|
||||
|
||||
return activation_funcs[act]()
|
@ -1,229 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Multi-Head Attention layer definition."""
|
||||
import math
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import initializer as I
|
||||
|
||||
__all__ = ["MultiHeadedAttention", "RelPositionMultiHeadedAttention"]
|
||||
|
||||
# Relative Positional Encodings
|
||||
# https://www.jianshu.com/p/c0608efcc26f
|
||||
# https://zhuanlan.zhihu.com/p/344604604
|
||||
|
||||
|
||||
class MultiHeadedAttention(nn.Layer):
|
||||
"""Multi-Head Attention layer."""
|
||||
|
||||
def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
|
||||
"""Construct an MultiHeadedAttention object.
|
||||
Args:
|
||||
n_head (int): The number of heads.
|
||||
n_feat (int): The number of features.
|
||||
dropout_rate (float): Dropout rate.
|
||||
"""
|
||||
super().__init__()
|
||||
assert n_feat % n_head == 0
|
||||
# We assume d_v always equals d_k
|
||||
self.d_k = n_feat // n_head
|
||||
self.h = n_head
|
||||
self.linear_q = nn.Linear(n_feat, n_feat)
|
||||
self.linear_k = nn.Linear(n_feat, n_feat)
|
||||
self.linear_v = nn.Linear(n_feat, n_feat)
|
||||
self.linear_out = nn.Linear(n_feat, n_feat)
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
def forward_qkv(self,
|
||||
query: paddle.Tensor,
|
||||
key: paddle.Tensor,
|
||||
value: paddle.Tensor
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||||
"""Transform query, key and value.
|
||||
Args:
|
||||
query (paddle.Tensor): Query tensor (#batch, time1, size).
|
||||
key (paddle.Tensor): Key tensor (#batch, time2, size).
|
||||
value (paddle.Tensor): Value tensor (#batch, time2, size).
|
||||
Returns:
|
||||
paddle.Tensor: Transformed query tensor, size
|
||||
(#batch, n_head, time1, d_k).
|
||||
paddle.Tensor: Transformed key tensor, size
|
||||
(#batch, n_head, time2, d_k).
|
||||
paddle.Tensor: Transformed value tensor, size
|
||||
(#batch, n_head, time2, d_k).
|
||||
"""
|
||||
n_batch = query.size(0)
|
||||
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
||||
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
||||
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
||||
q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
|
||||
k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k)
|
||||
v = v.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward_attention(self,
|
||||
value: paddle.Tensor,
|
||||
scores: paddle.Tensor,
|
||||
mask: Optional[paddle.Tensor]) -> paddle.Tensor:
|
||||
"""Compute attention context vector.
|
||||
Args:
|
||||
value (paddle.Tensor): Transformed value, size
|
||||
(#batch, n_head, time2, d_k).
|
||||
scores (paddle.Tensor): Attention score, size
|
||||
(#batch, n_head, time1, time2).
|
||||
mask (paddle.Tensor): Mask, size (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
Returns:
|
||||
paddle.Tensor: Transformed value weighted
|
||||
by the attention score, (#batch, time1, d_model).
|
||||
"""
|
||||
n_batch = value.size(0)
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
scores = scores.masked_fill(mask, -float('inf'))
|
||||
attn = paddle.softmax(
|
||||
scores, axis=-1).masked_fill(mask,
|
||||
0.0) # (batch, head, time1, time2)
|
||||
else:
|
||||
attn = paddle.softmax(
|
||||
scores, axis=-1) # (batch, head, time1, time2)
|
||||
|
||||
p_attn = self.dropout(attn)
|
||||
x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k)
|
||||
x = x.transpose([0, 2, 1, 3]).contiguous().view(
|
||||
n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
|
||||
|
||||
return self.linear_out(x) # (batch, time1, d_model)
|
||||
|
||||
def forward(self,
|
||||
query: paddle.Tensor,
|
||||
key: paddle.Tensor,
|
||||
value: paddle.Tensor,
|
||||
mask: Optional[paddle.Tensor]) -> paddle.Tensor:
|
||||
"""Compute scaled dot product attention.
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||
"""
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
scores = paddle.matmul(q,
|
||||
k.transpose([0, 1, 3, 2])) / math.sqrt(self.d_k)
|
||||
return self.forward_attention(v, scores, mask)
|
||||
|
||||
|
||||
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
||||
"""Multi-Head Attention layer with relative position encoding."""
|
||||
|
||||
def __init__(self, n_head, n_feat, dropout_rate):
|
||||
"""Construct an RelPositionMultiHeadedAttention object.
|
||||
Paper: https://arxiv.org/abs/1901.02860
|
||||
Args:
|
||||
n_head (int): The number of heads.
|
||||
n_feat (int): The number of features.
|
||||
dropout_rate (float): Dropout rate.
|
||||
"""
|
||||
super().__init__(n_head, n_feat, dropout_rate)
|
||||
# linear transformation for positional encoding
|
||||
self.linear_pos = nn.Linear(n_feat, n_feat, bias_attr=False)
|
||||
# these two learnable bias are used in matrix c and matrix d
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
#self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
#self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
#torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
||||
#torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
||||
pos_bias_u = self.create_parameter(
|
||||
[self.h, self.d_k], default_initializer=I.XavierUniform())
|
||||
self.add_parameter('pos_bias_u', pos_bias_u)
|
||||
pos_bias_v = self.create_parameter(
|
||||
(self.h, self.d_k), default_initializer=I.XavierUniform())
|
||||
self.add_parameter('pos_bias_v', pos_bias_v)
|
||||
|
||||
def rel_shift(self, x, zero_triu: bool=False):
|
||||
"""Compute relative positinal encoding.
|
||||
Args:
|
||||
x (paddle.Tensor): Input tensor (batch, head, time1, time1).
|
||||
zero_triu (bool): If true, return the lower triangular part of
|
||||
the matrix.
|
||||
Returns:
|
||||
paddle.Tensor: Output tensor. (batch, head, time1, time1)
|
||||
"""
|
||||
zero_pad = paddle.zeros(
|
||||
(x.size(0), x.size(1), x.size(2), 1), dtype=x.dtype)
|
||||
x_padded = paddle.cat([zero_pad, x], dim=-1)
|
||||
|
||||
x_padded = x_padded.view(x.size(0), x.size(1), x.size(3) + 1, x.size(2))
|
||||
x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1]
|
||||
|
||||
if zero_triu:
|
||||
ones = paddle.ones((x.size(2), x.size(3)))
|
||||
x = x * paddle.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
||||
|
||||
return x
|
||||
|
||||
def forward(self,
|
||||
query: paddle.Tensor,
|
||||
key: paddle.Tensor,
|
||||
value: paddle.Tensor,
|
||||
pos_emb: paddle.Tensor,
|
||||
mask: Optional[paddle.Tensor]):
|
||||
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
||||
Args:
|
||||
query (paddle.Tensor): Query tensor (#batch, time1, size).
|
||||
key (paddle.Tensor): Key tensor (#batch, time2, size).
|
||||
value (paddle.Tensor): Value tensor (#batch, time2, size).
|
||||
pos_emb (paddle.Tensor): Positional embedding tensor
|
||||
(#batch, time1, size).
|
||||
mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
Returns:
|
||||
paddle.Tensor: Output tensor (#batch, time1, d_model).
|
||||
"""
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
|
||||
|
||||
n_batch_pos = pos_emb.size(0)
|
||||
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
||||
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
|
||||
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3])
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3])
|
||||
|
||||
# compute attention score
|
||||
# first compute matrix a and matrix c
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
# (batch, head, time1, time2)
|
||||
matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2]))
|
||||
|
||||
# compute matrix b and matrix d
|
||||
# (batch, head, time1, time2)
|
||||
matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2]))
|
||||
# Remove rel_shift since it is useless in speech recognition,
|
||||
# and it requires special attention for streaming.
|
||||
# matrix_bd = self.rel_shift(matrix_bd)
|
||||
|
||||
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
||||
self.d_k) # (batch, head, time1, time2)
|
||||
|
||||
return self.forward_attention(v, scores, mask)
|
@ -1,366 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
__all__ = ['CRF']
|
||||
|
||||
|
||||
class CRF(nn.Layer):
|
||||
"""
|
||||
Linear-chain Conditional Random Field (CRF).
|
||||
|
||||
Args:
|
||||
nb_labels (int): number of labels in your tagset, including special symbols.
|
||||
bos_tag_id (int): integer representing the beginning of sentence symbol in
|
||||
your tagset.
|
||||
eos_tag_id (int): integer representing the end of sentence symbol in your tagset.
|
||||
pad_tag_id (int, optional): integer representing the pad symbol in your tagset.
|
||||
If None, the model will treat the PAD as a normal tag. Otherwise, the model
|
||||
will apply constraints for PAD transitions.
|
||||
batch_first (bool): Whether the first dimension represents the batch dimension.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
nb_labels: int,
|
||||
bos_tag_id: int,
|
||||
eos_tag_id: int,
|
||||
pad_tag_id: int=None,
|
||||
batch_first: bool=True):
|
||||
super().__init__()
|
||||
|
||||
self.nb_labels = nb_labels
|
||||
self.BOS_TAG_ID = bos_tag_id
|
||||
self.EOS_TAG_ID = eos_tag_id
|
||||
self.PAD_TAG_ID = pad_tag_id
|
||||
self.batch_first = batch_first
|
||||
|
||||
# initialize transitions from a random uniform distribution between -0.1 and 0.1
|
||||
self.transitions = self.create_parameter(
|
||||
[self.nb_labels, self.nb_labels],
|
||||
default_initializer=nn.initializer.Uniform(-0.1, 0.1))
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
# enforce contraints (rows=from, columns=to) with a big negative number
|
||||
# so exp(-10000) will tend to zero
|
||||
|
||||
# no transitions allowed to the beginning of sentence
|
||||
self.transitions[:, self.BOS_TAG_ID] = -10000.0
|
||||
# no transition alloed from the end of sentence
|
||||
self.transitions[self.EOS_TAG_ID, :] = -10000.0
|
||||
|
||||
if self.PAD_TAG_ID is not None:
|
||||
# no transitions from padding
|
||||
self.transitions[self.PAD_TAG_ID, :] = -10000.0
|
||||
# no transitions to padding
|
||||
self.transitions[:, self.PAD_TAG_ID] = -10000.0
|
||||
# except if the end of sentence is reached
|
||||
# or we are already in a pad position
|
||||
self.transitions[self.PAD_TAG_ID, self.EOS_TAG_ID] = 0.0
|
||||
self.transitions[self.PAD_TAG_ID, self.PAD_TAG_ID] = 0.0
|
||||
|
||||
def forward(self,
|
||||
emissions: paddle.Tensor,
|
||||
tags: paddle.Tensor,
|
||||
mask: paddle.Tensor=None) -> paddle.Tensor:
|
||||
"""Compute the negative log-likelihood. See `log_likelihood` method."""
|
||||
nll = -self.log_likelihood(emissions, tags, mask=mask)
|
||||
return nll
|
||||
|
||||
def log_likelihood(self, emissions, tags, mask=None):
|
||||
"""Compute the probability of a sequence of tags given a sequence of
|
||||
emissions scores.
|
||||
|
||||
Args:
|
||||
emissions (paddle.Tensor): Sequence of emissions for each label.
|
||||
Shape of (batch_size, seq_len, nb_labels) if batch_first is True,
|
||||
(seq_len, batch_size, nb_labels) otherwise.
|
||||
tags (paddle.LongTensor): Sequence of labels.
|
||||
Shape of (batch_size, seq_len) if batch_first is True,
|
||||
(seq_len, batch_size) otherwise.
|
||||
mask (paddle.FloatTensor, optional): Tensor representing valid positions.
|
||||
If None, all positions are considered valid.
|
||||
Shape of (batch_size, seq_len) if batch_first is True,
|
||||
(seq_len, batch_size) otherwise.
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: sum of the log-likelihoods for each sequence in the batch.
|
||||
Shape of ()
|
||||
"""
|
||||
# fix tensors order by setting batch as the first dimension
|
||||
if not self.batch_first:
|
||||
emissions = emissions.transpose(0, 1)
|
||||
tags = tags.transpose(0, 1)
|
||||
|
||||
if mask is None:
|
||||
mask = paddle.ones(emissions.shape[:2], dtype=paddle.float)
|
||||
|
||||
scores = self._compute_scores(emissions, tags, mask=mask)
|
||||
partition = self._compute_log_partition(emissions, mask=mask)
|
||||
return paddle.sum(scores - partition)
|
||||
|
||||
def decode(self, emissions, mask=None):
|
||||
"""Find the most probable sequence of labels given the emissions using
|
||||
the Viterbi algorithm.
|
||||
|
||||
Args:
|
||||
emissions (paddle.Tensor): Sequence of emissions for each label.
|
||||
Shape (batch_size, seq_len, nb_labels) if batch_first is True,
|
||||
(seq_len, batch_size, nb_labels) otherwise.
|
||||
mask (paddle.FloatTensor, optional): Tensor representing valid positions.
|
||||
If None, all positions are considered valid.
|
||||
Shape (batch_size, seq_len) if batch_first is True,
|
||||
(seq_len, batch_size) otherwise.
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: the viterbi score for the for each batch.
|
||||
Shape of (batch_size,)
|
||||
list of lists: the best viterbi sequence of labels for each batch. [B, T]
|
||||
"""
|
||||
# fix tensors order by setting batch as the first dimension
|
||||
if not self.batch_first:
|
||||
emissions = emissions.transpose(0, 1)
|
||||
tags = tags.transpose(0, 1)
|
||||
|
||||
if mask is None:
|
||||
mask = paddle.ones(emissions.shape[:2], dtype=paddle.float)
|
||||
|
||||
scores, sequences = self._viterbi_decode(emissions, mask)
|
||||
return scores, sequences
|
||||
|
||||
def _compute_scores(self, emissions, tags, mask):
|
||||
"""Compute the scores for a given batch of emissions with their tags.
|
||||
|
||||
Args:
|
||||
emissions (paddle.Tensor): (batch_size, seq_len, nb_labels)
|
||||
tags (Paddle.LongTensor): (batch_size, seq_len)
|
||||
mask (Paddle.FloatTensor): (batch_size, seq_len)
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: Scores for each batch.
|
||||
Shape of (batch_size,)
|
||||
"""
|
||||
batch_size, seq_length = tags.shape
|
||||
scores = paddle.zeros([batch_size])
|
||||
|
||||
# save first and last tags to be used later
|
||||
first_tags = tags[:, 0]
|
||||
last_valid_idx = mask.int().sum(1) - 1
|
||||
|
||||
# TODO(Hui Zhang): not support fancy index.
|
||||
# last_tags = tags.gather(last_valid_idx.unsqueeze(1), axis=1).squeeze()
|
||||
batch_idx = paddle.arange(batch_size, dtype=last_valid_idx.dtype)
|
||||
gather_last_valid_idx = paddle.stack(
|
||||
[batch_idx, last_valid_idx], axis=-1)
|
||||
last_tags = tags.gather_nd(gather_last_valid_idx)
|
||||
|
||||
# add the transition from BOS to the first tags for each batch
|
||||
# t_scores = self.transitions[self.BOS_TAG_ID, first_tags]
|
||||
t_scores = self.transitions[self.BOS_TAG_ID].gather(first_tags)
|
||||
|
||||
# add the [unary] emission scores for the first tags for each batch
|
||||
# for all batches, the first word, see the correspondent emissions
|
||||
# for the first tags (which is a list of ids):
|
||||
# emissions[:, 0, [tag_1, tag_2, ..., tag_nblabels]]
|
||||
# e_scores = emissions[:, 0].gather(1, first_tags.unsqueeze(1)).squeeze()
|
||||
gather_first_tags_idx = paddle.stack([batch_idx, first_tags], axis=-1)
|
||||
e_scores = emissions[:, 0].gather_nd(gather_first_tags_idx)
|
||||
|
||||
# the scores for a word is just the sum of both scores
|
||||
scores += e_scores + t_scores
|
||||
|
||||
# now lets do this for each remaining word
|
||||
for i in range(1, seq_length):
|
||||
|
||||
# we could: iterate over batches, check if we reached a mask symbol
|
||||
# and stop the iteration, but vecotrizing is faster due to gpu,
|
||||
# so instead we perform an element-wise multiplication
|
||||
is_valid = mask[:, i]
|
||||
|
||||
previous_tags = tags[:, i - 1]
|
||||
current_tags = tags[:, i]
|
||||
|
||||
# calculate emission and transition scores as we did before
|
||||
# e_scores = emissions[:, i].gather(1, current_tags.unsqueeze(1)).squeeze()
|
||||
gather_current_tags_idx = paddle.stack(
|
||||
[batch_idx, current_tags], axis=-1)
|
||||
e_scores = emissions[:, i].gather_nd(gather_current_tags_idx)
|
||||
# t_scores = self.transitions[previous_tags, current_tags]
|
||||
gather_transitions_idx = paddle.stack(
|
||||
[previous_tags, current_tags], axis=-1)
|
||||
t_scores = self.transitions.gather_nd(gather_transitions_idx)
|
||||
|
||||
# apply the mask
|
||||
e_scores = e_scores * is_valid
|
||||
t_scores = t_scores * is_valid
|
||||
|
||||
scores += e_scores + t_scores
|
||||
|
||||
# add the transition from the end tag to the EOS tag for each batch
|
||||
# scores += self.transitions[last_tags, self.EOS_TAG_ID]
|
||||
scores += self.transitions.gather(last_tags)[:, self.EOS_TAG_ID]
|
||||
|
||||
return scores
|
||||
|
||||
def _compute_log_partition(self, emissions, mask):
|
||||
"""Compute the partition function in log-space using the forward-algorithm.
|
||||
|
||||
Args:
|
||||
emissions (paddle.Tensor): (batch_size, seq_len, nb_labels)
|
||||
mask (Paddle.FloatTensor): (batch_size, seq_len)
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: the partition scores for each batch.
|
||||
Shape of (batch_size,)
|
||||
"""
|
||||
batch_size, seq_length, nb_labels = emissions.shape
|
||||
|
||||
# in the first iteration, BOS will have all the scores
|
||||
alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze(
|
||||
0) + emissions[:, 0]
|
||||
|
||||
for i in range(1, seq_length):
|
||||
# (bs, nb_labels) -> (bs, 1, nb_labels)
|
||||
e_scores = emissions[:, i].unsqueeze(1)
|
||||
|
||||
# (nb_labels, nb_labels) -> (bs, nb_labels, nb_labels)
|
||||
t_scores = self.transitions.unsqueeze(0)
|
||||
|
||||
# (bs, nb_labels) -> (bs, nb_labels, 1)
|
||||
a_scores = alphas.unsqueeze(2)
|
||||
|
||||
scores = e_scores + t_scores + a_scores
|
||||
new_alphas = paddle.logsumexp(scores, axis=1)
|
||||
|
||||
# set alphas if the mask is valid, otherwise keep the current values
|
||||
is_valid = mask[:, i].unsqueeze(-1)
|
||||
alphas = is_valid * new_alphas + (1 - is_valid) * alphas
|
||||
|
||||
# add the scores for the final transition
|
||||
last_transition = self.transitions[:, self.EOS_TAG_ID]
|
||||
end_scores = alphas + last_transition.unsqueeze(0)
|
||||
|
||||
# return a *log* of sums of exps
|
||||
return paddle.logsumexp(end_scores, axis=1)
|
||||
|
||||
def _viterbi_decode(self, emissions, mask):
|
||||
"""Compute the viterbi algorithm to find the most probable sequence of labels
|
||||
given a sequence of emissions.
|
||||
|
||||
Args:
|
||||
emissions (paddle.Tensor): (batch_size, seq_len, nb_labels)
|
||||
mask (Paddle.FloatTensor): (batch_size, seq_len)
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: the viterbi score for the for each batch.
|
||||
Shape of (batch_size,)
|
||||
list of lists of ints: the best viterbi sequence of labels for each batch
|
||||
"""
|
||||
batch_size, seq_length, nb_labels = emissions.shape
|
||||
|
||||
# in the first iteration, BOS will have all the scores and then, the max
|
||||
alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze(
|
||||
0) + emissions[:, 0]
|
||||
|
||||
backpointers = []
|
||||
|
||||
for i in range(1, seq_length):
|
||||
# (bs, nb_labels) -> (bs, 1, nb_labels)
|
||||
e_scores = emissions[:, i].unsqueeze(1)
|
||||
|
||||
# (nb_labels, nb_labels) -> (bs, nb_labels, nb_labels)
|
||||
t_scores = self.transitions.unsqueeze(0)
|
||||
|
||||
# (bs, nb_labels) -> (bs, nb_labels, 1)
|
||||
a_scores = alphas.unsqueeze(2)
|
||||
|
||||
# combine current scores with previous alphas
|
||||
scores = e_scores + t_scores + a_scores
|
||||
|
||||
# so far is exactly like the forward algorithm,
|
||||
# but now, instead of calculating the logsumexp,
|
||||
# we will find the highest score and the tag associated with it
|
||||
# max_scores, max_score_tags = paddle.max(scores, axis=1)
|
||||
max_scores = paddle.max(scores, axis=1)
|
||||
max_score_tags = paddle.argmax(scores, axis=1)
|
||||
|
||||
# set alphas if the mask is valid, otherwise keep the current values
|
||||
is_valid = mask[:, i].unsqueeze(-1)
|
||||
alphas = is_valid * max_scores + (1 - is_valid) * alphas
|
||||
|
||||
# add the max_score_tags for our list of backpointers
|
||||
# max_scores has shape (batch_size, nb_labels) so we transpose it to
|
||||
# be compatible with our previous loopy version of viterbi
|
||||
backpointers.append(max_score_tags.t())
|
||||
|
||||
# add the scores for the final transition
|
||||
last_transition = self.transitions[:, self.EOS_TAG_ID]
|
||||
end_scores = alphas + last_transition.unsqueeze(0)
|
||||
|
||||
# get the final most probable score and the final most probable tag
|
||||
# max_final_scores, max_final_tags = paddle.max(end_scores, axis=1)
|
||||
max_final_scores = paddle.max(end_scores, axis=1)
|
||||
max_final_tags = paddle.argmax(end_scores, axis=1)
|
||||
|
||||
# find the best sequence of labels for each sample in the batch
|
||||
best_sequences = []
|
||||
emission_lengths = mask.int().sum(axis=1)
|
||||
for i in range(batch_size):
|
||||
|
||||
# recover the original sentence length for the i-th sample in the batch
|
||||
sample_length = emission_lengths[i].item()
|
||||
|
||||
# recover the max tag for the last timestep
|
||||
sample_final_tag = max_final_tags[i].item()
|
||||
|
||||
# limit the backpointers until the last but one
|
||||
# since the last corresponds to the sample_final_tag
|
||||
sample_backpointers = backpointers[:sample_length - 1]
|
||||
|
||||
# follow the backpointers to build the sequence of labels
|
||||
sample_path = self._find_best_path(i, sample_final_tag,
|
||||
sample_backpointers)
|
||||
|
||||
# add this path to the list of best sequences
|
||||
best_sequences.append(sample_path)
|
||||
|
||||
return max_final_scores, best_sequences
|
||||
|
||||
def _find_best_path(self, sample_id, best_tag, backpointers):
|
||||
"""Auxiliary function to find the best path sequence for a specific sample.
|
||||
|
||||
Args:
|
||||
sample_id (int): sample index in the range [0, batch_size)
|
||||
best_tag (int): tag which maximizes the final score
|
||||
backpointers (list of lists of tensors): list of pointers with
|
||||
shape (seq_len_i-1, nb_labels, batch_size) where seq_len_i
|
||||
represents the length of the ith sample in the batch
|
||||
|
||||
Returns:
|
||||
list of ints: a list of tag indexes representing the bast path
|
||||
"""
|
||||
# add the final best_tag to our best path
|
||||
best_path = [best_tag]
|
||||
|
||||
# traverse the backpointers in backwards
|
||||
for backpointers_t in reversed(backpointers):
|
||||
|
||||
# recover the best_tag at this timestep
|
||||
best_tag = backpointers_t[best_tag][sample_id].item()
|
||||
|
||||
# append to the beginning of the list so we don't need to reverse it later
|
||||
best_path.insert(0, best_tag)
|
||||
|
||||
return best_path
|
@ -1,98 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class FocalLossHX(nn.Layer):
|
||||
def __init__(self, gamma=0, size_average=True):
|
||||
super(FocalLoss, self).__init__()
|
||||
self.gamma = gamma
|
||||
self.size_average = size_average
|
||||
|
||||
def forward(self, input, target):
|
||||
# print('input')
|
||||
# print(input.shape)
|
||||
# print(target.shape)
|
||||
|
||||
if input.dim() > 2:
|
||||
input = paddle.reshape(
|
||||
input,
|
||||
shape=[input.size(0), input.size(1), -1]) # N,C,H,W => N,C,H*W
|
||||
input = input.transpose(1, 2) # N,C,H*W => N,H*W,C
|
||||
input = paddle.reshape(
|
||||
input, shape=[-1, input.size(2)]) # N,H*W,C => N*H*W,C
|
||||
target = paddle.reshape(target, shape=[-1])
|
||||
|
||||
logpt = F.log_softmax(input)
|
||||
# print('logpt')
|
||||
# print(logpt.shape)
|
||||
# print(logpt)
|
||||
|
||||
# get true class column from each row
|
||||
all_rows = paddle.arange(len(input))
|
||||
# print(target)
|
||||
log_pt = logpt.numpy()[all_rows.numpy(), target.numpy()]
|
||||
|
||||
pt = paddle.to_tensor(log_pt, dtype='float64').exp()
|
||||
ce = F.cross_entropy(input, target, reduction='none')
|
||||
# print('ce')
|
||||
# print(ce.shape)
|
||||
|
||||
loss = (1 - pt)**self.gamma * ce
|
||||
# print('ce:%f'%ce.mean())
|
||||
# print('fl:%f'%loss.mean())
|
||||
if self.size_average:
|
||||
return loss.mean()
|
||||
else:
|
||||
return loss.sum()
|
||||
|
||||
|
||||
class FocalLoss(nn.Layer):
|
||||
"""
|
||||
Focal Loss.
|
||||
Code referenced from:
|
||||
https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py
|
||||
Args:
|
||||
gamma (float): the coefficient of Focal Loss.
|
||||
ignore_index (int64): Specifies a target value that is ignored
|
||||
and does not contribute to the input gradient. Default ``255``.
|
||||
"""
|
||||
|
||||
def __init__(self, gamma=2.0):
|
||||
super(FocalLoss, self).__init__()
|
||||
self.gamma = gamma
|
||||
|
||||
def forward(self, logit, label):
|
||||
#####logit = F.softmax(logit)
|
||||
# logit = paddle.reshape(
|
||||
# logit, [logit.shape[0], logit.shape[1], -1]) # N,C,H,W => N,C,H*W
|
||||
# logit = paddle.transpose(logit, [0, 2, 1]) # N,C,H*W => N,H*W,C
|
||||
# logit = paddle.reshape(logit,
|
||||
# [-1, logit.shape[2]]) # N,H*W,C => N*H*W,C
|
||||
label = paddle.reshape(label, [-1, 1])
|
||||
range_ = paddle.arange(0, label.shape[0])
|
||||
range_ = paddle.unsqueeze(range_, axis=-1)
|
||||
label = paddle.cast(label, dtype='int64')
|
||||
label = paddle.concat([range_, label], axis=-1)
|
||||
logpt = F.log_softmax(logit)
|
||||
logpt = paddle.gather_nd(logpt, label)
|
||||
|
||||
pt = paddle.exp(logpt.detach())
|
||||
loss = -1 * (1 - pt)**self.gamma * logpt
|
||||
loss = paddle.mean(loss)
|
||||
# print(loss)
|
||||
# print(logpt)
|
||||
return loss
|
@ -1,304 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Text
|
||||
from typing import Union
|
||||
|
||||
import paddle
|
||||
from paddle import distributed as dist
|
||||
from paddle.optimizer import Optimizer
|
||||
from speechtask.punctuation_restoration.utils import mp_tools
|
||||
# from speechtask.punctuation_restoration.utils.log import Log
|
||||
|
||||
# logger = Log(__name__).getlog()
|
||||
|
||||
__all__ = ["Checkpoint"]
|
||||
|
||||
|
||||
class Checkpoint():
|
||||
def __init__(self,
|
||||
logger,
|
||||
kbest_n: int=5,
|
||||
latest_n: int=1,
|
||||
metric_type='val_loss'):
|
||||
self.best_records: Mapping[Path, float] = {}
|
||||
self.latest_records = []
|
||||
self.kbest_n = kbest_n
|
||||
self.latest_n = latest_n
|
||||
self._save_all = (kbest_n == -1)
|
||||
self.logger = logger
|
||||
self.metric_type = metric_type
|
||||
|
||||
def add_checkpoint(self,
|
||||
checkpoint_dir,
|
||||
tag_or_iteration: Union[int, Text],
|
||||
model: paddle.nn.Layer,
|
||||
optimizer: Optimizer=None,
|
||||
infos: dict=None):
|
||||
"""Save checkpoint in best_n and latest_n.
|
||||
Args:
|
||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||
tag_or_iteration (int or str): the latest iteration(step or epoch) number or tag.
|
||||
model (Layer): model to be checkpointed.
|
||||
optimizer (Optimizer, optional): optimizer to be checkpointed.
|
||||
infos (dict or None)): any info you want to save.
|
||||
metric_type (str, optional): metric type. Defaults to 'val_loss'.
|
||||
"""
|
||||
metric_type = self.metric_type
|
||||
if (metric_type not in infos.keys()):
|
||||
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
|
||||
optimizer, infos)
|
||||
return
|
||||
|
||||
#save best
|
||||
if self._should_save_best(infos[metric_type]):
|
||||
self._save_best_checkpoint_and_update(
|
||||
infos[metric_type], checkpoint_dir, tag_or_iteration, model,
|
||||
optimizer, infos)
|
||||
#save latest
|
||||
self._save_latest_checkpoint_and_update(
|
||||
checkpoint_dir, tag_or_iteration, model, optimizer, infos)
|
||||
|
||||
if isinstance(tag_or_iteration, int):
|
||||
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
|
||||
|
||||
def load_parameters(self,
|
||||
model,
|
||||
optimizer=None,
|
||||
checkpoint_dir=None,
|
||||
checkpoint_path=None,
|
||||
record_file="checkpoint_latest"):
|
||||
"""Load a last model checkpoint from disk.
|
||||
Args:
|
||||
model (Layer): model to load parameters.
|
||||
optimizer (Optimizer, optional): optimizer to load states if needed.
|
||||
Defaults to None.
|
||||
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
||||
checkpoint_path (str, optional): if specified, load the checkpoint
|
||||
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
|
||||
be ignored. Defaults to None.
|
||||
record_file "checkpoint_latest" or "checkpoint_best"
|
||||
Returns:
|
||||
configs (dict): epoch or step, lr and other meta info should be saved.
|
||||
"""
|
||||
configs = {}
|
||||
|
||||
if checkpoint_path is not None:
|
||||
pass
|
||||
elif checkpoint_dir is not None and record_file is not None:
|
||||
# load checkpint from record file
|
||||
checkpoint_record = os.path.join(checkpoint_dir, record_file)
|
||||
iteration = self._load_checkpoint_idx(checkpoint_record)
|
||||
if iteration == -1:
|
||||
return configs
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
"{}".format(iteration))
|
||||
else:
|
||||
raise ValueError(
|
||||
"At least one of 'checkpoint_path' or 'checkpoint_dir' should be specified!"
|
||||
)
|
||||
|
||||
rank = dist.get_rank()
|
||||
|
||||
params_path = checkpoint_path + ".pdparams"
|
||||
model_dict = paddle.load(params_path)
|
||||
model.set_state_dict(model_dict)
|
||||
self.logger.info(
|
||||
"Rank {}: loaded model from {}".format(rank, params_path))
|
||||
|
||||
optimizer_path = checkpoint_path + ".pdopt"
|
||||
if optimizer and os.path.isfile(optimizer_path):
|
||||
optimizer_dict = paddle.load(optimizer_path)
|
||||
optimizer.set_state_dict(optimizer_dict)
|
||||
self.logger.info("Rank {}: loaded optimizer state from {}".format(
|
||||
rank, optimizer_path))
|
||||
|
||||
info_path = re.sub('.pdparams$', '.json', params_path)
|
||||
if os.path.exists(info_path):
|
||||
with open(info_path, 'r') as fin:
|
||||
configs = json.load(fin)
|
||||
return configs
|
||||
|
||||
def load_latest_parameters(self,
|
||||
model,
|
||||
optimizer=None,
|
||||
checkpoint_dir=None,
|
||||
checkpoint_path=None):
|
||||
"""Load a last model checkpoint from disk.
|
||||
Args:
|
||||
model (Layer): model to load parameters.
|
||||
optimizer (Optimizer, optional): optimizer to load states if needed.
|
||||
Defaults to None.
|
||||
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
||||
checkpoint_path (str, optional): if specified, load the checkpoint
|
||||
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
|
||||
be ignored. Defaults to None.
|
||||
Returns:
|
||||
configs (dict): epoch or step, lr and other meta info should be saved.
|
||||
"""
|
||||
return self.load_parameters(model, optimizer, checkpoint_dir,
|
||||
checkpoint_path, "checkpoint_latest")
|
||||
|
||||
def load_best_parameters(self,
|
||||
model,
|
||||
optimizer=None,
|
||||
checkpoint_dir=None,
|
||||
checkpoint_path=None):
|
||||
"""Load a last model checkpoint from disk.
|
||||
Args:
|
||||
model (Layer): model to load parameters.
|
||||
optimizer (Optimizer, optional): optimizer to load states if needed.
|
||||
Defaults to None.
|
||||
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
||||
checkpoint_path (str, optional): if specified, load the checkpoint
|
||||
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
|
||||
be ignored. Defaults to None.
|
||||
Returns:
|
||||
configs (dict): epoch or step, lr and other meta info should be saved.
|
||||
"""
|
||||
return self.load_parameters(model, optimizer, checkpoint_dir,
|
||||
checkpoint_path, "checkpoint_best")
|
||||
|
||||
def _should_save_best(self, metric: float) -> bool:
|
||||
if not self._best_full():
|
||||
return True
|
||||
|
||||
# already full
|
||||
worst_record_path = max(self.best_records, key=self.best_records.get)
|
||||
# worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0]
|
||||
worst_metric = self.best_records[worst_record_path]
|
||||
return metric < worst_metric
|
||||
|
||||
def _best_full(self):
|
||||
return (not self._save_all) and len(self.best_records) == self.kbest_n
|
||||
|
||||
def _latest_full(self):
|
||||
return len(self.latest_records) == self.latest_n
|
||||
|
||||
def _save_best_checkpoint_and_update(self, metric, checkpoint_dir,
|
||||
tag_or_iteration, model, optimizer,
|
||||
infos):
|
||||
# remove the worst
|
||||
if self._best_full():
|
||||
worst_record_path = max(self.best_records,
|
||||
key=self.best_records.get)
|
||||
self.best_records.pop(worst_record_path)
|
||||
if (worst_record_path not in self.latest_records):
|
||||
self.logger.info(
|
||||
"remove the worst checkpoint: {}".format(worst_record_path))
|
||||
self._del_checkpoint(checkpoint_dir, worst_record_path)
|
||||
|
||||
# add the new one
|
||||
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
|
||||
optimizer, infos)
|
||||
self.best_records[tag_or_iteration] = metric
|
||||
|
||||
def _save_latest_checkpoint_and_update(
|
||||
self, checkpoint_dir, tag_or_iteration, model, optimizer, infos):
|
||||
# remove the old
|
||||
if self._latest_full():
|
||||
to_del_fn = self.latest_records.pop(0)
|
||||
if (to_del_fn not in self.best_records.keys()):
|
||||
self.logger.info(
|
||||
"remove the latest checkpoint: {}".format(to_del_fn))
|
||||
self._del_checkpoint(checkpoint_dir, to_del_fn)
|
||||
self.latest_records.append(tag_or_iteration)
|
||||
|
||||
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
|
||||
optimizer, infos)
|
||||
|
||||
def _del_checkpoint(self, checkpoint_dir, tag_or_iteration):
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
"{}".format(tag_or_iteration))
|
||||
for filename in glob.glob(checkpoint_path + ".*"):
|
||||
os.remove(filename)
|
||||
self.logger.info("delete file: {}".format(filename))
|
||||
|
||||
def _load_checkpoint_idx(self, checkpoint_record: str) -> int:
|
||||
"""Get the iteration number corresponding to the latest saved checkpoint.
|
||||
Args:
|
||||
checkpoint_path (str): the saved path of checkpoint.
|
||||
Returns:
|
||||
int: the latest iteration number. -1 for no checkpoint to load.
|
||||
"""
|
||||
if not os.path.isfile(checkpoint_record):
|
||||
return -1
|
||||
|
||||
# Fetch the latest checkpoint index.
|
||||
with open(checkpoint_record, "rt") as handle:
|
||||
latest_checkpoint = handle.readlines()[-1].strip()
|
||||
iteration = int(latest_checkpoint.split(":")[-1])
|
||||
return iteration
|
||||
|
||||
def _save_checkpoint_record(self, checkpoint_dir: str, iteration: int):
|
||||
"""Save the iteration number of the latest model to be checkpoint record.
|
||||
Args:
|
||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||
iteration (int): the latest iteration number.
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
checkpoint_record_latest = os.path.join(checkpoint_dir,
|
||||
"checkpoint_latest")
|
||||
checkpoint_record_best = os.path.join(checkpoint_dir, "checkpoint_best")
|
||||
|
||||
with open(checkpoint_record_best, "w") as handle:
|
||||
for i in self.best_records.keys():
|
||||
handle.write("model_checkpoint_path:{}\n".format(i))
|
||||
with open(checkpoint_record_latest, "w") as handle:
|
||||
for i in self.latest_records:
|
||||
handle.write("model_checkpoint_path:{}\n".format(i))
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
def _save_parameters(self,
|
||||
checkpoint_dir: str,
|
||||
tag_or_iteration: Union[int, str],
|
||||
model: paddle.nn.Layer,
|
||||
optimizer: Optimizer=None,
|
||||
infos: dict=None):
|
||||
"""Checkpoint the latest trained model parameters.
|
||||
Args:
|
||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||
tag_or_iteration (int or str): the latest iteration(step or epoch) number.
|
||||
model (Layer): model to be checkpointed.
|
||||
optimizer (Optimizer, optional): optimizer to be checkpointed.
|
||||
Defaults to None.
|
||||
infos (dict or None): any info you want to save.
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
"{}".format(tag_or_iteration))
|
||||
|
||||
model_dict = model.state_dict()
|
||||
params_path = checkpoint_path + ".pdparams"
|
||||
paddle.save(model_dict, params_path)
|
||||
self.logger.info("Saved model to {}".format(params_path))
|
||||
|
||||
if optimizer:
|
||||
opt_dict = optimizer.state_dict()
|
||||
optimizer_path = checkpoint_path + ".pdopt"
|
||||
paddle.save(opt_dict, optimizer_path)
|
||||
self.logger.info(
|
||||
"Saved optimzier state to {}".format(optimizer_path))
|
||||
|
||||
info_path = re.sub('.pdparams$', '.json', params_path)
|
||||
infos = {} if infos is None else infos
|
||||
with open(info_path, 'w') as fout:
|
||||
data = json.dumps(infos)
|
||||
fout.write(data)
|
@ -1,88 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
from paddle import nn
|
||||
|
||||
__all__ = [
|
||||
"summary", "gradient_norm", "freeze", "unfreeze", "print_grads",
|
||||
"print_params"
|
||||
]
|
||||
|
||||
|
||||
def summary(layer: nn.Layer, print_func=print):
|
||||
if print_func is None:
|
||||
return
|
||||
num_params = num_elements = 0
|
||||
for name, param in layer.state_dict().items():
|
||||
if print_func:
|
||||
print_func(
|
||||
"{} | {} | {}".format(name, param.shape, np.prod(param.shape)))
|
||||
num_elements += np.prod(param.shape)
|
||||
num_params += 1
|
||||
if print_func:
|
||||
num_elements = num_elements / 1024**2
|
||||
print_func(
|
||||
f"Total parameters: {num_params}, {num_elements:.2f}M elements.")
|
||||
|
||||
|
||||
def print_grads(model, print_func=print):
|
||||
if print_func is None:
|
||||
return
|
||||
for n, p in model.named_parameters():
|
||||
msg = f"param grad: {n}: shape: {p.shape} grad: {p.grad}"
|
||||
print_func(msg)
|
||||
|
||||
|
||||
def print_params(model, print_func=print):
|
||||
if print_func is None:
|
||||
return
|
||||
total = 0.0
|
||||
num_params = 0.0
|
||||
for n, p in model.named_parameters():
|
||||
msg = f"{n} | {p.shape} | {np.prod(p.shape)} | {not p.stop_gradient}"
|
||||
total += np.prod(p.shape)
|
||||
num_params += 1
|
||||
if print_func:
|
||||
print_func(msg)
|
||||
if print_func:
|
||||
total = total / 1024**2
|
||||
print_func(f"Total parameters: {num_params}, {total:.2f}M elements.")
|
||||
|
||||
|
||||
def gradient_norm(layer: nn.Layer):
|
||||
grad_norm_dict = {}
|
||||
for name, param in layer.state_dict().items():
|
||||
if param.trainable:
|
||||
grad = param.gradient() # return numpy.ndarray
|
||||
grad_norm_dict[name] = np.linalg.norm(grad) / grad.size
|
||||
return grad_norm_dict
|
||||
|
||||
|
||||
def recursively_remove_weight_norm(layer: nn.Layer):
|
||||
for layer in layer.sublayers():
|
||||
try:
|
||||
nn.utils.remove_weight_norm(layer)
|
||||
except ValueError as e:
|
||||
# ther is not weight norm hoom in this layer
|
||||
pass
|
||||
|
||||
|
||||
def freeze(layer: nn.Layer):
|
||||
for param in layer.parameters():
|
||||
param.trainable = False
|
||||
|
||||
|
||||
def unfreeze(layer: nn.Layer):
|
||||
for param in layer.parameters():
|
||||
param.trainable = True
|
@ -1,81 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains common utility functions."""
|
||||
import distutils.util
|
||||
import math
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
__all__ = ['print_arguments', 'add_arguments', "log_add"]
|
||||
|
||||
|
||||
def print_arguments(args, info=None):
|
||||
"""Print argparse's arguments.
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("name", default="Jonh", type=str, help="User name.")
|
||||
args = parser.parse_args()
|
||||
print_arguments(args)
|
||||
|
||||
:param args: Input argparse.Namespace for printing.
|
||||
:type args: argparse.Namespace
|
||||
"""
|
||||
filename = ""
|
||||
if info:
|
||||
filename = info["__file__"]
|
||||
filename = os.path.basename(filename)
|
||||
print(f"----------- {filename} Configuration Arguments -----------")
|
||||
for arg, value in sorted(vars(args).items()):
|
||||
print("%s: %s" % (arg, value))
|
||||
print("-----------------------------------------------------------")
|
||||
|
||||
|
||||
def add_arguments(argname, type, default, help, argparser, **kwargs):
|
||||
"""Add argparse's argument.
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
add_argument("name", str, "Jonh", "User name.", parser)
|
||||
args = parser.parse_args()
|
||||
"""
|
||||
type = distutils.util.strtobool if type == bool else type
|
||||
argparser.add_argument(
|
||||
"--" + argname,
|
||||
default=default,
|
||||
type=type,
|
||||
help=help + ' Default: %(default)s.',
|
||||
**kwargs)
|
||||
|
||||
|
||||
def log_add(args: List[int]) -> float:
|
||||
"""Stable log add
|
||||
|
||||
Args:
|
||||
args (List[int]): log scores
|
||||
|
||||
Returns:
|
||||
float: sum of log scores
|
||||
"""
|
||||
if all(a == -float('inf') for a in args):
|
||||
return -float('inf')
|
||||
a_max = max(args)
|
||||
lsp = math.log(sum(math.exp(a - a_max) for a in args))
|
||||
return a_max + lsp
|
Loading…
Reference in new issue