From b99b0ac42f3243d3f59d18138bff99be726ac582 Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Thu, 28 Oct 2021 18:13:56 +0800 Subject: [PATCH] Merge punctuation restoration task into PaddleSpeech. --- text_processing/.gitignore | 7 + text_processing/README.md | 25 + .../punctuation_restoration/chinese/README.md | 35 + .../chinese/conf/blstm.yaml | 34 + .../chinese/conf/data_conf/chinese.yaml | 7 + .../chinese/conf/train_conf/bertBLSTM_zh.yaml | 49 ++ .../conf/train_conf/bertLinear_zh.yaml | 42 ++ .../chinese/local/avg.sh | 23 + .../chinese/local/data.sh | 19 + .../chinese/local/test.sh | 32 + .../chinese/local/train.sh | 32 + .../punctuation_restoration/chinese/path.sh | 13 + .../punctuation_restoration/chinese/run.sh | 47 ++ .../punctuation_restoration/english/README.md | 23 + .../english/conf/data_conf/english.yaml | 7 + .../conf/train_conf/bertBLSTM_base_en.yaml | 47 ++ .../conf/train_conf/bertLinear_en.yaml | 39 ++ .../english/local/avg.sh | 23 + .../english/local/data.sh | 18 + .../english/local/test.sh | 32 + .../english/local/train.sh | 32 + .../punctuation_restoration/english/path.sh | 13 + .../punctuation_restoration/english/run.sh | 47 ++ text_processing/requirements.txt | 6 + .../punctuation_restoration/bin/avg_model.py | 112 +++ .../punctuation_restoration/bin/pre_data.py | 48 ++ .../punctuation_restoration/bin/test.py | 45 ++ .../punctuation_restoration/bin/train.py | 49 ++ .../punctuation_restoration/io/__init__.py | 13 + .../punctuation_restoration/io/collator.py | 64 ++ .../punctuation_restoration/io/common.py | 55 ++ .../punctuation_restoration/io/dataset.py | 310 +++++++++ .../model/BertBLSTM.py | 74 ++ .../model/BertLinear.py | 63 ++ .../punctuation_restoration/model/blstm.py | 89 +++ .../punctuation_restoration/model/lstm.py | 85 +++ .../modules/__init__.py | 13 + .../modules/activation.py | 141 ++++ .../modules/attention.py | 229 ++++++ .../punctuation_restoration/modules/crf.py | 366 ++++++++++ .../training/__init__.py | 13 + .../punctuation_restoration/training/loss.py | 98 +++ .../training/trainer.py | 651 ++++++++++++++++++ .../punctuation_restoration/utils/__init__.py | 13 + .../utils/checkpoint.py | 304 ++++++++ .../utils/default_parser.py | 74 ++ .../utils/layer_tools.py | 88 +++ .../punctuation_restoration/utils/mp_tools.py | 30 + .../utils/punct_pre.py | 163 +++++ .../punctuation_restoration/utils/utility.py | 81 +++ 50 files changed, 3923 insertions(+) create mode 100644 text_processing/.gitignore create mode 100644 text_processing/README.md create mode 100644 text_processing/examples/punctuation_restoration/chinese/README.md create mode 100644 text_processing/examples/punctuation_restoration/chinese/conf/blstm.yaml create mode 100644 text_processing/examples/punctuation_restoration/chinese/conf/data_conf/chinese.yaml create mode 100644 text_processing/examples/punctuation_restoration/chinese/conf/train_conf/bertBLSTM_zh.yaml create mode 100644 text_processing/examples/punctuation_restoration/chinese/conf/train_conf/bertLinear_zh.yaml create mode 100644 text_processing/examples/punctuation_restoration/chinese/local/avg.sh create mode 100644 text_processing/examples/punctuation_restoration/chinese/local/data.sh create mode 100644 text_processing/examples/punctuation_restoration/chinese/local/test.sh create mode 100644 text_processing/examples/punctuation_restoration/chinese/local/train.sh create mode 100644 text_processing/examples/punctuation_restoration/chinese/path.sh create mode 100644 text_processing/examples/punctuation_restoration/chinese/run.sh create mode 100644 text_processing/examples/punctuation_restoration/english/README.md create mode 100644 text_processing/examples/punctuation_restoration/english/conf/data_conf/english.yaml create mode 100644 text_processing/examples/punctuation_restoration/english/conf/train_conf/bertBLSTM_base_en.yaml create mode 100644 text_processing/examples/punctuation_restoration/english/conf/train_conf/bertLinear_en.yaml create mode 100644 text_processing/examples/punctuation_restoration/english/local/avg.sh create mode 100644 text_processing/examples/punctuation_restoration/english/local/data.sh create mode 100644 text_processing/examples/punctuation_restoration/english/local/test.sh create mode 100644 text_processing/examples/punctuation_restoration/english/local/train.sh create mode 100644 text_processing/examples/punctuation_restoration/english/path.sh create mode 100644 text_processing/examples/punctuation_restoration/english/run.sh create mode 100644 text_processing/requirements.txt create mode 100644 text_processing/speechtask/punctuation_restoration/bin/avg_model.py create mode 100644 text_processing/speechtask/punctuation_restoration/bin/pre_data.py create mode 100644 text_processing/speechtask/punctuation_restoration/bin/test.py create mode 100644 text_processing/speechtask/punctuation_restoration/bin/train.py create mode 100644 text_processing/speechtask/punctuation_restoration/io/__init__.py create mode 100644 text_processing/speechtask/punctuation_restoration/io/collator.py create mode 100644 text_processing/speechtask/punctuation_restoration/io/common.py create mode 100644 text_processing/speechtask/punctuation_restoration/io/dataset.py create mode 100644 text_processing/speechtask/punctuation_restoration/model/BertBLSTM.py create mode 100644 text_processing/speechtask/punctuation_restoration/model/BertLinear.py create mode 100644 text_processing/speechtask/punctuation_restoration/model/blstm.py create mode 100644 text_processing/speechtask/punctuation_restoration/model/lstm.py create mode 100644 text_processing/speechtask/punctuation_restoration/modules/__init__.py create mode 100644 text_processing/speechtask/punctuation_restoration/modules/activation.py create mode 100644 text_processing/speechtask/punctuation_restoration/modules/attention.py create mode 100644 text_processing/speechtask/punctuation_restoration/modules/crf.py create mode 100644 text_processing/speechtask/punctuation_restoration/training/__init__.py create mode 100644 text_processing/speechtask/punctuation_restoration/training/loss.py create mode 100644 text_processing/speechtask/punctuation_restoration/training/trainer.py create mode 100644 text_processing/speechtask/punctuation_restoration/utils/__init__.py create mode 100644 text_processing/speechtask/punctuation_restoration/utils/checkpoint.py create mode 100644 text_processing/speechtask/punctuation_restoration/utils/default_parser.py create mode 100644 text_processing/speechtask/punctuation_restoration/utils/layer_tools.py create mode 100644 text_processing/speechtask/punctuation_restoration/utils/mp_tools.py create mode 100644 text_processing/speechtask/punctuation_restoration/utils/punct_pre.py create mode 100644 text_processing/speechtask/punctuation_restoration/utils/utility.py diff --git a/text_processing/.gitignore b/text_processing/.gitignore new file mode 100644 index 00000000..e400141b --- /dev/null +++ b/text_processing/.gitignore @@ -0,0 +1,7 @@ +data +glove +.pyc +checkpoints +epoch +__pycache__ +glove.840B.300d.zip diff --git a/text_processing/README.md b/text_processing/README.md new file mode 100644 index 00000000..294af01d --- /dev/null +++ b/text_processing/README.md @@ -0,0 +1,25 @@ +# PaddleSpeechTask +A speech library to deal with a series of related front-end and back-end tasks + +## 环境 +- python==3.6.13 +- paddle==2.1.1 + +## 中/英文文本加标点任务 punctuation restoration: + +### 数据集: data +- 中文数据来源:data/chinese +1.iwlst2012zh +2.平凡的世界 + +- 英文数据来源: data/english +1.iwlst2012en + +- iwlst2012数据获取过程见data/README.md + +### 模型:speechtask/punctuation_restoration/model +1.BLSTM模型 + +2.BertLinear模型 + +3.BertBLSTM模型 diff --git a/text_processing/examples/punctuation_restoration/chinese/README.md b/text_processing/examples/punctuation_restoration/chinese/README.md new file mode 100644 index 00000000..1fcd954c --- /dev/null +++ b/text_processing/examples/punctuation_restoration/chinese/README.md @@ -0,0 +1,35 @@ +# 中文实验例程 +## 测试数据: +- IWLST2012中文:test2012 + +## 运行代码 +- 运行 `run.sh 0 0 conf/train_conf/bertBLSTM_zh.yaml 1 conf/data_conf/chinese.yaml ` + +## 实验结果: +- BertLinear + - 实验配置:conf/train_conf/bertLinear_zh.yaml + - 测试结果 + + | | COMMA | PERIOD | QUESTION | OVERALL | + |-----------|-----------|-----------|-----------|--------- | + |Precision | 0.425665 | 0.335190 | 0.698113 | 0.486323 | + |Recall | 0.511278 | 0.572108 | 0.787234 | 0.623540 | + |F1 | 0.464560 | 0.422717 | 0.740000 | 0.542426 | + +- BertBLSTM + - 实验配置:conf/train_conf/bertBLSTM_zh.yaml + - 测试结果 avg_1 + + | | COMMA | PERIOD | QUESTION | OVERALL | + |-----------|-----------|-----------|-----------|--------- | + |Precision | 0.469484 | 0.550604 | 0.801887 | 0.607325 | + |Recall | 0.580271 | 0.592408 | 0.817308 | 0.663329 | + |F1 | 0.519031 | 0.570741 | 0.809524 | 0.633099 | + + - BertBLSTM/avg_1测试标贝合成数据 + + | | COMMA | PERIOD | QUESTION | OVERALL | + |-----------|-----------|-----------|-----------|--------- | + |Precision | 0.217192 | 0.196339 | 0.820717 | 0.411416 | + |Recall | 0.205922 | 0.892531 | 0.416162 | 0.504872 | + |F1 | 0.211407 | 0.321873 | 0.552279 | 0.361853 | diff --git a/text_processing/examples/punctuation_restoration/chinese/conf/blstm.yaml b/text_processing/examples/punctuation_restoration/chinese/conf/blstm.yaml new file mode 100644 index 00000000..9b1a2e01 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/chinese/conf/blstm.yaml @@ -0,0 +1,34 @@ +data: + language: chinese + raw_path: /data4/mahaoxin/PaddleSpeechTask/data/chinese/PFDSJ #path to raw dataset + raw_train_file: train + raw_dev_file: dev + raw_test_file: test + vocab_file: vocab + punc_file: punc_vocab + save_path: data/PFDSJ #path to save dataset + seq_len: 100 + batch_size: 10 + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 + +model_type: blstm +model_params: + vocab_size: 3751 + embedding_size: 200 + hidden_size: 100 + num_layers: 3 + num_class: 5 + init_scale: 0.1 + +training: + n_epoch: 32 + lr: !!float 1e-4 + lr_decay: 1.0 + weight_decay: !!float 1e-06 + global_grad_clip: 5.0 + log_interval: 10 + + + diff --git a/text_processing/examples/punctuation_restoration/chinese/conf/data_conf/chinese.yaml b/text_processing/examples/punctuation_restoration/chinese/conf/data_conf/chinese.yaml new file mode 100644 index 00000000..191bfd3e --- /dev/null +++ b/text_processing/examples/punctuation_restoration/chinese/conf/data_conf/chinese.yaml @@ -0,0 +1,7 @@ +type: chinese +raw_path: /data4/mahaoxin/PaddleSpeechTask/data/chinese/iwslt2012_zh #path to raw dataset +raw_train_file: iwslt2012_train_zh +raw_dev_file: iwslt2010_dev_zh +raw_test_file: biaobei_asr +punc_file: punc_vocab +save_path: data/iwslt2012_zh #path to save dataset \ No newline at end of file diff --git a/text_processing/examples/punctuation_restoration/chinese/conf/train_conf/bertBLSTM_zh.yaml b/text_processing/examples/punctuation_restoration/chinese/conf/train_conf/bertBLSTM_zh.yaml new file mode 100644 index 00000000..d1f58aac --- /dev/null +++ b/text_processing/examples/punctuation_restoration/chinese/conf/train_conf/bertBLSTM_zh.yaml @@ -0,0 +1,49 @@ +data: + dataset_type: Bert + train_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/train + dev_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/dev + test_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/test2012_revise + data_params: + pretrained_token: bert-base-chinese + punc_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/punc_vocab + seq_len: 100 + batch_size: 64 + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 + +checkpoint: + kbest_n: 5 + latest_n: 10 + metric_type: F1 + + +model_type: BertBLSTM +model_params: + pretrained_token: bert-base-chinese + output_size: 4 + dropout: 0.0 + bert_size: 768 + blstm_size: 128 + num_blstm_layers: 2 + init_scale: 0.1 + +# model_type: BertChLinear +# model_params: bert-base-chinese +# pretrained_token: +# output_size: 4 +# dropout: 0.0 +# bert_size: 768 + +training: + n_epoch: 100 + lr: !!float 1e-5 + lr_decay: 1.0 + weight_decay: !!float 1e-06 + global_grad_clip: 5.0 + log_interval: 10 + log_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/log/bertBLSTM_zh0812.log + +testing: + log_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/log/test_bertBLSTM_zh0812.log + diff --git a/text_processing/examples/punctuation_restoration/chinese/conf/train_conf/bertLinear_zh.yaml b/text_processing/examples/punctuation_restoration/chinese/conf/train_conf/bertLinear_zh.yaml new file mode 100644 index 00000000..c422e840 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/chinese/conf/train_conf/bertLinear_zh.yaml @@ -0,0 +1,42 @@ +data: + dataset_type: Bert + train_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/train + dev_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/dev + test_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/test2012 + data_params: + pretrained_token: bert-base-chinese + punc_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/punc_vocab + seq_len: 100 + batch_size: 32 + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 + +checkpoint: + kbest_n: 10 + latest_n: 10 + metric_type: F1 + + +model_type: BertLinear +model_params: + pretrained_token: bert-base-uncased + output_size: 4 + dropout: 0.2 + bert_size: 768 + hiddensize: 1568 + + +training: + n_epoch: 50 + lr: !!float 1e-5 + lr_decay: 1.0 + weight_decay: !!float 1e-06 + global_grad_clip: 5.0 + log_interval: 10 + log_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/log/train_linear0812.log + +testing: + log_interval: 10 + log_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/log/test_linear0812.log + diff --git a/text_processing/examples/punctuation_restoration/chinese/local/avg.sh b/text_processing/examples/punctuation_restoration/chinese/local/avg.sh new file mode 100644 index 00000000..b8c14c66 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/chinese/local/avg.sh @@ -0,0 +1,23 @@ +#! /usr/bin/env bash + +if [ $# != 2 ]; then + echo "usage: ${0} ckpt_dir avg_num" + exit -1 +fi + +ckpt_dir=${1} +average_num=${2} +decode_checkpoint=${ckpt_dir}/avg_${average_num}.pdparams + +python3 -u ${BIN_DIR}/avg_model.py \ +--dst_model ${decode_checkpoint} \ +--ckpt_dir ${ckpt_dir} \ +--num ${average_num} \ +--val_best + +if [ $? -ne 0 ]; then + echo "Failed in avg ckpt!" + exit 1 +fi + +exit 0 \ No newline at end of file diff --git a/text_processing/examples/punctuation_restoration/chinese/local/data.sh b/text_processing/examples/punctuation_restoration/chinese/local/data.sh new file mode 100644 index 00000000..aff7203c --- /dev/null +++ b/text_processing/examples/punctuation_restoration/chinese/local/data.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +if [ $# != 1 ];then + echo "usage: ${0} data_pre_conf" + echo $1 + exit -1 +fi + +data_pre_conf=$1 + +python3 -u ${BIN_DIR}/pre_data.py \ +--config ${data_pre_conf} + +if [ $? -ne 0 ]; then + echo "Failed in training!" + exit 1 +fi + +exit 0 diff --git a/text_processing/examples/punctuation_restoration/chinese/local/test.sh b/text_processing/examples/punctuation_restoration/chinese/local/test.sh new file mode 100644 index 00000000..6db75ca2 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/chinese/local/test.sh @@ -0,0 +1,32 @@ + +#!/bin/bash + +if [ $# != 2 ];then + echo "usage: ${0} config_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +device=gpu +if [ ${ngpu} == 0 ];then + device=cpu +fi +config_path=$1 +ckpt_prefix=$2 + + +python3 -u ${BIN_DIR}/test.py \ +--device ${device} \ +--nproc 1 \ +--config ${config_path} \ +--result_file ${ckpt_prefix}.rsl \ +--checkpoint_path ${ckpt_prefix} + +if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 +fi + +exit 0 diff --git a/text_processing/examples/punctuation_restoration/chinese/local/train.sh b/text_processing/examples/punctuation_restoration/chinese/local/train.sh new file mode 100644 index 00000000..f6bd2c98 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/chinese/local/train.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +if [ $# != 2 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +ckpt_name=$2 + +device=gpu +if [ ${ngpu} == 0 ];then + device=cpu +fi + +mkdir -p exp + +python3 -u ${BIN_DIR}/train.py \ +--device ${device} \ +--nproc ${ngpu} \ +--config ${config_path} \ +--output exp/${ckpt_name} + +if [ $? -ne 0 ]; then + echo "Failed in training!" + exit 1 +fi + +exit 0 diff --git a/text_processing/examples/punctuation_restoration/chinese/path.sh b/text_processing/examples/punctuation_restoration/chinese/path.sh new file mode 100644 index 00000000..8154cc78 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/chinese/path.sh @@ -0,0 +1,13 @@ +export MAIN_ROOT=${PWD}/../../../ + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + + +export BIN_DIR=${MAIN_ROOT}/speechtask/punctuation_restoration/bin diff --git a/text_processing/examples/punctuation_restoration/chinese/run.sh b/text_processing/examples/punctuation_restoration/chinese/run.sh new file mode 100644 index 00000000..bb3d25d4 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/chinese/run.sh @@ -0,0 +1,47 @@ +#!/bin/bash +set -e +source path.sh + + +## stage, gpu, data_pre_config, train_config, avg_num +if [ $# -lt 4 ]; then + echo "usage: bash ./run.sh stage gpu train_config avg_num data_config" + echo "eg: bash ./run.sh 0 0 train_config 1 data_config " + exit -1 +fi + +stage=$1 +stop_stage=100 +gpus=$2 +conf_path=$3 +avg_num=$4 +avg_ckpt=avg_${avg_num} +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +echo "checkpoint name ${ckpt}" + +if [ $stage -le 0 ]; then + if [ $# -eq 5 ]; then + data_pre_conf=$5 + # prepare data + bash ./local/data.sh ${data_pre_conf} || exit -1 + else + echo "data_pre_conf is not exist!" + exit -1 + fi +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `exp` dir + CUDA_VISIBLE_DEVICES=${gpus} bash ./local/train.sh ${conf_path} ${ckpt} +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # avg n best model + bash ./local/avg.sh exp/${ckpt}/checkpoints ${avg_num} +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # test ckpt avg_n + CUDA_VISIBLE_DEVICES=${gpus} bash ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 +fi diff --git a/text_processing/examples/punctuation_restoration/english/README.md b/text_processing/examples/punctuation_restoration/english/README.md new file mode 100644 index 00000000..7955bb7d --- /dev/null +++ b/text_processing/examples/punctuation_restoration/english/README.md @@ -0,0 +1,23 @@ +# 英文实验例程 +## 测试数据: +- IWLST2012英文:test2011 + +## 运行代码 +- 运行 `run.sh 0 0 conf/train_conf/bertBLSTM_base_en.yaml 1 conf/data_conf/english.yaml ` + + +## 相关论文实验结果: +> * Nagy, Attila, Bence Bial, and Judit Ács. "Automatic punctuation restoration with BERT models." arXiv preprint arXiv:2101.07343 (2021)* +> + + +## 实验结果: +- BertBLSTM + - 实验配置:conf/train_conf/bertLinear_en.yaml + - 测试结果:exp/bertLinear_enRe/checkpoints/3.pdparams + + | | COMMA | PERIOD | QUESTION | OVERALL | + |-----------|-----------|-----------|-----------|--------- | + |Precision |0.667910 |0.715778 |0.822222 |0.735304 | + |Recall |0.755274 |0.868188 |0.804348 |0.809270 | + |F1 |0.708911 |0.784651 |0.813187 |0.768916 | diff --git a/text_processing/examples/punctuation_restoration/english/conf/data_conf/english.yaml b/text_processing/examples/punctuation_restoration/english/conf/data_conf/english.yaml new file mode 100644 index 00000000..44834f28 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/english/conf/data_conf/english.yaml @@ -0,0 +1,7 @@ +type: english +raw_path: /data4/mahaoxin/PaddleSpeechTask/data/english/iwslt2012_en #path to raw dataset +raw_train_file: iwslt2012_train_en +raw_dev_file: iwslt2010_dev_en +raw_test_file: iwslt2011_test_en +punc_file: punc_vocab +save_path: data/iwslt2012_en #path to save dataset \ No newline at end of file diff --git a/text_processing/examples/punctuation_restoration/english/conf/train_conf/bertBLSTM_base_en.yaml b/text_processing/examples/punctuation_restoration/english/conf/train_conf/bertBLSTM_base_en.yaml new file mode 100644 index 00000000..7f4383d4 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/english/conf/train_conf/bertBLSTM_base_en.yaml @@ -0,0 +1,47 @@ +data: + dataset_type: Bert + train_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/train + dev_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/dev + test_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/test2011 + data_params: + pretrained_token: bert-base-uncased #english + punc_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/punc_vocab + seq_len: 50 + batch_size: 32 + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 + +checkpoint: + kbest_n: 10 + latest_n: 10 + +model_type: BertBLSTM +model_params: + pretrained_token: bert-base-uncased + output_size: 4 + dropout: 0.0 + bert_size: 768 + blstm_size: 128 + num_blstm_layers: 2 + init_scale: 0.2 +# model_type: BertChLinear +# model_params: +# pretrained_token: bert-large-uncased +# output_size: 4 +# dropout: 0.0 +# bert_size: 768 + +training: + n_epoch: 100 + lr: !!float 1e-5 + lr_decay: 1.0 + weight_decay: !!float 1e-06 + global_grad_clip: 5.0 + log_interval: 10 + log_path: log/bertBLSTM_base0812.log + +testing: + log_path: log/testbertBLSTM_base0812.log + + diff --git a/text_processing/examples/punctuation_restoration/english/conf/train_conf/bertLinear_en.yaml b/text_processing/examples/punctuation_restoration/english/conf/train_conf/bertLinear_en.yaml new file mode 100644 index 00000000..8cac9889 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/english/conf/train_conf/bertLinear_en.yaml @@ -0,0 +1,39 @@ +data: + dataset_type: Bert + train_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/train + dev_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/dev + test_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/test2011 + data_params: + pretrained_token: bert-base-uncased #english + punc_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/punc_vocab + seq_len: 100 + batch_size: 32 + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 + +checkpoint: + kbest_n: 10 + latest_n: 10 + +model_type: BertLinear +model_params: + pretrained_token: bert-base-uncased + output_size: 4 + dropout: 0.2 + bert_size: 768 + hiddensize: 1568 + +training: + n_epoch: 20 + lr: !!float 1e-5 + lr_decay: 1.0 + weight_decay: !!float 1e-06 + global_grad_clip: 3.0 + log_interval: 10 + log_path: log/train_linear0820.log + +testing: + log_path: log/test2011_linear0820.log + + diff --git a/text_processing/examples/punctuation_restoration/english/local/avg.sh b/text_processing/examples/punctuation_restoration/english/local/avg.sh new file mode 100644 index 00000000..b8c14c66 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/english/local/avg.sh @@ -0,0 +1,23 @@ +#! /usr/bin/env bash + +if [ $# != 2 ]; then + echo "usage: ${0} ckpt_dir avg_num" + exit -1 +fi + +ckpt_dir=${1} +average_num=${2} +decode_checkpoint=${ckpt_dir}/avg_${average_num}.pdparams + +python3 -u ${BIN_DIR}/avg_model.py \ +--dst_model ${decode_checkpoint} \ +--ckpt_dir ${ckpt_dir} \ +--num ${average_num} \ +--val_best + +if [ $? -ne 0 ]; then + echo "Failed in avg ckpt!" + exit 1 +fi + +exit 0 \ No newline at end of file diff --git a/text_processing/examples/punctuation_restoration/english/local/data.sh b/text_processing/examples/punctuation_restoration/english/local/data.sh new file mode 100644 index 00000000..1b0c62b1 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/english/local/data.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +if [ $# != 1 ];then + echo "usage: ${0} config_path" + exit -1 +fi + +config_path=$1 + +python3 -u ${BIN_DIR}/pre_data.py \ +--config ${config_path} + +if [ $? -ne 0 ]; then + echo "Failed in training!" + exit 1 +fi + +exit 0 diff --git a/text_processing/examples/punctuation_restoration/english/local/test.sh b/text_processing/examples/punctuation_restoration/english/local/test.sh new file mode 100644 index 00000000..6db75ca2 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/english/local/test.sh @@ -0,0 +1,32 @@ + +#!/bin/bash + +if [ $# != 2 ];then + echo "usage: ${0} config_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +device=gpu +if [ ${ngpu} == 0 ];then + device=cpu +fi +config_path=$1 +ckpt_prefix=$2 + + +python3 -u ${BIN_DIR}/test.py \ +--device ${device} \ +--nproc 1 \ +--config ${config_path} \ +--result_file ${ckpt_prefix}.rsl \ +--checkpoint_path ${ckpt_prefix} + +if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 +fi + +exit 0 diff --git a/text_processing/examples/punctuation_restoration/english/local/train.sh b/text_processing/examples/punctuation_restoration/english/local/train.sh new file mode 100644 index 00000000..f6bd2c98 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/english/local/train.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +if [ $# != 2 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +ckpt_name=$2 + +device=gpu +if [ ${ngpu} == 0 ];then + device=cpu +fi + +mkdir -p exp + +python3 -u ${BIN_DIR}/train.py \ +--device ${device} \ +--nproc ${ngpu} \ +--config ${config_path} \ +--output exp/${ckpt_name} + +if [ $? -ne 0 ]; then + echo "Failed in training!" + exit 1 +fi + +exit 0 diff --git a/text_processing/examples/punctuation_restoration/english/path.sh b/text_processing/examples/punctuation_restoration/english/path.sh new file mode 100644 index 00000000..8154cc78 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/english/path.sh @@ -0,0 +1,13 @@ +export MAIN_ROOT=${PWD}/../../../ + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + + +export BIN_DIR=${MAIN_ROOT}/speechtask/punctuation_restoration/bin diff --git a/text_processing/examples/punctuation_restoration/english/run.sh b/text_processing/examples/punctuation_restoration/english/run.sh new file mode 100644 index 00000000..bb3d25d4 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/english/run.sh @@ -0,0 +1,47 @@ +#!/bin/bash +set -e +source path.sh + + +## stage, gpu, data_pre_config, train_config, avg_num +if [ $# -lt 4 ]; then + echo "usage: bash ./run.sh stage gpu train_config avg_num data_config" + echo "eg: bash ./run.sh 0 0 train_config 1 data_config " + exit -1 +fi + +stage=$1 +stop_stage=100 +gpus=$2 +conf_path=$3 +avg_num=$4 +avg_ckpt=avg_${avg_num} +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +echo "checkpoint name ${ckpt}" + +if [ $stage -le 0 ]; then + if [ $# -eq 5 ]; then + data_pre_conf=$5 + # prepare data + bash ./local/data.sh ${data_pre_conf} || exit -1 + else + echo "data_pre_conf is not exist!" + exit -1 + fi +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `exp` dir + CUDA_VISIBLE_DEVICES=${gpus} bash ./local/train.sh ${conf_path} ${ckpt} +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # avg n best model + bash ./local/avg.sh exp/${ckpt}/checkpoints ${avg_num} +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # test ckpt avg_n + CUDA_VISIBLE_DEVICES=${gpus} bash ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 +fi diff --git a/text_processing/requirements.txt b/text_processing/requirements.txt new file mode 100644 index 00000000..685ab029 --- /dev/null +++ b/text_processing/requirements.txt @@ -0,0 +1,6 @@ +numpy +pyyaml +tensorboardX +tqdm +ujson +yacs diff --git a/text_processing/speechtask/punctuation_restoration/bin/avg_model.py b/text_processing/speechtask/punctuation_restoration/bin/avg_model.py new file mode 100644 index 00000000..a012e258 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/bin/avg_model.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import glob +import json +import os + +import numpy as np +import paddle + + +def main(args): + paddle.set_device('cpu') + + val_scores = [] + beat_val_scores = [] + selected_epochs = [] + if args.val_best: + jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json') + for y in jsons: + with open(y, 'r') as f: + dic_json = json.load(f) + loss = dic_json['F1'] + epoch = dic_json['epoch'] + if epoch >= args.min_epoch and epoch <= args.max_epoch: + val_scores.append((epoch, loss)) + + val_scores = np.array(val_scores) + sort_idx = np.argsort(val_scores[:, 1]) + sorted_val_scores = val_scores[sort_idx] + path_list = [ + args.ckpt_dir + '/{}.pdparams'.format(int(epoch)) + for epoch in sorted_val_scores[:args.num, 0] + ] + + beat_val_scores = sorted_val_scores[:args.num, 1] + selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64) + print("best val scores = " + str(beat_val_scores)) + print("selected epochs = " + str(selected_epochs)) + else: + path_list = glob.glob(f'{args.ckpt_dir}/[!avg][!final]*.pdparams') + path_list = sorted(path_list, key=os.path.getmtime) + path_list = path_list[-args.num:] + + print(path_list) + + avg = None + num = args.num + assert num == len(path_list) + for path in path_list: + print(f'Processing {path}') + states = paddle.load(path) + if avg is None: + avg = states + else: + for k in avg.keys(): + avg[k] += states[k] + # average + for k in avg.keys(): + if avg[k] is not None: + avg[k] /= num + + paddle.save(avg, args.dst_model) + print(f'Saving to {args.dst_model}') + + meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json' + with open(meta_path, 'w') as f: + data = json.dumps({ + "avg_ckpt": args.dst_model, + "ckpt": path_list, + "epoch": selected_epochs.tolist(), + "val_loss": beat_val_scores.tolist(), + }) + f.write(data + "\n") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='average model') + parser.add_argument('--dst_model', required=True, help='averaged model') + parser.add_argument( + '--ckpt_dir', required=True, help='ckpt model dir for average') + parser.add_argument( + '--val_best', action="store_true", help='averaged model') + parser.add_argument( + '--num', default=5, type=int, help='nums for averaged model') + parser.add_argument( + '--min_epoch', + default=0, + type=int, + help='min epoch used for averaging model') + parser.add_argument( + '--max_epoch', + default=65536, # Big enough + type=int, + help='max epoch used for averaging model') + + args = parser.parse_args() + print(args) + + main(args) diff --git a/text_processing/speechtask/punctuation_restoration/bin/pre_data.py b/text_processing/speechtask/punctuation_restoration/bin/pre_data.py new file mode 100644 index 00000000..a074d7e3 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/bin/pre_data.py @@ -0,0 +1,48 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data preparation for punctuation_restoration task.""" +import yaml +from speechtask.punctuation_restoration.utils.default_parser import default_argument_parser +from speechtask.punctuation_restoration.utils.punct_pre import process_chinese_pure_senetence +from speechtask.punctuation_restoration.utils.punct_pre import process_english_pure_senetence +from speechtask.punctuation_restoration.utils.utility import print_arguments + + +# create dataset from raw data files +def main(config, args): + print("Start preparing data from raw data.") + if (config['type'] == 'chinese'): + process_chinese_pure_senetence(config) + elif (config['type'] == 'english'): + print('english!!!!') + process_english_pure_senetence(config) + else: + print('Error: Type should be chinese or english!!!!') + raise ValueError('Type should be chinese or english') + + print("Finish preparing data.") + + +if __name__ == "__main__": + parser = default_argument_parser() + args = parser.parse_args() + print_arguments(args, globals()) + + # https://yaml.org/type/float.html + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + # config.freeze() + print(config) + main(config, args) diff --git a/text_processing/speechtask/punctuation_restoration/bin/test.py b/text_processing/speechtask/punctuation_restoration/bin/test.py new file mode 100644 index 00000000..17892fdb --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/bin/test.py @@ -0,0 +1,45 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation for model.""" +import yaml +from speechtask.punctuation_restoration.training.trainer import Tester +from speechtask.punctuation_restoration.utils.default_parser import default_argument_parser +from speechtask.punctuation_restoration.utils.utility import print_arguments + + +def main_sp(config, args): + exp = Tester(config, args) + exp.setup() + exp.run_test() + + +def main(config, args): + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + args = parser.parse_args() + print_arguments(args, globals()) + + # https://yaml.org/type/float.html + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + main(config, args) diff --git a/text_processing/speechtask/punctuation_restoration/bin/train.py b/text_processing/speechtask/punctuation_restoration/bin/train.py new file mode 100644 index 00000000..1ffd79b7 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/bin/train.py @@ -0,0 +1,49 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Trainer for punctuation_restoration task.""" +import yaml +from paddle import distributed as dist +from speechtask.punctuation_restoration.training.trainer import Trainer +from speechtask.punctuation_restoration.utils.default_parser import default_argument_parser +from speechtask.punctuation_restoration.utils.utility import print_arguments + + +def main_sp(config, args): + exp = Trainer(config, args) + exp.setup() + exp.run() + + +def main(config, args): + if args.device == "gpu" and args.nprocs > 1: + dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) + else: + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + args = parser.parse_args() + print_arguments(args, globals()) + + # https://yaml.org/type/float.html + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + main(config, args) diff --git a/text_processing/speechtask/punctuation_restoration/io/__init__.py b/text_processing/speechtask/punctuation_restoration/io/__init__.py new file mode 100644 index 00000000..185a92b8 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/io/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/text_processing/speechtask/punctuation_restoration/io/collator.py b/text_processing/speechtask/punctuation_restoration/io/collator.py new file mode 100644 index 00000000..5b63b584 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/io/collator.py @@ -0,0 +1,64 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +__all__ = ["TextCollator"] + + +class TextCollator(): + def __init__(self, padding_value): + self.padding_value = padding_value + + def __call__(self, batch): + """batch examples + Args: + batch ([List]): batch is (text, punctuation) + text (List[int] ) shape (batch, L) + punctuation (List[int] or str): shape (batch, L) + Returns: + tuple(text, punctuation): batched data. + text : (B, Lmax) + punctuation : (B, Lmax) + """ + texts = [] + punctuations = [] + for text, punctuation in batch: + + texts.append(text) + punctuations.append(punctuation) + + #[B, T, D] + x_pad = self.pad_sequence(texts).astype(np.int64) + # print(x_pad.shape) + # pad_list(audios, 0.0).astype(np.float32) + # ilens = np.array(audio_lens).astype(np.int64) + y_pad = self.pad_sequence(punctuations).astype(np.int64) + # print(y_pad.shape) + # olens = np.array(text_lens).astype(np.int64) + return x_pad, y_pad + + def pad_sequence(self, sequences): + # assuming trailing dimensions and type of all the Tensors + # in sequences are same and fetching those from sequences[0] + max_len = max([len(s) for s in sequences]) + out_dims = (len(sequences), max_len) + + out_tensor = np.full(out_dims, + self.padding_value) #, dtype=sequences[0].dtype) + for i, tensor in enumerate(sequences): + length = len(tensor) + # use index notation to prevent duplicate references to the tensor + out_tensor[i, :length] = tensor + + return out_tensor diff --git a/text_processing/speechtask/punctuation_restoration/io/common.py b/text_processing/speechtask/punctuation_restoration/io/common.py new file mode 100644 index 00000000..3ed4a604 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/io/common.py @@ -0,0 +1,55 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import codecs +import re +import unicodedata + +import ujson + +PAD = "" +UNK = "" +NUM = "" +END = "" +SPACE = "_SPACE" + + +def write_json(filename, dataset): + with codecs.open(filename, mode="w", encoding="utf-8") as f: + ujson.dump(dataset, f) + + +def word_convert(word, keep_number=True, lowercase=True): + if not keep_number: + if is_digit(word): + word = NUM + if lowercase: + word = word.lower() + return word + + +def is_digit(word): + try: + float(word) + return True + except ValueError: + pass + try: + unicodedata.numeric(word) + return True + except (TypeError, ValueError): + pass + result = re.compile(r'^[-+]?[0-9]+,[0-9]+$').match(word) + if result: + return True + return False diff --git a/text_processing/speechtask/punctuation_restoration/io/dataset.py b/text_processing/speechtask/punctuation_restoration/io/dataset.py new file mode 100644 index 00000000..17c13c38 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/io/dataset.py @@ -0,0 +1,310 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import random + +import numpy as np +import paddle +from paddle.io import Dataset +from paddlenlp.transformers import BertTokenizer +# from speechtask.punctuation_restoration.utils.punct_prepro import load_dataset + +__all__ = ["PuncDataset", "PuncDatasetFromBertTokenizer"] + + +class PuncDataset(Dataset): + """Representing a Dataset + superclass + ---------- + data.Dataset : + Dataset is a abstract class, representing the real data. + """ + + def __init__(self, train_path, vocab_path, punc_path, seq_len=100): + # 检查文件是否存在 + print(train_path) + print(vocab_path) + assert os.path.exists(train_path), "train文件不存在" + assert os.path.exists(vocab_path), "词典文件不存在" + assert os.path.exists(punc_path), "标点文件不存在" + self.seq_len = seq_len + + self.word2id = self.load_vocab( + vocab_path, extra_word_list=['', '']) + self.id2word = {v: k for k, v in self.word2id.items()} + self.punc2id = self.load_vocab(punc_path, extra_word_list=[" "]) + self.id2punc = {k: v for (v, k) in self.punc2id.items()} + + tmp_seqs = open(train_path, encoding='utf-8').readlines() + self.txt_seqs = [i for seq in tmp_seqs for i in seq.split()] + # print(self.txt_seqs[:10]) + # with open('./txt_seq', 'w', encoding='utf-8') as w: + # print(self.txt_seqs, file=w) + self.preprocess(self.txt_seqs) + print('---punc-') + print(self.punc2id) + + def __len__(self): + """return the sentence nums in .txt + """ + return self.in_len + + def __getitem__(self, index): + """返回指定索引的张量对 (输入文本id的序列 , 其对应的标点id序列) + Parameters + ---------- + index : int 索引 + """ + return self.input_data[index], self.label[index] + + def load_vocab(self, vocab_path, extra_word_list=[], encoding='utf-8'): + n = len(extra_word_list) + with open(vocab_path, encoding='utf-8') as vf: + vocab = {word.strip(): i + n for i, word in enumerate(vf)} + for i, word in enumerate(extra_word_list): + vocab[word] = i + return vocab + + def preprocess(self, txt_seqs: list): + """将文本转为单词和应预测标点的id pair + Parameters + ---------- + txt : 文本 + 文本每个单词跟随一个空格,符号也跟一个空格 + """ + input_data = [] + label = [] + input_r = [] + label_r = [] + # txt_seqs is a list like: ['char', 'char', 'char', '*,*', 'char', ......] + count = 0 + length = len(txt_seqs) + for token in txt_seqs: + count += 1 + if count == length: + break + if token in self.punc2id: + continue + punc = txt_seqs[count] + if punc not in self.punc2id: + # print('标点{}:'.format(count), self.punc2id[" "]) + label.append(self.punc2id[" "]) + input_data.append( + self.word2id.get(token, self.word2id[""])) + input_r.append(token) + label_r.append(' ') + else: + # print('标点{}:'.format(count), self.punc2id[punc]) + label.append(self.punc2id[punc]) + input_data.append( + self.word2id.get(token, self.word2id[""])) + input_r.append(token) + label_r.append(punc) + if len(input_data) != len(label): + assert 'error: length input_data != label' + # code below is for using 100 as a hidden size + print(len(input_data)) + self.in_len = len(input_data) // self.seq_len + len_tmp = self.in_len * self.seq_len + input_data = input_data[:len_tmp] + label = label[:len_tmp] + + self.input_data = paddle.to_tensor( + np.array(input_data, dtype='int64').reshape(-1, self.seq_len)) + self.label = paddle.to_tensor( + np.array(label, dtype='int64').reshape(-1, self.seq_len)) + + +# unk_token='[UNK]' +# sep_token='[SEP]' +# pad_token='[PAD]' +# cls_token='[CLS]' +# mask_token='[MASK]' + + +class PuncDatasetFromBertTokenizer(Dataset): + """Representing a Dataset + superclass + ---------- + data.Dataset : + Dataset is a abstract class, representing the real data. + """ + + def __init__(self, + train_path, + is_eval, + pretrained_token, + punc_path, + seq_len=100): + # 检查文件是否存在 + print(train_path) + self.tokenizer = BertTokenizer.from_pretrained( + pretrained_token, do_lower_case=True) + self.paddingID = self.tokenizer.pad_token_id + assert os.path.exists(train_path), "train文件不存在" + assert os.path.exists(punc_path), "标点文件不存在" + self.seq_len = seq_len + + self.punc2id = self.load_vocab(punc_path, extra_word_list=[" "]) + self.id2punc = {k: v for (v, k) in self.punc2id.items()} + + tmp_seqs = open(train_path, encoding='utf-8').readlines() + self.txt_seqs = [i for seq in tmp_seqs for i in seq.split()] + # print(self.txt_seqs[:10]) + # with open('./txt_seq', 'w', encoding='utf-8') as w: + # print(self.txt_seqs, file=w) + if (is_eval): + self.preprocess(self.txt_seqs) + else: + self.preprocess_shift(self.txt_seqs) + print("data len: %d" % (len(self.input_data))) + print('---punc-') + print(self.punc2id) + + def __len__(self): + """return the sentence nums in .txt + """ + return self.in_len + + def __getitem__(self, index): + """返回指定索引的张量对 (输入文本id的序列 , 其对应的标点id序列) + Parameters + ---------- + index : int 索引 + """ + return self.input_data[index], self.label[index] + + def load_vocab(self, vocab_path, extra_word_list=[], encoding='utf-8'): + n = len(extra_word_list) + with open(vocab_path, encoding='utf-8') as vf: + vocab = {word.strip(): i + n for i, word in enumerate(vf)} + for i, word in enumerate(extra_word_list): + vocab[word] = i + return vocab + + def preprocess(self, txt_seqs: list): + """将文本转为单词和应预测标点的id pair + Parameters + ---------- + txt : 文本 + 文本每个单词跟随一个空格,符号也跟一个空格 + """ + input_data = [] + label = [] + # txt_seqs is a list like: ['char', 'char', 'char', '*,*', 'char', ......] + count = 0 + for i in range(len(txt_seqs) - 1): + word = txt_seqs[i] + punc = txt_seqs[i + 1] + if word in self.punc2id: + continue + + token = self.tokenizer(word) + x = token["input_ids"][1:-1] + input_data.extend(x) + + for i in range(len(x) - 1): + label.append(self.punc2id[" "]) + + if punc not in self.punc2id: + # print('标点{}:'.format(count), self.punc2id[" "]) + label.append(self.punc2id[" "]) + else: + label.append(self.punc2id[punc]) + + if len(input_data) != len(label): + assert 'error: length input_data != label' + # code below is for using 100 as a hidden size + + # print(len(input_data[0])) + # print(len(label)) + self.in_len = len(input_data) // self.seq_len + len_tmp = self.in_len * self.seq_len + input_data = input_data[:len_tmp] + label = label[:len_tmp] + # # print(input_data) + # print(type(input_data)) + # tmp=np.array(input_data) + # print('--~~~~~~~~~~~~~') + # print(type(tmp)) + # print(tmp.shape) + self.input_data = paddle.to_tensor( + np.array(input_data, dtype='int64').reshape( + -1, self.seq_len)) #, dtype='int64' + self.label = paddle.to_tensor( + np.array(label, dtype='int64').reshape( + -1, self.seq_len)) #, dtype='int64' + + def preprocess_shift(self, txt_seqs: list): + """将文本转为单词和应预测标点的id pair + Parameters + ---------- + txt : 文本 + 文本每个单词跟随一个空格,符号也跟一个空格 + """ + input_data = [] + label = [] + # txt_seqs is a list like: ['char', 'char', 'char', '*,*', 'char', ......] + count = 0 + for i in range(len(txt_seqs) - 1): + word = txt_seqs[i] + punc = txt_seqs[i + 1] + if word in self.punc2id: + continue + + token = self.tokenizer(word) + x = token["input_ids"][1:-1] + input_data.extend(x) + + for i in range(len(x) - 1): + label.append(self.punc2id[" "]) + + if punc not in self.punc2id: + # print('标点{}:'.format(count), self.punc2id[" "]) + label.append(self.punc2id[" "]) + else: + label.append(self.punc2id[punc]) + + if len(input_data) != len(label): + assert 'error: length input_data != label' + + # print(len(input_data[0])) + # print(len(label)) + start = 0 + processed_data = [] + processed_label = [] + while (start < len(input_data) - self.seq_len): + # end=start+self.seq_len + end = random.randint(start + self.seq_len // 2, + start + self.seq_len) + processed_data.append(input_data[start:end]) + processed_label.append(label[start:end]) + + start = start + random.randint(1, self.seq_len // 2) + + self.in_len = len(processed_data) + # # print(input_data) + # print(type(input_data)) + # tmp=np.array(input_data) + # print('--~~~~~~~~~~~~~') + # print(type(tmp)) + # print(tmp.shape) + self.input_data = processed_data + #paddle.to_tensor(np.array(processed_data, dtype='int64')) #, dtype='int64' + self.label = processed_label + #paddle.to_tensor(np.array(processed_label, dtype='int64')) #, dtype='int64' + + +if __name__ == '__main__': + dataset = PuncDataset() diff --git a/text_processing/speechtask/punctuation_restoration/model/BertBLSTM.py b/text_processing/speechtask/punctuation_restoration/model/BertBLSTM.py new file mode 100644 index 00000000..bc953adf --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/model/BertBLSTM.py @@ -0,0 +1,74 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn +import paddle.nn.initializer as I +from paddlenlp.transformers import BertForTokenClassification + + +class BertBLSTMPunc(nn.Layer): + def __init__(self, + pretrained_token="bert-large-uncased", + output_size=4, + dropout=0.0, + bert_size=768, + blstm_size=128, + num_blstm_layers=2, + init_scale=0.1): + super(BertBLSTMPunc, self).__init__() + self.output_size = output_size + self.bert = BertForTokenClassification.from_pretrained( + pretrained_token, num_classes=bert_size) + # self.bert_vocab_size = vocab_size + # self.bn = nn.BatchNorm1d(segment_size*self.bert_vocab_size) + # self.fc = nn.Linear(segment_size*self.bert_vocab_size, output_size) + + self.lstm = nn.LSTM( + input_size=bert_size, + hidden_size=blstm_size, + num_layers=num_blstm_layers, + direction="bidirect", + weight_ih_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale)), + weight_hh_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale))) + + # NOTE dense*2 使用bert中间层 dense hidden_state self.bert_size + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(blstm_size * 2, output_size) + self.softmax = nn.Softmax() + + def forward(self, x): + # print('input :', x.shape) + x = self.bert(x) #[0] + # print('after bert :', x.shape) + + y, (_, _) = self.lstm(x) + # print('after lstm :', y.shape) + y = self.fc(self.dropout(y)) + y = paddle.reshape(y, shape=[-1, self.output_size]) + # print('after fc :', y.shape) + + logit = self.softmax(y) + # print('after softmax :', logit.shape) + + return y, logit + + +if __name__ == '__main__': + print('start model') + model = BertBLSTMPunc() + x = paddle.randint(low=0, high=40, shape=[2, 5]) + print(x) + y, logit = model(x) diff --git a/text_processing/speechtask/punctuation_restoration/model/BertLinear.py b/text_processing/speechtask/punctuation_restoration/model/BertLinear.py new file mode 100644 index 00000000..854f522c --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/model/BertLinear.py @@ -0,0 +1,63 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn +from paddlenlp.transformers import BertForTokenClassification + + +class BertLinearPunc(nn.Layer): + def __init__(self, + pretrained_token="bert-base-uncased", + output_size=4, + dropout=0.2, + bert_size=768, + hiddensize=1568): + super(BertLinearPunc, self).__init__() + self.output_size = output_size + self.bert = BertForTokenClassification.from_pretrained( + pretrained_token, num_classes=bert_size) + # self.bert_vocab_size = vocab_size + # self.bn = nn.BatchNorm1d(segment_size*self.bert_vocab_size) + # self.fc = nn.Linear(segment_size*self.bert_vocab_size, output_size) + + # NOTE dense*2 使用bert中间层 dense hidden_state self.bert_size + self.dropout1 = nn.Dropout(dropout) + self.fc1 = nn.Linear(bert_size, hiddensize) + self.dropout2 = nn.Dropout(dropout) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hiddensize, output_size) + self.softmax = nn.Softmax() + + def forward(self, x): + # print('input :', x.shape) + x = self.bert(x) #[0] + # print('after bert :', x.shape) + + x = self.fc1(self.dropout1(x)) + x = self.fc2(self.relu(self.dropout2(x))) + x = paddle.reshape(x, shape=[-1, self.output_size]) + # print('after fc :', x.shape) + + logit = self.softmax(x) + # print('after softmax :', logit.shape) + + return x, logit + + +if __name__ == '__main__': + print('start model') + model = BertLinearPunc() + x = paddle.randint(low=0, high=40, shape=[2, 5]) + print(x) + y, logit = model(x) diff --git a/text_processing/speechtask/punctuation_restoration/model/blstm.py b/text_processing/speechtask/punctuation_restoration/model/blstm.py new file mode 100644 index 00000000..fcfd31a3 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/model/blstm.py @@ -0,0 +1,89 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn +import paddle.nn.initializer as I + + +class BiLSTM(nn.Layer): + """LSTM for Punctuation Restoration + """ + + def __init__(self, + vocab_size, + embedding_size, + hidden_size, + num_layers, + num_class, + init_scale=0.1): + super(BiLSTM, self).__init__() + # hyper parameters + self.vocab_size = vocab_size + self.embedding_size = embedding_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_class = num_class + + # 网络中的层 + self.embedding = nn.Embedding( + vocab_size, + embedding_size, + weight_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale))) + # print(hidden_size) + # print(embedding_size) + self.lstm = nn.LSTM( + input_size=embedding_size, + hidden_size=hidden_size, + num_layers=num_layers, + direction="bidirect", + weight_ih_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale)), + weight_hh_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale))) + # Here is a one direction LSTM. If bidirection LSTM, (hidden_size*2(,)) + self.fc = nn.Linear( + in_features=hidden_size * 2, + out_features=num_class, + weight_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale)), + bias_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale))) + # self.fc = nn.Linear(hidden_size, num_class) + + self.softmax = nn.Softmax() + + def forward(self, input): + """The forward process of Net + Parameters + ---------- + inputs : tensor + Training data, batch first + """ + # Inherit the knowledge of context + + # hidden = self.init_hidden(inputs.size(0)) + # print('input_size',inputs.size()) + embedding = self.embedding(input) + # print('embedding_size', embedding.size()) + # packed = pack_sequence(embedding, inputs_lengths, batch_first=True) + # embedding本身是同样长度的,用这个函数主要是为了用pack + # ***************************************************************************** + y, (_, _) = self.lstm(embedding) + + # print(y.size()) + y = self.fc(y) + y = paddle.reshape(y, shape=[-1, self.num_class]) + logit = self.softmax(y) + return y, logit diff --git a/text_processing/speechtask/punctuation_restoration/model/lstm.py b/text_processing/speechtask/punctuation_restoration/model/lstm.py new file mode 100644 index 00000000..5ec68533 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/model/lstm.py @@ -0,0 +1,85 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn +import paddle.nn.initializer as I + + +class RnnLm(nn.Layer): + def __init__(self, + vocab_size, + punc_size, + hidden_size, + num_layers=1, + init_scale=0.1, + dropout=0.0): + super(RnnLm, self).__init__() + self.hidden_size = hidden_size + self.num_layers = num_layers + self.init_scale = init_scale + self.punc_size = punc_size + + self.embedder = nn.Embedding( + vocab_size, + hidden_size, + weight_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale))) + + self.lstm = nn.LSTM( + input_size=hidden_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + weight_ih_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale)), + weight_hh_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale))) + + self.fc = nn.Linear( + hidden_size, + punc_size, + weight_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale)), + bias_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale))) + + self.dropout = nn.Dropout(p=dropout) + self.softmax = nn.Softmax() + + def forward(self, inputs): + x = inputs + x_emb = self.embedder(x) + x_emb = self.dropout(x_emb) + + y, (_, _) = self.lstm(x_emb) + + y = self.dropout(y) + y = self.fc(y) + y = paddle.reshape(y, shape=[-1, self.punc_size]) + logit = self.softmax(y) + return y, logit + + +class CrossEntropyLossForLm(nn.Layer): + def __init__(self): + super(CrossEntropyLossForLm, self).__init__() + + def forward(self, y, label): + label = paddle.unsqueeze(label, axis=2) + loss = paddle.nn.functional.cross_entropy( + input=y, label=label, reduction='none') + loss = paddle.squeeze(loss, axis=[2]) + loss = paddle.mean(loss, axis=[0]) + loss = paddle.sum(loss) + return loss diff --git a/text_processing/speechtask/punctuation_restoration/modules/__init__.py b/text_processing/speechtask/punctuation_restoration/modules/__init__.py new file mode 100644 index 00000000..185a92b8 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/modules/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/text_processing/speechtask/punctuation_restoration/modules/activation.py b/text_processing/speechtask/punctuation_restoration/modules/activation.py new file mode 100644 index 00000000..6a13e4aa --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/modules/activation.py @@ -0,0 +1,141 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import OrderedDict + +import paddle +from paddle import nn + +__all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock"] + + +def brelu(x, t_min=0.0, t_max=24.0, name=None): + # paddle.to_tensor is dygraph_only can not work under JIT + t_min = paddle.full(shape=[1], fill_value=t_min, dtype='float32') + t_max = paddle.full(shape=[1], fill_value=t_max, dtype='float32') + return x.maximum(t_min).minimum(t_max) + + +class LinearGLUBlock(nn.Layer): + """A linear Gated Linear Units (GLU) block.""" + + def __init__(self, idim: int): + """ GLU. + Args: + idim (int): input and output dimension + """ + super().__init__() + self.fc = nn.Linear(idim, idim * 2) + + def forward(self, xs): + return glu(self.fc(xs), dim=-1) + + +class ConvGLUBlock(nn.Layer): + def __init__(self, kernel_size, in_ch, out_ch, bottlececk_dim=0, + dropout=0.): + """A convolutional Gated Linear Units (GLU) block. + + Args: + kernel_size (int): kernel size + in_ch (int): number of input channels + out_ch (int): number of output channels + bottlececk_dim (int): dimension of the bottleneck layers for computational efficiency. Defaults to 0. + dropout (float): dropout probability. Defaults to 0.. + """ + + super().__init__() + + self.conv_residual = None + if in_ch != out_ch: + self.conv_residual = nn.utils.weight_norm( + nn.Conv2D( + in_channels=in_ch, out_channels=out_ch, kernel_size=(1, 1)), + name='weight', + dim=0) + self.dropout_residual = nn.Dropout(p=dropout) + + self.pad_left = ConstantPad2d((0, 0, kernel_size - 1, 0), 0) + + layers = OrderedDict() + if bottlececk_dim == 0: + layers['conv'] = nn.utils.weight_norm( + nn.Conv2D( + in_channels=in_ch, + out_channels=out_ch * 2, + kernel_size=(kernel_size, 1)), + name='weight', + dim=0) + # TODO(hirofumi0810): padding? + layers['dropout'] = nn.Dropout(p=dropout) + layers['glu'] = GLU() + + elif bottlececk_dim > 0: + layers['conv_in'] = nn.utils.weight_norm( + nn.Conv2D( + in_channels=in_ch, + out_channels=bottlececk_dim, + kernel_size=(1, 1)), + name='weight', + dim=0) + layers['dropout_in'] = nn.Dropout(p=dropout) + layers['conv_bottleneck'] = nn.utils.weight_norm( + nn.Conv2D( + in_channels=bottlececk_dim, + out_channels=bottlececk_dim, + kernel_size=(kernel_size, 1)), + name='weight', + dim=0) + layers['dropout'] = nn.Dropout(p=dropout) + layers['glu'] = GLU() + layers['conv_out'] = nn.utils.weight_norm( + nn.Conv2D( + in_channels=bottlececk_dim, + out_channels=out_ch * 2, + kernel_size=(1, 1)), + name='weight', + dim=0) + layers['dropout_out'] = nn.Dropout(p=dropout) + + self.layers = nn.Sequential(layers) + + def forward(self, xs): + """Forward pass. + Args: + xs (FloatTensor): `[B, in_ch, T, feat_dim]` + Returns: + out (FloatTensor): `[B, out_ch, T, feat_dim]` + """ + residual = xs + if self.conv_residual is not None: + residual = self.dropout_residual(self.conv_residual(residual)) + xs = self.pad_left(xs) # `[B, embed_dim, T+kernel-1, 1]` + xs = self.layers(xs) # `[B, out_ch * 2, T ,1]` + xs = xs + residual + return xs + + +def get_activation(act): + """Return activation function.""" + # Lazy load to avoid unused import + activation_funcs = { + "hardtanh": paddle.nn.Hardtanh, + "tanh": paddle.nn.Tanh, + "relu": paddle.nn.ReLU, + "selu": paddle.nn.SELU, + "swish": paddle.nn.Swish, + "gelu": paddle.nn.GELU, + "brelu": brelu, + } + + return activation_funcs[act]() diff --git a/text_processing/speechtask/punctuation_restoration/modules/attention.py b/text_processing/speechtask/punctuation_restoration/modules/attention.py new file mode 100644 index 00000000..1a7363c4 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/modules/attention.py @@ -0,0 +1,229 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Multi-Head Attention layer definition.""" +import math +from typing import Optional +from typing import Tuple + +import paddle +from paddle import nn +from paddle.nn import initializer as I + +__all__ = ["MultiHeadedAttention", "RelPositionMultiHeadedAttention"] + +# Relative Positional Encodings +# https://www.jianshu.com/p/c0608efcc26f +# https://zhuanlan.zhihu.com/p/344604604 + + +class MultiHeadedAttention(nn.Layer): + """Multi-Head Attention layer.""" + + def __init__(self, n_head: int, n_feat: int, dropout_rate: float): + """Construct an MultiHeadedAttention object. + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + super().__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv(self, + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Transform query, key and value. + Args: + query (paddle.Tensor): Query tensor (#batch, time1, size). + key (paddle.Tensor): Key tensor (#batch, time2, size). + value (paddle.Tensor): Value tensor (#batch, time2, size). + Returns: + paddle.Tensor: Transformed query tensor, size + (#batch, n_head, time1, d_k). + paddle.Tensor: Transformed key tensor, size + (#batch, n_head, time2, d_k). + paddle.Tensor: Transformed value tensor, size + (#batch, n_head, time2, d_k). + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) + k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k) + v = v.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k) + + return q, k, v + + def forward_attention(self, + value: paddle.Tensor, + scores: paddle.Tensor, + mask: Optional[paddle.Tensor]) -> paddle.Tensor: + """Compute attention context vector. + Args: + value (paddle.Tensor): Transformed value, size + (#batch, n_head, time2, d_k). + scores (paddle.Tensor): Attention score, size + (#batch, n_head, time1, time2). + mask (paddle.Tensor): Mask, size (#batch, 1, time2) or + (#batch, time1, time2). + Returns: + paddle.Tensor: Transformed value weighted + by the attention score, (#batch, time1, d_model). + """ + n_batch = value.size(0) + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + scores = scores.masked_fill(mask, -float('inf')) + attn = paddle.softmax( + scores, axis=-1).masked_fill(mask, + 0.0) # (batch, head, time1, time2) + else: + attn = paddle.softmax( + scores, axis=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(attn) + x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k) + x = x.transpose([0, 2, 1, 3]).contiguous().view( + n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + mask: Optional[paddle.Tensor]) -> paddle.Tensor: + """Compute scaled dot product attention. + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + """ + q, k, v = self.forward_qkv(query, key, value) + scores = paddle.matmul(q, + k.transpose([0, 1, 3, 2])) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask) + + +class RelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding.""" + + def __init__(self, n_head, n_feat, dropout_rate): + """Construct an RelPositionMultiHeadedAttention object. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + super().__init__(n_head, n_feat, dropout_rate) + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias_attr=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + #self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + #self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + #torch.nn.init.xavier_uniform_(self.pos_bias_u) + #torch.nn.init.xavier_uniform_(self.pos_bias_v) + pos_bias_u = self.create_parameter( + [self.h, self.d_k], default_initializer=I.XavierUniform()) + self.add_parameter('pos_bias_u', pos_bias_u) + pos_bias_v = self.create_parameter( + (self.h, self.d_k), default_initializer=I.XavierUniform()) + self.add_parameter('pos_bias_v', pos_bias_v) + + def rel_shift(self, x, zero_triu: bool=False): + """Compute relative positinal encoding. + Args: + x (paddle.Tensor): Input tensor (batch, head, time1, time1). + zero_triu (bool): If true, return the lower triangular part of + the matrix. + Returns: + paddle.Tensor: Output tensor. (batch, head, time1, time1) + """ + zero_pad = paddle.zeros( + (x.size(0), x.size(1), x.size(2), 1), dtype=x.dtype) + x_padded = paddle.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(x.size(0), x.size(1), x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1] + + if zero_triu: + ones = paddle.ones((x.size(2), x.size(3))) + x = x * paddle.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward(self, + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + pos_emb: paddle.Tensor, + mask: Optional[paddle.Tensor]): + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (paddle.Tensor): Query tensor (#batch, time1, size). + key (paddle.Tensor): Key tensor (#batch, time2, size). + value (paddle.Tensor): Value tensor (#batch, time2, size). + pos_emb (paddle.Tensor): Positional embedding tensor + (#batch, time1, size). + mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + Returns: + paddle.Tensor: Output tensor (#batch, time1, d_model). + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3]) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3]) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2])) + + # compute matrix b and matrix d + # (batch, head, time1, time2) + matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2])) + # Remove rel_shift since it is useless in speech recognition, + # and it requires special attention for streaming. + # matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask) diff --git a/text_processing/speechtask/punctuation_restoration/modules/crf.py b/text_processing/speechtask/punctuation_restoration/modules/crf.py new file mode 100644 index 00000000..0a53ae6f --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/modules/crf.py @@ -0,0 +1,366 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +from paddle import nn + +__all__ = ['CRF'] + + +class CRF(nn.Layer): + """ + Linear-chain Conditional Random Field (CRF). + + Args: + nb_labels (int): number of labels in your tagset, including special symbols. + bos_tag_id (int): integer representing the beginning of sentence symbol in + your tagset. + eos_tag_id (int): integer representing the end of sentence symbol in your tagset. + pad_tag_id (int, optional): integer representing the pad symbol in your tagset. + If None, the model will treat the PAD as a normal tag. Otherwise, the model + will apply constraints for PAD transitions. + batch_first (bool): Whether the first dimension represents the batch dimension. + """ + + def __init__(self, + nb_labels: int, + bos_tag_id: int, + eos_tag_id: int, + pad_tag_id: int=None, + batch_first: bool=True): + super().__init__() + + self.nb_labels = nb_labels + self.BOS_TAG_ID = bos_tag_id + self.EOS_TAG_ID = eos_tag_id + self.PAD_TAG_ID = pad_tag_id + self.batch_first = batch_first + + # initialize transitions from a random uniform distribution between -0.1 and 0.1 + self.transitions = self.create_parameter( + [self.nb_labels, self.nb_labels], + default_initializer=nn.initializer.Uniform(-0.1, 0.1)) + self.init_weights() + + def init_weights(self): + # enforce contraints (rows=from, columns=to) with a big negative number + # so exp(-10000) will tend to zero + + # no transitions allowed to the beginning of sentence + self.transitions[:, self.BOS_TAG_ID] = -10000.0 + # no transition alloed from the end of sentence + self.transitions[self.EOS_TAG_ID, :] = -10000.0 + + if self.PAD_TAG_ID is not None: + # no transitions from padding + self.transitions[self.PAD_TAG_ID, :] = -10000.0 + # no transitions to padding + self.transitions[:, self.PAD_TAG_ID] = -10000.0 + # except if the end of sentence is reached + # or we are already in a pad position + self.transitions[self.PAD_TAG_ID, self.EOS_TAG_ID] = 0.0 + self.transitions[self.PAD_TAG_ID, self.PAD_TAG_ID] = 0.0 + + def forward(self, + emissions: paddle.Tensor, + tags: paddle.Tensor, + mask: paddle.Tensor=None) -> paddle.Tensor: + """Compute the negative log-likelihood. See `log_likelihood` method.""" + nll = -self.log_likelihood(emissions, tags, mask=mask) + return nll + + def log_likelihood(self, emissions, tags, mask=None): + """Compute the probability of a sequence of tags given a sequence of + emissions scores. + + Args: + emissions (paddle.Tensor): Sequence of emissions for each label. + Shape of (batch_size, seq_len, nb_labels) if batch_first is True, + (seq_len, batch_size, nb_labels) otherwise. + tags (paddle.LongTensor): Sequence of labels. + Shape of (batch_size, seq_len) if batch_first is True, + (seq_len, batch_size) otherwise. + mask (paddle.FloatTensor, optional): Tensor representing valid positions. + If None, all positions are considered valid. + Shape of (batch_size, seq_len) if batch_first is True, + (seq_len, batch_size) otherwise. + + Returns: + paddle.Tensor: sum of the log-likelihoods for each sequence in the batch. + Shape of () + """ + # fix tensors order by setting batch as the first dimension + if not self.batch_first: + emissions = emissions.transpose(0, 1) + tags = tags.transpose(0, 1) + + if mask is None: + mask = paddle.ones(emissions.shape[:2], dtype=paddle.float) + + scores = self._compute_scores(emissions, tags, mask=mask) + partition = self._compute_log_partition(emissions, mask=mask) + return paddle.sum(scores - partition) + + def decode(self, emissions, mask=None): + """Find the most probable sequence of labels given the emissions using + the Viterbi algorithm. + + Args: + emissions (paddle.Tensor): Sequence of emissions for each label. + Shape (batch_size, seq_len, nb_labels) if batch_first is True, + (seq_len, batch_size, nb_labels) otherwise. + mask (paddle.FloatTensor, optional): Tensor representing valid positions. + If None, all positions are considered valid. + Shape (batch_size, seq_len) if batch_first is True, + (seq_len, batch_size) otherwise. + + Returns: + paddle.Tensor: the viterbi score for the for each batch. + Shape of (batch_size,) + list of lists: the best viterbi sequence of labels for each batch. [B, T] + """ + # fix tensors order by setting batch as the first dimension + if not self.batch_first: + emissions = emissions.transpose(0, 1) + tags = tags.transpose(0, 1) + + if mask is None: + mask = paddle.ones(emissions.shape[:2], dtype=paddle.float) + + scores, sequences = self._viterbi_decode(emissions, mask) + return scores, sequences + + def _compute_scores(self, emissions, tags, mask): + """Compute the scores for a given batch of emissions with their tags. + + Args: + emissions (paddle.Tensor): (batch_size, seq_len, nb_labels) + tags (Paddle.LongTensor): (batch_size, seq_len) + mask (Paddle.FloatTensor): (batch_size, seq_len) + + Returns: + paddle.Tensor: Scores for each batch. + Shape of (batch_size,) + """ + batch_size, seq_length = tags.shape + scores = paddle.zeros([batch_size]) + + # save first and last tags to be used later + first_tags = tags[:, 0] + last_valid_idx = mask.int().sum(1) - 1 + + # TODO(Hui Zhang): not support fancy index. + # last_tags = tags.gather(last_valid_idx.unsqueeze(1), axis=1).squeeze() + batch_idx = paddle.arange(batch_size, dtype=last_valid_idx.dtype) + gather_last_valid_idx = paddle.stack( + [batch_idx, last_valid_idx], axis=-1) + last_tags = tags.gather_nd(gather_last_valid_idx) + + # add the transition from BOS to the first tags for each batch + # t_scores = self.transitions[self.BOS_TAG_ID, first_tags] + t_scores = self.transitions[self.BOS_TAG_ID].gather(first_tags) + + # add the [unary] emission scores for the first tags for each batch + # for all batches, the first word, see the correspondent emissions + # for the first tags (which is a list of ids): + # emissions[:, 0, [tag_1, tag_2, ..., tag_nblabels]] + # e_scores = emissions[:, 0].gather(1, first_tags.unsqueeze(1)).squeeze() + gather_first_tags_idx = paddle.stack([batch_idx, first_tags], axis=-1) + e_scores = emissions[:, 0].gather_nd(gather_first_tags_idx) + + # the scores for a word is just the sum of both scores + scores += e_scores + t_scores + + # now lets do this for each remaining word + for i in range(1, seq_length): + + # we could: iterate over batches, check if we reached a mask symbol + # and stop the iteration, but vecotrizing is faster due to gpu, + # so instead we perform an element-wise multiplication + is_valid = mask[:, i] + + previous_tags = tags[:, i - 1] + current_tags = tags[:, i] + + # calculate emission and transition scores as we did before + # e_scores = emissions[:, i].gather(1, current_tags.unsqueeze(1)).squeeze() + gather_current_tags_idx = paddle.stack( + [batch_idx, current_tags], axis=-1) + e_scores = emissions[:, i].gather_nd(gather_current_tags_idx) + # t_scores = self.transitions[previous_tags, current_tags] + gather_transitions_idx = paddle.stack( + [previous_tags, current_tags], axis=-1) + t_scores = self.transitions.gather_nd(gather_transitions_idx) + + # apply the mask + e_scores = e_scores * is_valid + t_scores = t_scores * is_valid + + scores += e_scores + t_scores + + # add the transition from the end tag to the EOS tag for each batch + # scores += self.transitions[last_tags, self.EOS_TAG_ID] + scores += self.transitions.gather(last_tags)[:, self.EOS_TAG_ID] + + return scores + + def _compute_log_partition(self, emissions, mask): + """Compute the partition function in log-space using the forward-algorithm. + + Args: + emissions (paddle.Tensor): (batch_size, seq_len, nb_labels) + mask (Paddle.FloatTensor): (batch_size, seq_len) + + Returns: + paddle.Tensor: the partition scores for each batch. + Shape of (batch_size,) + """ + batch_size, seq_length, nb_labels = emissions.shape + + # in the first iteration, BOS will have all the scores + alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze( + 0) + emissions[:, 0] + + for i in range(1, seq_length): + # (bs, nb_labels) -> (bs, 1, nb_labels) + e_scores = emissions[:, i].unsqueeze(1) + + # (nb_labels, nb_labels) -> (bs, nb_labels, nb_labels) + t_scores = self.transitions.unsqueeze(0) + + # (bs, nb_labels) -> (bs, nb_labels, 1) + a_scores = alphas.unsqueeze(2) + + scores = e_scores + t_scores + a_scores + new_alphas = paddle.logsumexp(scores, axis=1) + + # set alphas if the mask is valid, otherwise keep the current values + is_valid = mask[:, i].unsqueeze(-1) + alphas = is_valid * new_alphas + (1 - is_valid) * alphas + + # add the scores for the final transition + last_transition = self.transitions[:, self.EOS_TAG_ID] + end_scores = alphas + last_transition.unsqueeze(0) + + # return a *log* of sums of exps + return paddle.logsumexp(end_scores, axis=1) + + def _viterbi_decode(self, emissions, mask): + """Compute the viterbi algorithm to find the most probable sequence of labels + given a sequence of emissions. + + Args: + emissions (paddle.Tensor): (batch_size, seq_len, nb_labels) + mask (Paddle.FloatTensor): (batch_size, seq_len) + + Returns: + paddle.Tensor: the viterbi score for the for each batch. + Shape of (batch_size,) + list of lists of ints: the best viterbi sequence of labels for each batch + """ + batch_size, seq_length, nb_labels = emissions.shape + + # in the first iteration, BOS will have all the scores and then, the max + alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze( + 0) + emissions[:, 0] + + backpointers = [] + + for i in range(1, seq_length): + # (bs, nb_labels) -> (bs, 1, nb_labels) + e_scores = emissions[:, i].unsqueeze(1) + + # (nb_labels, nb_labels) -> (bs, nb_labels, nb_labels) + t_scores = self.transitions.unsqueeze(0) + + # (bs, nb_labels) -> (bs, nb_labels, 1) + a_scores = alphas.unsqueeze(2) + + # combine current scores with previous alphas + scores = e_scores + t_scores + a_scores + + # so far is exactly like the forward algorithm, + # but now, instead of calculating the logsumexp, + # we will find the highest score and the tag associated with it + # max_scores, max_score_tags = paddle.max(scores, axis=1) + max_scores = paddle.max(scores, axis=1) + max_score_tags = paddle.argmax(scores, axis=1) + + # set alphas if the mask is valid, otherwise keep the current values + is_valid = mask[:, i].unsqueeze(-1) + alphas = is_valid * max_scores + (1 - is_valid) * alphas + + # add the max_score_tags for our list of backpointers + # max_scores has shape (batch_size, nb_labels) so we transpose it to + # be compatible with our previous loopy version of viterbi + backpointers.append(max_score_tags.t()) + + # add the scores for the final transition + last_transition = self.transitions[:, self.EOS_TAG_ID] + end_scores = alphas + last_transition.unsqueeze(0) + + # get the final most probable score and the final most probable tag + # max_final_scores, max_final_tags = paddle.max(end_scores, axis=1) + max_final_scores = paddle.max(end_scores, axis=1) + max_final_tags = paddle.argmax(end_scores, axis=1) + + # find the best sequence of labels for each sample in the batch + best_sequences = [] + emission_lengths = mask.int().sum(axis=1) + for i in range(batch_size): + + # recover the original sentence length for the i-th sample in the batch + sample_length = emission_lengths[i].item() + + # recover the max tag for the last timestep + sample_final_tag = max_final_tags[i].item() + + # limit the backpointers until the last but one + # since the last corresponds to the sample_final_tag + sample_backpointers = backpointers[:sample_length - 1] + + # follow the backpointers to build the sequence of labels + sample_path = self._find_best_path(i, sample_final_tag, + sample_backpointers) + + # add this path to the list of best sequences + best_sequences.append(sample_path) + + return max_final_scores, best_sequences + + def _find_best_path(self, sample_id, best_tag, backpointers): + """Auxiliary function to find the best path sequence for a specific sample. + + Args: + sample_id (int): sample index in the range [0, batch_size) + best_tag (int): tag which maximizes the final score + backpointers (list of lists of tensors): list of pointers with + shape (seq_len_i-1, nb_labels, batch_size) where seq_len_i + represents the length of the ith sample in the batch + + Returns: + list of ints: a list of tag indexes representing the bast path + """ + # add the final best_tag to our best path + best_path = [best_tag] + + # traverse the backpointers in backwards + for backpointers_t in reversed(backpointers): + + # recover the best_tag at this timestep + best_tag = backpointers_t[best_tag][sample_id].item() + + # append to the beginning of the list so we don't need to reverse it later + best_path.insert(0, best_tag) + + return best_path diff --git a/text_processing/speechtask/punctuation_restoration/training/__init__.py b/text_processing/speechtask/punctuation_restoration/training/__init__.py new file mode 100644 index 00000000..185a92b8 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/training/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/text_processing/speechtask/punctuation_restoration/training/loss.py b/text_processing/speechtask/punctuation_restoration/training/loss.py new file mode 100644 index 00000000..356dfcab --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/training/loss.py @@ -0,0 +1,98 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class FocalLossHX(nn.Layer): + def __init__(self, gamma=0, size_average=True): + super(FocalLoss, self).__init__() + self.gamma = gamma + self.size_average = size_average + + def forward(self, input, target): + # print('input') + # print(input.shape) + # print(target.shape) + + if input.dim() > 2: + input = paddle.reshape( + input, + shape=[input.size(0), input.size(1), -1]) # N,C,H,W => N,C,H*W + input = input.transpose(1, 2) # N,C,H*W => N,H*W,C + input = paddle.reshape( + input, shape=[-1, input.size(2)]) # N,H*W,C => N*H*W,C + target = paddle.reshape(target, shape=[-1]) + + logpt = F.log_softmax(input) + # print('logpt') + # print(logpt.shape) + # print(logpt) + + # get true class column from each row + all_rows = paddle.arange(len(input)) + # print(target) + log_pt = logpt.numpy()[all_rows.numpy(), target.numpy()] + + pt = paddle.to_tensor(log_pt, dtype='float64').exp() + ce = F.cross_entropy(input, target, reduction='none') + # print('ce') + # print(ce.shape) + + loss = (1 - pt)**self.gamma * ce + # print('ce:%f'%ce.mean()) + # print('fl:%f'%loss.mean()) + if self.size_average: + return loss.mean() + else: + return loss.sum() + + +class FocalLoss(nn.Layer): + """ + Focal Loss. + Code referenced from: + https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py + Args: + gamma (float): the coefficient of Focal Loss. + ignore_index (int64): Specifies a target value that is ignored + and does not contribute to the input gradient. Default ``255``. + """ + + def __init__(self, gamma=2.0): + super(FocalLoss, self).__init__() + self.gamma = gamma + + def forward(self, logit, label): + #####logit = F.softmax(logit) + # logit = paddle.reshape( + # logit, [logit.shape[0], logit.shape[1], -1]) # N,C,H,W => N,C,H*W + # logit = paddle.transpose(logit, [0, 2, 1]) # N,C,H*W => N,H*W,C + # logit = paddle.reshape(logit, + # [-1, logit.shape[2]]) # N,H*W,C => N*H*W,C + label = paddle.reshape(label, [-1, 1]) + range_ = paddle.arange(0, label.shape[0]) + range_ = paddle.unsqueeze(range_, axis=-1) + label = paddle.cast(label, dtype='int64') + label = paddle.concat([range_, label], axis=-1) + logpt = F.log_softmax(logit) + logpt = paddle.gather_nd(logpt, label) + + pt = paddle.exp(logpt.detach()) + loss = -1 * (1 - pt)**self.gamma * logpt + loss = paddle.mean(loss) + # print(loss) + # print(logpt) + return loss diff --git a/text_processing/speechtask/punctuation_restoration/training/trainer.py b/text_processing/speechtask/punctuation_restoration/training/trainer.py new file mode 100644 index 00000000..2dce88a3 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/training/trainer.py @@ -0,0 +1,651 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import time +from collections import defaultdict +from pathlib import Path + +import numpy as np +import paddle +import paddle.nn as nn +import pandas as pd +from paddle import distributed as dist +from paddle.io import DataLoader +from sklearn.metrics import classification_report +from sklearn.metrics import f1_score +from sklearn.metrics import precision_recall_fscore_support +from speechtask.punctuation_restoration.io.dataset import PuncDataset +from speechtask.punctuation_restoration.io.dataset import PuncDatasetFromBertTokenizer +from speechtask.punctuation_restoration.model.BertBLSTM import BertBLSTMPunc +from speechtask.punctuation_restoration.model.BertLinear import BertLinearPunc +from speechtask.punctuation_restoration.model.blstm import BiLSTM +from speechtask.punctuation_restoration.model.lstm import RnnLm +from speechtask.punctuation_restoration.utils import layer_tools +from speechtask.punctuation_restoration.utils import mp_tools +from speechtask.punctuation_restoration.utils.checkpoint import Checkpoint +from tensorboardX import SummaryWriter + +__all__ = ["Trainer", "Tester"] + +DefinedClassifier = { + "lstm": RnnLm, + "blstm": BiLSTM, + "BertLinear": BertLinearPunc, + "BertBLSTM": BertBLSTMPunc +} + +DefinedLoss = { + "ce": nn.CrossEntropyLoss, +} + +DefinedDataset = { + 'PuncCh': PuncDataset, + 'Bert': PuncDatasetFromBertTokenizer, +} + + +class Trainer(): + """ + An experiment template in order to structure the training code and take + care of saving, loading, logging, visualization stuffs. It"s intended to + be flexible and simple. + + So it only handles output directory (create directory for the output, + create a checkpoint directory, dump the config in use and create + visualizer and logger) in a standard way without enforcing any + input-output protocols to the model and dataloader. It leaves the main + part for the user to implement their own (setup the model, criterion, + optimizer, define a training step, define a validation function and + customize all the text and visual logs). + It does not save too much boilerplate code. The users still have to write + the forward/backward/update mannually, but they are free to add + non-standard behaviors if needed. + We have some conventions to follow. + 1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and + ``valid_loader``, ``config`` and ``args`` attributes. + 2. The config should have a ``training`` field, which has + ``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is + used as the trigger to invoke validation, checkpointing and stop of the + experiment. + 3. There are four methods, namely ``train_batch``, ``valid``, + ``setup_model`` and ``setup_dataloader`` that should be implemented. + Feel free to add/overwrite other methods and standalone functions if you + need. + + Parameters + ---------- + config: yacs.config.CfgNode + The configuration used for the experiment. + + args: argparse.Namespace + The parsed command line arguments. + Examples + -------- + >>> def main_sp(config, args): + >>> exp = Trainer(config, args) + >>> exp.setup() + >>> exp.run() + >>> + >>> config = get_cfg_defaults() + >>> parser = default_argument_parser() + >>> args = parser.parse_args() + >>> if args.config: + >>> config.merge_from_file(args.config) + >>> if args.opts: + >>> config.merge_from_list(args.opts) + >>> config.freeze() + >>> + >>> if args.nprocs > 1 and args.device == "gpu": + >>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) + >>> else: + >>> main_sp(config, args) + """ + + def __init__(self, config, args): + self.config = config + self.args = args + self.optimizer = None + self.visualizer = None + self.output_dir = None + self.checkpoint_dir = None + self.iteration = 0 + self.epoch = 0 + + def setup(self): + """Setup the experiment. + """ + self.setup_logger() + paddle.set_device(self.args.device) + if self.parallel: + self.init_parallel() + + self.setup_output_dir() + self.dump_config() + self.setup_visualizer() + self.setup_checkpointer() + + self.setup_model() + + self.setup_dataloader() + + self.iteration = 0 + self.epoch = 0 + + @property + def parallel(self): + """A flag indicating whether the experiment should run with + multiprocessing. + """ + return self.args.device == "gpu" and self.args.nprocs > 1 + + def init_parallel(self): + """Init environment for multiprocess training. + """ + dist.init_parallel_env() + + @mp_tools.rank_zero_only + def save(self, tag=None, infos: dict=None): + """Save checkpoint (model parameters and optimizer states). + + Args: + tag (int or str, optional): None for step, else using tag, e.g epoch. Defaults to None. + infos (dict, optional): meta data to save. Defaults to None. + """ + + infos = infos if infos else dict() + infos.update({ + "step": self.iteration, + "epoch": self.epoch, + "lr": self.optimizer.get_lr() + }) + self.checkpointer.add_checkpoint(self.checkpoint_dir, self.iteration + if tag is None else tag, self.model, + self.optimizer, infos) + + def resume_or_scratch(self): + """Resume from latest checkpoint at checkpoints in the output + directory or load a specified checkpoint. + + If ``args.checkpoint_path`` is not None, load the checkpoint, else + resume training. + """ + scratch = None + infos = self.checkpointer.load_parameters( + self.model, + self.optimizer, + checkpoint_dir=self.checkpoint_dir, + checkpoint_path=self.args.checkpoint_path) + if infos: + # restore from ckpt + self.iteration = infos["step"] + self.epoch = infos["epoch"] + scratch = False + else: + self.iteration = 0 + self.epoch = 0 + scratch = True + + return scratch + + def new_epoch(self): + """Reset the train loader seed and increment `epoch`. + """ + self.epoch += 1 + if self.parallel: + self.train_loader.batch_sampler.set_epoch(self.epoch) + + def train(self): + """The training process control by epoch.""" + from_scratch = self.resume_or_scratch() + + if from_scratch: + # save init model, i.e. 0 epoch + self.save(tag="init") + + self.lr_scheduler.step(self.iteration) + if self.parallel: + self.train_loader.batch_sampler.set_epoch(self.epoch) + + self.logger.info( + f"Train Total Examples: {len(self.train_loader.dataset)}") + self.punc_list = [] + for i in range(len(self.train_loader.dataset.id2punc)): + self.punc_list.append(self.train_loader.dataset.id2punc[i]) + while self.epoch < self.config["training"]["n_epoch"]: + self.model.train() + self.total_label_train = [] + self.total_predict_train = [] + try: + data_start_time = time.time() + for batch_index, batch in enumerate(self.train_loader): + dataload_time = time.time() - data_start_time + msg = "Train: Rank: {}, ".format(dist.get_rank()) + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(batch_index + 1, + len(self.train_loader)) + msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) + msg += "data time: {:>.3f}s, ".format(dataload_time) + self.train_batch(batch_index, batch, msg) + data_start_time = time.time() + t = classification_report( + self.total_label_train, + self.total_predict_train, + target_names=self.punc_list) + self.logger.info(t) + except Exception as e: + self.logger.error(e) + raise e + + total_loss, F1_score = self.valid() + self.logger.info("Epoch {} Val info val_loss {}, F1_score {}". + format(self.epoch, total_loss, F1_score)) + if self.visualizer: + self.visualizer.add_scalars("epoch", { + "total_loss": total_loss, + "lr": self.lr_scheduler() + }, self.epoch) + + self.save( + tag=self.epoch, infos={"val_loss": total_loss, + "F1": F1_score}) + # step lr every epoch + self.lr_scheduler.step() + self.new_epoch() + + def run(self): + """The routine of the experiment after setup. This method is intended + to be used by the user. + """ + try: + self.train() + except KeyboardInterrupt: + self.save() + exit(-1) + finally: + self.destory() + self.logger.info("Training Done.") + + def setup_output_dir(self): + """Create a directory used for output. + """ + # output dir + output_dir = Path(self.args.output).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + + self.output_dir = output_dir + + def setup_checkpointer(self): + """Create a directory used to save checkpoints into. + + It is "checkpoints" inside the output directory. + """ + # checkpoint dir + self.checkpointer = Checkpoint(self.logger, + self.config["checkpoint"]["kbest_n"], + self.config["checkpoint"]["latest_n"]) + + checkpoint_dir = self.output_dir / "checkpoints" + checkpoint_dir.mkdir(exist_ok=True) + + self.checkpoint_dir = checkpoint_dir + + def setup_logger(self): + LOG_FORMAT = "%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s" + format_str = logging.Formatter( + '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' + ) + logging.basicConfig( + filename=self.config["training"]["log_path"], + level=logging.INFO, + format=LOG_FORMAT) + self.logger = logging.getLogger(__name__) + # self.logger = logging.getLogger(self.config["training"]["log_path"].strip().split('/')[-1].split('.')[0]) + + self.logger.setLevel(logging.INFO) #设置日志级别 + sh = logging.StreamHandler() #往屏幕上输出 + sh.setFormatter(format_str) #设置屏幕上显示的格式 + self.logger.addHandler(sh) #把对象加到logger里 + + self.logger.info('info') + print("setup logger!!!") + + @mp_tools.rank_zero_only + def destory(self): + """Close visualizer to avoid hanging after training""" + # https://github.com/pytorch/fairseq/issues/2357 + if self.visualizer: + self.visualizer.close() + + @mp_tools.rank_zero_only + def setup_visualizer(self): + """Initialize a visualizer to log the experiment. + + The visual log is saved in the output directory. + + Notes + ------ + Only the main process has a visualizer with it. Use multiple + visualizers in multiprocess to write to a same log file may cause + unexpected behaviors. + """ + # visualizer + visualizer = SummaryWriter(logdir=str(self.output_dir)) + self.visualizer = visualizer + + @mp_tools.rank_zero_only + def dump_config(self): + """Save the configuration used for this experiment. + + It is saved in to ``config.yaml`` in the output directory at the + beginning of the experiment. + """ + with open(self.output_dir / "config.yaml", "wt") as f: + print(self.config, file=f) + + def train_batch(self, batch_index, batch_data, msg): + start = time.time() + + input, label = batch_data + label = paddle.reshape(label, shape=[-1]) + y, logit = self.model(input) + pred = paddle.argmax(logit, axis=1) + self.total_label_train.extend(label.numpy().tolist()) + self.total_predict_train.extend(pred.numpy().tolist()) + # self.total_predict.append(logit.numpy().tolist()) + # print('--after model----') + # # print(label.shape) + # # print(pred.shape) + # # print('--!!!!!!!!!!!!!----') + # print("self.total_label") + # print(self.total_label) + # print("self.total_predict") + # print(self.total_predict) + loss = self.crit(y, label) + + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + self.optimizer.step() + self.optimizer.clear_grad() + iteration_time = time.time() - start + + losses_np = { + "train_loss": float(loss), + } + msg += "train time: {:>.3f}s, ".format(iteration_time) + msg += "batch size: {}, ".format(self.config["data"]["batch_size"]) + msg += ", ".join("{}: {:>.6f}".format(k, v) + for k, v in losses_np.items()) + self.logger.info(msg) + # print(msg) + + if dist.get_rank() == 0 and self.visualizer: + for k, v in losses_np.items(): + self.visualizer.add_scalar("train/{}".format(k), v, + self.iteration) + self.iteration += 1 + + @paddle.no_grad() + def valid(self): + self.logger.info( + f"Valid Total Examples: {len(self.valid_loader.dataset)}") + self.model.eval() + valid_losses = defaultdict(list) + num_seen_utts = 1 + total_loss = 0.0 + valid_total_label = [] + valid_total_predict = [] + for i, batch in enumerate(self.valid_loader): + input, label = batch + label = paddle.reshape(label, shape=[-1]) + y, logit = self.model(input) + pred = paddle.argmax(logit, axis=1) + valid_total_label.extend(label.numpy().tolist()) + valid_total_predict.extend(pred.numpy().tolist()) + loss = self.crit(y, label) + + if paddle.isfinite(loss): + num_utts = batch[1].shape[0] + num_seen_utts += num_utts + total_loss += float(loss) * num_utts + valid_losses["val_loss"].append(float(loss)) + + if (i + 1) % self.config["training"]["log_interval"] == 0: + valid_dump = {k: np.mean(v) for k, v in valid_losses.items()} + valid_dump["val_history_loss"] = total_loss / num_seen_utts + + # logging + msg = f"Valid: Rank: {dist.get_rank()}, " + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(i + 1, len(self.valid_loader)) + msg += ", ".join("{}: {:>.6f}".format(k, v) + for k, v in valid_dump.items()) + self.logger.info(msg) + # print(msg) + + self.logger.info("Rank {} Val info val_loss {}".format( + dist.get_rank(), total_loss / num_seen_utts)) + # print("Rank {} Val info val_loss {} acc: {}".format( + # dist.get_rank(), total_loss / num_seen_utts, acc)) + F1_score = f1_score( + valid_total_label, valid_total_predict, average="macro") + return total_loss / num_seen_utts, F1_score + + def setup_model(self): + config = self.config + + model = DefinedClassifier[self.config["model_type"]]( + **self.config["model_params"]) + self.crit = DefinedLoss[self.config["loss_type"]](**self.config[ + "loss"]) if "loss_type" in self.config else DefinedLoss["ce"]() + + if self.parallel: + model = paddle.DataParallel(model) + + self.logger.info(f"{model}") + layer_tools.print_params(model, self.logger.info) + + lr_scheduler = paddle.optimizer.lr.ExponentialDecay( + learning_rate=config["training"]["lr"], + gamma=config["training"]["lr_decay"], + verbose=True) + optimizer = paddle.optimizer.Adam( + learning_rate=lr_scheduler, + parameters=model.parameters(), + weight_decay=paddle.regularizer.L2Decay( + config["training"]["weight_decay"])) + + self.model = model + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.logger.info("Setup model/criterion/optimizer/lr_scheduler!") + + def setup_dataloader(self): + print("setup_dataloader!!!") + config = self.config["data"].copy() + + print(config["batch_size"]) + + train_dataset = DefinedDataset[config["dataset_type"]]( + train_path=config["train_path"], **config["data_params"]) + dev_dataset = DefinedDataset[config["dataset_type"]]( + train_path=config["dev_path"], **config["data_params"]) + + # train_dataset = config["dataset_type"](os.path.join(config["save_path"], "train"), + # os.path.join(config["save_path"], config["vocab_file"]), + # os.path.join(config["save_path"], config["punc_file"]), + # config["seq_len"]) + + # dev_dataset = PuncDataset(os.path.join(config["save_path"], "dev"), + # os.path.join(config["save_path"], config["vocab_file"]), + # os.path.join(config["save_path"], config["punc_file"]), + # config["seq_len"]) + + # if self.parallel: + # batch_sampler = SortagradDistributedBatchSampler( + # train_dataset, + # batch_size=config["batch_size"], + # num_replicas=None, + # rank=None, + # shuffle=True, + # drop_last=True, + # sortagrad=config["sortagrad"], + # shuffle_method=config["shuffle_method"]) + # else: + # batch_sampler = SortagradBatchSampler( + # train_dataset, + # shuffle=True, + # batch_size=config["batch_size"], + # drop_last=True, + # sortagrad=config["sortagrad"], + # shuffle_method=config["shuffle_method"]) + + self.train_loader = DataLoader( + train_dataset, + num_workers=config["num_workers"], + batch_size=config["batch_size"]) + self.valid_loader = DataLoader( + dev_dataset, + batch_size=config["batch_size"], + shuffle=False, + drop_last=False, + num_workers=config["num_workers"]) + self.logger.info("Setup train/valid Dataloader!") + + +class Tester(Trainer): + def __init__(self, config, args): + super().__init__(config, args) + + @mp_tools.rank_zero_only + @paddle.no_grad() + def test(self): + self.logger.info( + f"Test Total Examples: {len(self.test_loader.dataset)}") + self.punc_list = [] + for i in range(len(self.test_loader.dataset.id2punc)): + self.punc_list.append(self.test_loader.dataset.id2punc[i]) + self.model.eval() + test_total_label = [] + test_total_predict = [] + with open(self.args.result_file, 'w') as fout: + for i, batch in enumerate(self.test_loader): + input, label = batch + label = paddle.reshape(label, shape=[-1]) + y, logit = self.model(input) + pred = paddle.argmax(logit, axis=1) + test_total_label.extend(label.numpy().tolist()) + test_total_predict.extend(pred.numpy().tolist()) + # print(type(logit)) + + # logging + msg = "Test: " + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + self.logger.info(msg) + # print(msg) + t = classification_report( + test_total_label, test_total_predict, target_names=self.punc_list) + print(t) + t2 = self.evaluation(test_total_label, test_total_predict) + print(t2) + + def evaluation(self, y_pred, y_test): + precision, recall, f1, _ = precision_recall_fscore_support( + y_test, y_pred, average=None, labels=[1, 2, 3]) + overall = precision_recall_fscore_support( + y_test, y_pred, average='macro', labels=[1, 2, 3]) + result = pd.DataFrame( + np.array([precision, recall, f1]), + columns=list(['O', 'COMMA', 'PERIOD', 'QUESTION'])[1:], + index=['Precision', 'Recall', 'F1']) + result['OVERALL'] = overall[:3] + return result + + def run_test(self): + self.resume_or_scratch() + try: + self.test() + except KeyboardInterrupt: + exit(-1) + + def setup(self): + """Setup the experiment. + """ + paddle.set_device(self.args.device) + self.setup_logger() + self.setup_output_dir() + self.setup_checkpointer() + + self.setup_dataloader() + self.setup_model() + + self.iteration = 0 + self.epoch = 0 + + def setup_model(self): + config = self.config + model = DefinedClassifier[self.config["model_type"]]( + **self.config["model_params"]) + + self.model = model + self.logger.info("Setup model!") + + def setup_dataloader(self): + config = self.config["data"].copy() + + test_dataset = DefinedDataset[config["dataset_type"]]( + train_path=config["test_path"], **config["data_params"]) + + self.test_loader = DataLoader( + test_dataset, + batch_size=config["batch_size"], + shuffle=False, + drop_last=False) + self.logger.info("Setup test Dataloader!") + + def setup_output_dir(self): + """Create a directory used for output. + """ + # output dir + if self.args.output: + output_dir = Path(self.args.output).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + else: + output_dir = Path( + self.args.checkpoint_path).expanduser().parent.parent + output_dir.mkdir(parents=True, exist_ok=True) + + self.output_dir = output_dir + + def setup_logger(self): + LOG_FORMAT = "%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s" + format_str = logging.Formatter( + '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' + ) + logging.basicConfig( + filename=self.config["testing"]["log_path"], + level=logging.INFO, + format=LOG_FORMAT) + self.logger = logging.getLogger(__name__) + # self.logger = logging.getLogger(self.config["training"]["log_path"].strip().split('/')[-1].split('.')[0]) + + self.logger.setLevel(logging.INFO) #设置日志级别 + sh = logging.StreamHandler() #往屏幕上输出 + sh.setFormatter(format_str) #设置屏幕上显示的格式 + self.logger.addHandler(sh) #把对象加到logger里 + + self.logger.info('info') + print("setup test logger!!!") diff --git a/text_processing/speechtask/punctuation_restoration/utils/__init__.py b/text_processing/speechtask/punctuation_restoration/utils/__init__.py new file mode 100644 index 00000000..185a92b8 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/text_processing/speechtask/punctuation_restoration/utils/checkpoint.py b/text_processing/speechtask/punctuation_restoration/utils/checkpoint.py new file mode 100644 index 00000000..1ad4b5b3 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/utils/checkpoint.py @@ -0,0 +1,304 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import glob +import json +import os +import re +from pathlib import Path +from typing import Text +from typing import Union + +import paddle +from paddle import distributed as dist +from paddle.optimizer import Optimizer +from speechtask.punctuation_restoration.utils import mp_tools +# from speechtask.punctuation_restoration.utils.log import Log + +# logger = Log(__name__).getlog() + +__all__ = ["Checkpoint"] + + +class Checkpoint(): + def __init__(self, + logger, + kbest_n: int=5, + latest_n: int=1, + metric_type='val_loss'): + self.best_records: Mapping[Path, float] = {} + self.latest_records = [] + self.kbest_n = kbest_n + self.latest_n = latest_n + self._save_all = (kbest_n == -1) + self.logger = logger + self.metric_type = metric_type + + def add_checkpoint(self, + checkpoint_dir, + tag_or_iteration: Union[int, Text], + model: paddle.nn.Layer, + optimizer: Optimizer=None, + infos: dict=None): + """Save checkpoint in best_n and latest_n. + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + tag_or_iteration (int or str): the latest iteration(step or epoch) number or tag. + model (Layer): model to be checkpointed. + optimizer (Optimizer, optional): optimizer to be checkpointed. + infos (dict or None)): any info you want to save. + metric_type (str, optional): metric type. Defaults to 'val_loss'. + """ + metric_type = self.metric_type + if (metric_type not in infos.keys()): + self._save_parameters(checkpoint_dir, tag_or_iteration, model, + optimizer, infos) + return + + #save best + if self._should_save_best(infos[metric_type]): + self._save_best_checkpoint_and_update( + infos[metric_type], checkpoint_dir, tag_or_iteration, model, + optimizer, infos) + #save latest + self._save_latest_checkpoint_and_update( + checkpoint_dir, tag_or_iteration, model, optimizer, infos) + + if isinstance(tag_or_iteration, int): + self._save_checkpoint_record(checkpoint_dir, tag_or_iteration) + + def load_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None, + record_file="checkpoint_latest"): + """Load a last model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + record_file "checkpoint_latest" or "checkpoint_best" + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + configs = {} + + if checkpoint_path is not None: + pass + elif checkpoint_dir is not None and record_file is not None: + # load checkpint from record file + checkpoint_record = os.path.join(checkpoint_dir, record_file) + iteration = self._load_checkpoint_idx(checkpoint_record) + if iteration == -1: + return configs + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(iteration)) + else: + raise ValueError( + "At least one of 'checkpoint_path' or 'checkpoint_dir' should be specified!" + ) + + rank = dist.get_rank() + + params_path = checkpoint_path + ".pdparams" + model_dict = paddle.load(params_path) + model.set_state_dict(model_dict) + self.logger.info( + "Rank {}: loaded model from {}".format(rank, params_path)) + + optimizer_path = checkpoint_path + ".pdopt" + if optimizer and os.path.isfile(optimizer_path): + optimizer_dict = paddle.load(optimizer_path) + optimizer.set_state_dict(optimizer_dict) + self.logger.info("Rank {}: loaded optimizer state from {}".format( + rank, optimizer_path)) + + info_path = re.sub('.pdparams$', '.json', params_path) + if os.path.exists(info_path): + with open(info_path, 'r') as fin: + configs = json.load(fin) + return configs + + def load_latest_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): + """Load a last model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + return self.load_parameters(model, optimizer, checkpoint_dir, + checkpoint_path, "checkpoint_latest") + + def load_best_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): + """Load a last model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + return self.load_parameters(model, optimizer, checkpoint_dir, + checkpoint_path, "checkpoint_best") + + def _should_save_best(self, metric: float) -> bool: + if not self._best_full(): + return True + + # already full + worst_record_path = max(self.best_records, key=self.best_records.get) + # worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0] + worst_metric = self.best_records[worst_record_path] + return metric < worst_metric + + def _best_full(self): + return (not self._save_all) and len(self.best_records) == self.kbest_n + + def _latest_full(self): + return len(self.latest_records) == self.latest_n + + def _save_best_checkpoint_and_update(self, metric, checkpoint_dir, + tag_or_iteration, model, optimizer, + infos): + # remove the worst + if self._best_full(): + worst_record_path = max(self.best_records, + key=self.best_records.get) + self.best_records.pop(worst_record_path) + if (worst_record_path not in self.latest_records): + self.logger.info( + "remove the worst checkpoint: {}".format(worst_record_path)) + self._del_checkpoint(checkpoint_dir, worst_record_path) + + # add the new one + self._save_parameters(checkpoint_dir, tag_or_iteration, model, + optimizer, infos) + self.best_records[tag_or_iteration] = metric + + def _save_latest_checkpoint_and_update( + self, checkpoint_dir, tag_or_iteration, model, optimizer, infos): + # remove the old + if self._latest_full(): + to_del_fn = self.latest_records.pop(0) + if (to_del_fn not in self.best_records.keys()): + self.logger.info( + "remove the latest checkpoint: {}".format(to_del_fn)) + self._del_checkpoint(checkpoint_dir, to_del_fn) + self.latest_records.append(tag_or_iteration) + + self._save_parameters(checkpoint_dir, tag_or_iteration, model, + optimizer, infos) + + def _del_checkpoint(self, checkpoint_dir, tag_or_iteration): + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(tag_or_iteration)) + for filename in glob.glob(checkpoint_path + ".*"): + os.remove(filename) + self.logger.info("delete file: {}".format(filename)) + + def _load_checkpoint_idx(self, checkpoint_record: str) -> int: + """Get the iteration number corresponding to the latest saved checkpoint. + Args: + checkpoint_path (str): the saved path of checkpoint. + Returns: + int: the latest iteration number. -1 for no checkpoint to load. + """ + if not os.path.isfile(checkpoint_record): + return -1 + + # Fetch the latest checkpoint index. + with open(checkpoint_record, "rt") as handle: + latest_checkpoint = handle.readlines()[-1].strip() + iteration = int(latest_checkpoint.split(":")[-1]) + return iteration + + def _save_checkpoint_record(self, checkpoint_dir: str, iteration: int): + """Save the iteration number of the latest model to be checkpoint record. + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + iteration (int): the latest iteration number. + Returns: + None + """ + checkpoint_record_latest = os.path.join(checkpoint_dir, + "checkpoint_latest") + checkpoint_record_best = os.path.join(checkpoint_dir, "checkpoint_best") + + with open(checkpoint_record_best, "w") as handle: + for i in self.best_records.keys(): + handle.write("model_checkpoint_path:{}\n".format(i)) + with open(checkpoint_record_latest, "w") as handle: + for i in self.latest_records: + handle.write("model_checkpoint_path:{}\n".format(i)) + + @mp_tools.rank_zero_only + def _save_parameters(self, + checkpoint_dir: str, + tag_or_iteration: Union[int, str], + model: paddle.nn.Layer, + optimizer: Optimizer=None, + infos: dict=None): + """Checkpoint the latest trained model parameters. + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + tag_or_iteration (int or str): the latest iteration(step or epoch) number. + model (Layer): model to be checkpointed. + optimizer (Optimizer, optional): optimizer to be checkpointed. + Defaults to None. + infos (dict or None): any info you want to save. + Returns: + None + """ + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(tag_or_iteration)) + + model_dict = model.state_dict() + params_path = checkpoint_path + ".pdparams" + paddle.save(model_dict, params_path) + self.logger.info("Saved model to {}".format(params_path)) + + if optimizer: + opt_dict = optimizer.state_dict() + optimizer_path = checkpoint_path + ".pdopt" + paddle.save(opt_dict, optimizer_path) + self.logger.info( + "Saved optimzier state to {}".format(optimizer_path)) + + info_path = re.sub('.pdparams$', '.json', params_path) + infos = {} if infos is None else infos + with open(info_path, 'w') as fout: + data = json.dumps(infos) + fout.write(data) diff --git a/text_processing/speechtask/punctuation_restoration/utils/default_parser.py b/text_processing/speechtask/punctuation_restoration/utils/default_parser.py new file mode 100644 index 00000000..b83d989d --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/utils/default_parser.py @@ -0,0 +1,74 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + + +def default_argument_parser(): + r"""A simple yet genral argument parser for experiments with parakeet. + + This is used in examples with parakeet. And it is intended to be used by + other experiments with parakeet. It requires a minimal set of command line + arguments to start a training script. + + The ``--config`` and ``--opts`` are used for overwrite the deault + configuration. + + The ``--data`` and ``--output`` specifies the data path and output path. + Resuming training from existing progress at the output directory is the + intended default behavior. + + The ``--checkpoint_path`` specifies the checkpoint to load from. + + The ``--device`` and ``--nprocs`` specifies how to run the training. + + + See Also + -------- + parakeet.training.experiment + Returns + ------- + argparse.ArgumentParser + the parser + """ + parser = argparse.ArgumentParser() + + # yapf: disable + # data and output + parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.") + parser.add_argument("--dump-config", metavar="FILE", help="dump config to yaml file.") + # parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.") + parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and logs.") + + # load from saved checkpoint + parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load") + + # save jit model to + parser.add_argument("--export_path", type=str, help="path of the jit model to save") + + # save asr result to + parser.add_argument("--result_file", type=str, help="path of save the asr result") + + # running + parser.add_argument("--device", type=str, default='gpu', choices=["cpu", "gpu"], + help="device type to use, cpu and gpu are supported.") + parser.add_argument("--nprocs", type=int, default=1, help="number of parallel processes to use.") + + # overwrite extra config and default config + # parser.add_argument("--opts", nargs=argparse.REMAINDER, + # help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") + parser.add_argument("--opts", type=str, default=[], nargs='+', + help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") + # yapd: enable + + return parser diff --git a/text_processing/speechtask/punctuation_restoration/utils/layer_tools.py b/text_processing/speechtask/punctuation_restoration/utils/layer_tools.py new file mode 100644 index 00000000..fb076c0c --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/utils/layer_tools.py @@ -0,0 +1,88 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from paddle import nn + +__all__ = [ + "summary", "gradient_norm", "freeze", "unfreeze", "print_grads", + "print_params" +] + + +def summary(layer: nn.Layer, print_func=print): + if print_func is None: + return + num_params = num_elements = 0 + for name, param in layer.state_dict().items(): + if print_func: + print_func( + "{} | {} | {}".format(name, param.shape, np.prod(param.shape))) + num_elements += np.prod(param.shape) + num_params += 1 + if print_func: + num_elements = num_elements / 1024**2 + print_func( + f"Total parameters: {num_params}, {num_elements:.2f}M elements.") + + +def print_grads(model, print_func=print): + if print_func is None: + return + for n, p in model.named_parameters(): + msg = f"param grad: {n}: shape: {p.shape} grad: {p.grad}" + print_func(msg) + + +def print_params(model, print_func=print): + if print_func is None: + return + total = 0.0 + num_params = 0.0 + for n, p in model.named_parameters(): + msg = f"{n} | {p.shape} | {np.prod(p.shape)} | {not p.stop_gradient}" + total += np.prod(p.shape) + num_params += 1 + if print_func: + print_func(msg) + if print_func: + total = total / 1024**2 + print_func(f"Total parameters: {num_params}, {total:.2f}M elements.") + + +def gradient_norm(layer: nn.Layer): + grad_norm_dict = {} + for name, param in layer.state_dict().items(): + if param.trainable: + grad = param.gradient() # return numpy.ndarray + grad_norm_dict[name] = np.linalg.norm(grad) / grad.size + return grad_norm_dict + + +def recursively_remove_weight_norm(layer: nn.Layer): + for layer in layer.sublayers(): + try: + nn.utils.remove_weight_norm(layer) + except ValueError as e: + # ther is not weight norm hoom in this layer + pass + + +def freeze(layer: nn.Layer): + for param in layer.parameters(): + param.trainable = False + + +def unfreeze(layer: nn.Layer): + for param in layer.parameters(): + param.trainable = True diff --git a/text_processing/speechtask/punctuation_restoration/utils/mp_tools.py b/text_processing/speechtask/punctuation_restoration/utils/mp_tools.py new file mode 100644 index 00000000..d3e25aab --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/utils/mp_tools.py @@ -0,0 +1,30 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import wraps + +from paddle import distributed as dist + +__all__ = ["rank_zero_only"] + + +def rank_zero_only(func): + @wraps(func) + def wrapper(*args, **kwargs): + rank = dist.get_rank() + if rank != 0: + return + result = func(*args, **kwargs) + return result + + return wrapper diff --git a/text_processing/speechtask/punctuation_restoration/utils/punct_pre.py b/text_processing/speechtask/punctuation_restoration/utils/punct_pre.py new file mode 100644 index 00000000..7f143182 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/utils/punct_pre.py @@ -0,0 +1,163 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import shutil + +CHINESE_PUNCTUATION_MAPPING = { + 'O': '', + ',': ",", + '。': '。', + '?': '?', +} + + +def process_one_file_chinese(raw_path, save_path): + f = open(raw_path, 'r', encoding='utf-8') + save_file = open(save_path, 'w', encoding='utf-8') + for line in f.readlines(): + line = line.strip().replace(' ', '').replace(' ', '') + for i in line: + save_file.write(i + ' ') + save_file.write('\n') + save_file.close() + + +def process_chinese_pure_senetence(config): + ####need raw_path, raw_train_file, raw_dev_file, raw_test_file, punc_file, save_path + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "raw_train_file"])), "train file doesn't exist." + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "raw_dev_file"])), "dev file doesn't exist." + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "raw_test_file"])), "test file doesn't exist." + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "punc_file"])), "punc file doesn't exist." + + train_file = os.path.join(config["raw_path"], config["raw_train_file"]) + dev_file = os.path.join(config["raw_path"], config["raw_dev_file"]) + test_file = os.path.join(config["raw_path"], config["raw_test_file"]) + if not os.path.exists(config["save_path"]): + os.makedirs(config["save_path"]) + + shutil.copy( + os.path.join(config["raw_path"], config["punc_file"]), + os.path.join(config["save_path"], config["punc_file"])) + + process_one_file_chinese(train_file, + os.path.join(config["save_path"], "train")) + process_one_file_chinese(dev_file, os.path.join(config["save_path"], "dev")) + process_one_file_chinese(test_file, + os.path.join(config["save_path"], "test")) + + +def process_one_chinese_pair(raw_path, save_path): + + f = open(raw_path, 'r', encoding='utf-8') + save_file = open(save_path, 'w', encoding='utf-8') + for line in f.readlines(): + if (len(line.strip().split()) == 2): + word, punc = line.strip().split() + save_file.write(word + ' ' + CHINESE_PUNCTUATION_MAPPING[punc]) + if (punc == "。"): + save_file.write("\n") + else: + save_file.write(" ") + save_file.close() + + +def process_chinese_pair(config): + ### need raw_path, raw_train_file, raw_dev_file, raw_test_file, punc_file, save_path + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "raw_train_file"])), "train file doesn't exist." + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "raw_dev_file"])), "dev file doesn't exist." + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "raw_test_file"])), "test file doesn't exist." + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "punc_file"])), "punc file doesn't exist." + + train_file = os.path.join(config["raw_path"], config["raw_train_file"]) + dev_file = os.path.join(config["raw_path"], config["raw_dev_file"]) + test_file = os.path.join(config["raw_path"], config["raw_test_file"]) + + process_one_chinese_pair(train_file, + os.path.join(config["save_path"], "train")) + process_one_chinese_pair(dev_file, os.path.join(config["save_path"], "dev")) + process_one_chinese_pair(test_file, + os.path.join(config["save_path"], "test")) + + shutil.copy( + os.path.join(config["raw_path"], config["punc_file"]), + os.path.join(config["save_path"], config["punc_file"])) + + +english_punc = [',', '.', '?'] +ignore_english_punc = ['\"', '/'] + + +def process_one_file_english(raw_path, save_path): + f = open(raw_path, 'r', encoding='utf-8') + save_file = open(save_path, 'w', encoding='utf-8') + for line in f.readlines(): + for i in ignore_english_punc: + line = line.replace(i, '') + for i in english_punc: + line = line.replace(i, ' ' + i) + wordlist = line.strip().split(' ') + # print(type(wordlist)) + # print(wordlist) + for i in wordlist: + save_file.write(i + ' ') + save_file.write('\n') + save_file.close() + + +def process_english_pure_senetence(config): + ####need raw_path, raw_train_file, raw_dev_file, raw_test_file, punc_file, save_path + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "raw_train_file"])), "train file doesn't exist." + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "raw_dev_file"])), "dev file doesn't exist." + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "raw_test_file"])), "test file doesn't exist." + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "punc_file"])), "punc file doesn't exist." + + train_file = os.path.join(config["raw_path"], config["raw_train_file"]) + dev_file = os.path.join(config["raw_path"], config["raw_dev_file"]) + test_file = os.path.join(config["raw_path"], config["raw_test_file"]) + if not os.path.exists(config["save_path"]): + os.makedirs(config["save_path"]) + + shutil.copy( + os.path.join(config["raw_path"], config["punc_file"]), + os.path.join(config["save_path"], config["punc_file"])) + + process_one_file_english(train_file, + os.path.join(config["save_path"], "train")) + process_one_file_english(dev_file, os.path.join(config["save_path"], "dev")) + process_one_file_english(test_file, + os.path.join(config["save_path"], "test")) diff --git a/text_processing/speechtask/punctuation_restoration/utils/utility.py b/text_processing/speechtask/punctuation_restoration/utils/utility.py new file mode 100644 index 00000000..64570026 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/utils/utility.py @@ -0,0 +1,81 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains common utility functions.""" +import distutils.util +import math +import os +from typing import List + +__all__ = ['print_arguments', 'add_arguments', "log_add"] + + +def print_arguments(args, info=None): + """Print argparse's arguments. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + parser.add_argument("name", default="Jonh", type=str, help="User name.") + args = parser.parse_args() + print_arguments(args) + + :param args: Input argparse.Namespace for printing. + :type args: argparse.Namespace + """ + filename = "" + if info: + filename = info["__file__"] + filename = os.path.basename(filename) + print(f"----------- {filename} Configuration Arguments -----------") + for arg, value in sorted(vars(args).items()): + print("%s: %s" % (arg, value)) + print("-----------------------------------------------------------") + + +def add_arguments(argname, type, default, help, argparser, **kwargs): + """Add argparse's argument. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + add_argument("name", str, "Jonh", "User name.", parser) + args = parser.parse_args() + """ + type = distutils.util.strtobool if type == bool else type + argparser.add_argument( + "--" + argname, + default=default, + type=type, + help=help + ' Default: %(default)s.', + **kwargs) + + +def log_add(args: List[int]) -> float: + """Stable log add + + Args: + args (List[int]): log scores + + Returns: + float: sum of log scores + """ + if all(a == -float('inf') for a in args): + return -float('inf') + a_max = max(args) + lsp = math.log(sum(math.exp(a - a_max) for a in args)) + return a_max + lsp