You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/text_processing/speechtask/punctuation_restoration/io/dataset.py

311 lines
11 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import numpy as np
import paddle
from paddle.io import Dataset
from paddlenlp.transformers import BertTokenizer
# from speechtask.punctuation_restoration.utils.punct_prepro import load_dataset
__all__ = ["PuncDataset", "PuncDatasetFromBertTokenizer"]
class PuncDataset(Dataset):
"""Representing a Dataset
superclass
----------
data.Dataset :
Dataset is a abstract class, representing the real data.
"""
def __init__(self, train_path, vocab_path, punc_path, seq_len=100):
# 检查文件是否存在
print(train_path)
print(vocab_path)
assert os.path.exists(train_path), "train文件不存在"
assert os.path.exists(vocab_path), "词典文件不存在"
assert os.path.exists(punc_path), "标点文件不存在"
self.seq_len = seq_len
self.word2id = self.load_vocab(
vocab_path, extra_word_list=['<UNK>', '<END>'])
self.id2word = {v: k for k, v in self.word2id.items()}
self.punc2id = self.load_vocab(punc_path, extra_word_list=[" "])
self.id2punc = {k: v for (v, k) in self.punc2id.items()}
tmp_seqs = open(train_path, encoding='utf-8').readlines()
self.txt_seqs = [i for seq in tmp_seqs for i in seq.split()]
# print(self.txt_seqs[:10])
# with open('./txt_seq', 'w', encoding='utf-8') as w:
# print(self.txt_seqs, file=w)
self.preprocess(self.txt_seqs)
print('---punc-')
print(self.punc2id)
def __len__(self):
"""return the sentence nums in .txt
"""
return self.in_len
def __getitem__(self, index):
"""返回指定索引的张量对 (输入文本id的序列 , 其对应的标点id序列)
Parameters
----------
index : int 索引
"""
return self.input_data[index], self.label[index]
def load_vocab(self, vocab_path, extra_word_list=[], encoding='utf-8'):
n = len(extra_word_list)
with open(vocab_path, encoding='utf-8') as vf:
vocab = {word.strip(): i + n for i, word in enumerate(vf)}
for i, word in enumerate(extra_word_list):
vocab[word] = i
return vocab
def preprocess(self, txt_seqs: list):
"""将文本转为单词和应预测标点的id pair
Parameters
----------
txt : 文本
文本每个单词跟随一个空格,符号也跟一个空格
"""
input_data = []
label = []
input_r = []
label_r = []
# txt_seqs is a list like: ['char', 'char', 'char', '**', 'char', ......]
count = 0
length = len(txt_seqs)
for token in txt_seqs:
count += 1
if count == length:
break
if token in self.punc2id:
continue
punc = txt_seqs[count]
if punc not in self.punc2id:
# print('标点{}'.format(count), self.punc2id[" "])
label.append(self.punc2id[" "])
input_data.append(
self.word2id.get(token, self.word2id["<UNK>"]))
input_r.append(token)
label_r.append(' ')
else:
# print('标点{}'.format(count), self.punc2id[punc])
label.append(self.punc2id[punc])
input_data.append(
self.word2id.get(token, self.word2id["<UNK>"]))
input_r.append(token)
label_r.append(punc)
if len(input_data) != len(label):
assert 'error: length input_data != label'
# code below is for using 100 as a hidden size
print(len(input_data))
self.in_len = len(input_data) // self.seq_len
len_tmp = self.in_len * self.seq_len
input_data = input_data[:len_tmp]
label = label[:len_tmp]
self.input_data = paddle.to_tensor(
np.array(input_data, dtype='int64').reshape(-1, self.seq_len))
self.label = paddle.to_tensor(
np.array(label, dtype='int64').reshape(-1, self.seq_len))
# unk_token='[UNK]'
# sep_token='[SEP]'
# pad_token='[PAD]'
# cls_token='[CLS]'
# mask_token='[MASK]'
class PuncDatasetFromBertTokenizer(Dataset):
"""Representing a Dataset
superclass
----------
data.Dataset :
Dataset is a abstract class, representing the real data.
"""
def __init__(self,
train_path,
is_eval,
pretrained_token,
punc_path,
seq_len=100):
# 检查文件是否存在
print(train_path)
self.tokenizer = BertTokenizer.from_pretrained(
pretrained_token, do_lower_case=True)
self.paddingID = self.tokenizer.pad_token_id
assert os.path.exists(train_path), "train文件不存在"
assert os.path.exists(punc_path), "标点文件不存在"
self.seq_len = seq_len
self.punc2id = self.load_vocab(punc_path, extra_word_list=[" "])
self.id2punc = {k: v for (v, k) in self.punc2id.items()}
tmp_seqs = open(train_path, encoding='utf-8').readlines()
self.txt_seqs = [i for seq in tmp_seqs for i in seq.split()]
# print(self.txt_seqs[:10])
# with open('./txt_seq', 'w', encoding='utf-8') as w:
# print(self.txt_seqs, file=w)
if (is_eval):
self.preprocess(self.txt_seqs)
else:
self.preprocess_shift(self.txt_seqs)
print("data len: %d" % (len(self.input_data)))
print('---punc-')
print(self.punc2id)
def __len__(self):
"""return the sentence nums in .txt
"""
return self.in_len
def __getitem__(self, index):
"""返回指定索引的张量对 (输入文本id的序列 , 其对应的标点id序列)
Parameters
----------
index : int 索引
"""
return self.input_data[index], self.label[index]
def load_vocab(self, vocab_path, extra_word_list=[], encoding='utf-8'):
n = len(extra_word_list)
with open(vocab_path, encoding='utf-8') as vf:
vocab = {word.strip(): i + n for i, word in enumerate(vf)}
for i, word in enumerate(extra_word_list):
vocab[word] = i
return vocab
def preprocess(self, txt_seqs: list):
"""将文本转为单词和应预测标点的id pair
Parameters
----------
txt : 文本
文本每个单词跟随一个空格,符号也跟一个空格
"""
input_data = []
label = []
# txt_seqs is a list like: ['char', 'char', 'char', '**', 'char', ......]
count = 0
for i in range(len(txt_seqs) - 1):
word = txt_seqs[i]
punc = txt_seqs[i + 1]
if word in self.punc2id:
continue
token = self.tokenizer(word)
x = token["input_ids"][1:-1]
input_data.extend(x)
for i in range(len(x) - 1):
label.append(self.punc2id[" "])
if punc not in self.punc2id:
# print('标点{}'.format(count), self.punc2id[" "])
label.append(self.punc2id[" "])
else:
label.append(self.punc2id[punc])
if len(input_data) != len(label):
assert 'error: length input_data != label'
# code below is for using 100 as a hidden size
# print(len(input_data[0]))
# print(len(label))
self.in_len = len(input_data) // self.seq_len
len_tmp = self.in_len * self.seq_len
input_data = input_data[:len_tmp]
label = label[:len_tmp]
# # print(input_data)
# print(type(input_data))
# tmp=np.array(input_data)
# print('--~~~~~~~~~~~~~')
# print(type(tmp))
# print(tmp.shape)
self.input_data = paddle.to_tensor(
np.array(input_data, dtype='int64').reshape(
-1, self.seq_len)) #, dtype='int64'
self.label = paddle.to_tensor(
np.array(label, dtype='int64').reshape(
-1, self.seq_len)) #, dtype='int64'
def preprocess_shift(self, txt_seqs: list):
"""将文本转为单词和应预测标点的id pair
Parameters
----------
txt : 文本
文本每个单词跟随一个空格,符号也跟一个空格
"""
input_data = []
label = []
# txt_seqs is a list like: ['char', 'char', 'char', '**', 'char', ......]
count = 0
for i in range(len(txt_seqs) - 1):
word = txt_seqs[i]
punc = txt_seqs[i + 1]
if word in self.punc2id:
continue
token = self.tokenizer(word)
x = token["input_ids"][1:-1]
input_data.extend(x)
for i in range(len(x) - 1):
label.append(self.punc2id[" "])
if punc not in self.punc2id:
# print('标点{}'.format(count), self.punc2id[" "])
label.append(self.punc2id[" "])
else:
label.append(self.punc2id[punc])
if len(input_data) != len(label):
assert 'error: length input_data != label'
# print(len(input_data[0]))
# print(len(label))
start = 0
processed_data = []
processed_label = []
while (start < len(input_data) - self.seq_len):
# end=start+self.seq_len
end = random.randint(start + self.seq_len // 2,
start + self.seq_len)
processed_data.append(input_data[start:end])
processed_label.append(label[start:end])
start = start + random.randint(1, self.seq_len // 2)
self.in_len = len(processed_data)
# # print(input_data)
# print(type(input_data))
# tmp=np.array(input_data)
# print('--~~~~~~~~~~~~~')
# print(type(tmp))
# print(tmp.shape)
self.input_data = processed_data
#paddle.to_tensor(np.array(processed_data, dtype='int64')) #, dtype='int64'
self.label = processed_label
#paddle.to_tensor(np.array(processed_label, dtype='int64')) #, dtype='int64'
if __name__ == '__main__':
dataset = PuncDataset()