pull/1740/head
Hui Zhang 3 years ago
parent caf7225892
commit c7d9b11529

@ -12,6 +12,8 @@ exclude =
.git,
# python cache
__pycache__,
# third party
utils/compute-wer.py,
third_party/,
# Provide a comma-separate list of glob patterns to include for checks.
filename =

@ -40,6 +40,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ['ASRExecutor']
@cli_register(
name='paddlespeech.asr', description='Speech to text infer command.')
class ASRExecutor(BaseExecutor):
@ -278,7 +279,8 @@ class ASRExecutor(BaseExecutor):
self._outputs["result"] = result_transcripts[0]
elif "conformer" in model_type or "transformer" in model_type:
logger.info(f"we will use the transformer like model : {model_type}")
logger.info(
f"we will use the transformer like model : {model_type}")
try:
result_transcripts = self.model.decode(
audio,

@ -305,6 +305,7 @@ class ASRClientExecutor(BaseExecutor):
return res['asr_results']
@cli_client_register(
name='paddlespeech_client.cls', description='visit cls service')
class CLSClientExecutor(BaseExecutor):

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import paddle
from paddlespeech.cli.log import logger
from paddlespeech.s2t.utils.utility import log_add

@ -20,11 +20,11 @@ A few sklearn functions are modified in this script as per requirement.
import argparse
import copy
import warnings
from distutils.util import strtobool
import numpy as np
import scipy
import sklearn
from distutils.util import strtobool
from scipy import linalg
from scipy import sparse
from scipy.sparse.csgraph import connected_components

@ -2,6 +2,7 @@
import argparse
from collections import Counter
def main(args):
counter = Counter()
with open(args.text, 'r') as fin, open(args.lexicon, 'w') as fout:
@ -20,21 +21,16 @@ def main(args):
fout.write(f"{word}\t{val}\n")
fout.flush()
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='text(line:utt1 中国 人) to lexiconline:中国 中 国).')
parser.add_argument(
'--has_key',
default=True,
help='text path, with utt or not')
'--has_key', default=True, help='text path, with utt or not')
parser.add_argument(
'--text',
required=True,
help='text path. line: utt1 中国 人 or 中国 人')
'--text', required=True, help='text path. line: utt1 中国 人 or 中国 人')
parser.add_argument(
'--lexicon',
required=True,
help='lexicon path. line:中国 中 国')
'--lexicon', required=True, help='lexicon path. line:中国 中 国')
args = parser.parse_args()
print(args)

@ -1,13 +1,14 @@
#!/usr/bin/env python3
# modify from https://sites.google.com/site/homepageoffuyanwei/Home/remarksandexcellentdiscussion/page-2
class Word:
def __init__(self, text='', freq=0):
self.text = text
self.freq = freq
self.length = len(text)
class Chunk:
def __init__(self, w1, w2=None, w3=None):
self.words = []
@ -44,8 +45,8 @@ class Chunk:
sum += word.freq
return sum
class ComplexCompare:
class ComplexCompare:
def takeHightest(self, chunks, comparator):
i = 1
for j in range(1, len(chunks)):
@ -61,21 +62,25 @@ class ComplexCompare:
def mmFilter(self, chunks):
def comparator(a, b):
return a.totalWordLength() - b.totalWordLength()
return self.takeHightest(chunks, comparator)
def lawlFilter(self, chunks):
def comparator(a, b):
return a.averageWordLength() - b.averageWordLength()
return self.takeHightest(chunks, comparator)
def svmlFilter(self, chunks):
def comparator(a, b):
return b.standardDeviation() - a.standardDeviation()
return self.takeHightest(chunks, comparator)
def logFreqFilter(self, chunks):
def comparator(a, b):
return a.wordFrequency() - b.wordFrequency()
return self.takeHightest(chunks, comparator)
@ -83,6 +88,7 @@ class ComplexCompare:
dictWord = {}
maxWordLength = 0
def loadDictChars(filepath):
global maxWordLength
fsock = open(filepath)
@ -90,18 +96,22 @@ def loadDictChars(filepath):
freq, word = line.split()
word = word.strip()
dictWord[word] = (len(word), int(freq))
maxWordLength = len(word) if maxWordLength < len(word) else maxWordLength
maxWordLength = len(word) if maxWordLength < len(
word) else maxWordLength
fsock.close()
def loadDictWords(filepath):
global maxWordLength
fsock = open(filepath)
for line in fsock.readlines():
word = line.strip()
dictWord[word] = (len(word), 0)
maxWordLength = len(word) if maxWordLength < len(word) else maxWordLength
maxWordLength = len(word) if maxWordLength < len(
word) else maxWordLength
fsock.close()
#判断该词word是否在字典dictWord中
def getDictWord(word):
result = dictWord.get(word)
@ -109,14 +119,15 @@ def getDictWord(word):
return Word(word, result[1])
return None
#开始加载字典
def run():
from os.path import join, dirname
loadDictChars(join(dirname(__file__), 'data', 'chars.dic'))
loadDictWords(join(dirname(__file__), 'data', 'words.dic'))
class Analysis:
class Analysis:
def __init__(self, text):
self.text = text
self.cacheSize = 3
@ -134,11 +145,10 @@ class Analysis:
if not dictWord:
run()
def __iter__(self):
while True:
token = self.getNextToken()
if token == None:
if token is None:
raise StopIteration
yield token
@ -375,6 +385,8 @@ if __name__=="__main__":
cuttest(u"好人使用了它就可以解决一些问题")
cuttest(u"是因为和国家")
cuttest(u"老年搜索还支持")
cuttest(u"干脆就把那部蒙人的闲法给废了拉倒RT @laoshipukong : 27日全国人大常委会第三次审议侵权责任法草案删除了有关医疗损害责任“举证倒置”的规定。在医患纠纷中本已处于弱势地位的消费者由此将陷入万劫不复的境地。 ")
cuttest(
u"干脆就把那部蒙人的闲法给废了拉倒RT @laoshipukong : 27日全国人大常委会第三次审议侵权责任法草案删除了有关医疗损害责任“举证倒置”的规定。在医患纠纷中本已处于弱势地位的消费者由此将陷入万劫不复的境地。 "
)
cuttest("2022年12月30日是星期几")
cuttest("二零二二年十二月三十日是星期几?")

@ -26,9 +26,9 @@ import argparse
import os
import re
import subprocess
from distutils.util import strtobool
import numpy as np
from distutils.util import strtobool
FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)")
SCORED_SPEAKER_TIME = re.compile(r"(?<=SCORED SPEAKER TIME =)[\d.]+")

@ -1,15 +1,18 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# CopyRight WeNet Apache-2.0 License
import re, sys, unicodedata
import codecs
import re
import sys
import unicodedata
remove_tag = True
spacelist = [' ', '\t', '\r', '\n']
puncts = ['!', ',', '?',
'', '', '', '', '', '',
'', '', '', '', '', '', '', '']
puncts = [
'!', ',', '?', '', '', '', '', '', '', '', '', '', '', '', '',
'', ''
]
def characterize(string):
res = []
@ -43,10 +46,12 @@ def characterize(string) :
i = j
return res
def stripoff_tags(x):
if not x: return ''
chars = []
i = 0; T=len(x)
i = 0
T = len(x)
while i < T:
if x[i] == '<':
while i < T and x[i] != '>':
@ -78,6 +83,7 @@ def normalize(sentence, ignore_words, cs, split=None):
new_sentence.append(x)
return new_sentence
class Calculator:
def __init__(self):
self.data = {}
@ -87,6 +93,7 @@ class Calculator :
self.cost['sub'] = 1
self.cost['del'] = 1
self.cost['ins'] = 1
def calculate(self, lab, rec):
# Initialization
lab.insert(0, '')
@ -108,10 +115,22 @@ class Calculator :
self.space[0][0]['error'] = 'non'
for token in lab:
if token not in self.data and len(token) > 0:
self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0}
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
for token in rec:
if token not in self.data and len(token) > 0:
self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0}
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
# Computing edit distance
for i, lab_token in enumerate(lab):
for j, rec_token in enumerate(rec):
@ -141,7 +160,15 @@ class Calculator :
self.space[i][j]['dist'] = min_dist
self.space[i][j]['error'] = min_error
# Tracing back
result = {'lab':[], 'rec':[], 'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0}
result = {
'lab': [],
'rec': [],
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
i = len(lab) - 1
j = len(rec) - 1
while True:
@ -184,8 +211,11 @@ class Calculator :
elif self.space[i][j]['error'] == 'non': # starting point
break
else: # shouldn't reach here
print('this should not happen , i = {i} , j = {j} , error = {error}'.format(i = i, j = j, error = self.space[i][j]['error']))
print(
'this should not happen , i = {i} , j = {j} , error = {error}'.
format(i=i, j=j, error=self.space[i][j]['error']))
return result
def overall(self):
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
for token in self.data:
@ -195,6 +225,7 @@ class Calculator :
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def cluster(self, data):
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
for token in data:
@ -205,12 +236,15 @@ class Calculator :
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def keys(self):
return list(self.data.keys())
def width(string):
return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
def default_cluster(word):
unicode_names = [unicodedata.name(char) for char in word]
for i in reversed(range(len(unicode_names))):
@ -250,9 +284,15 @@ def default_cluster(word) :
return 'Other'
return unicode_names[0]
def usage():
print("compute-wer.py : compute word error rate (WER) and align recognition results and references.")
print(" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer")
print(
"compute-wer.py : compute word error rate (WER) and align recognition results and references."
)
print(
" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer"
)
if __name__ == '__main__':
if len(sys.argv) == 1:
@ -370,7 +410,8 @@ if __name__ == '__main__':
array = line.strip().split()
if len(array) == 0: continue
fid = array[0]
rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, split)
rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive,
split)
# compute error rate on the interaction of reference file and hyp file
for line in open(ref_file, 'r', encoding='utf-8'):
@ -399,12 +440,14 @@ if __name__ == '__main__':
result = calculator.calculate(lab, rec)
if verbose:
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
wer = float(result['ins'] + result['sub'] + result[
'del']) * 100.0 / result['all']
else:
wer = 0.0
print('WER: %4.2f %%' % wer, end=' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
(result['all'], result['cor'], result['sub'], result['del'],
result['ins']))
space = {}
space['lab'] = []
space['rec'] = []
@ -446,30 +489,37 @@ if __name__ == '__main__':
rec1 = rec2
if verbose:
print('===========================================================================')
print(
'==========================================================================='
)
print()
result = calculator.overall()
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
wer = float(result['ins'] + result['sub'] + result[
'del']) * 100.0 / result['all']
else:
wer = 0.0
print('Overall -> %4.2f %%' % wer, end=' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
(result['all'], result['cor'], result['sub'], result['del'],
result['ins']))
if not verbose:
print()
if verbose:
for cluster_id in default_clusters:
result = calculator.cluster([ k for k in default_clusters[cluster_id] ])
result = calculator.cluster(
[k for k in default_clusters[cluster_id]])
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
wer = float(result['ins'] + result['sub'] + result[
'del']) * 100.0 / result['all']
else:
wer = 0.0
print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
(result['all'], result['cor'], result['sub'], result['del'],
result['ins']))
if len(cluster_file) > 0: # compute separated WERs for word clusters
cluster_id = ''
cluster = []
@ -480,12 +530,14 @@ if __name__ == '__main__':
token.lstrip('</').rstrip('>') == cluster_id :
result = calculator.cluster(cluster)
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
wer = float(result['ins'] + result['sub'] + result[
'del']) * 100.0 / result['all']
else:
wer = 0.0
print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
(result['all'], result['cor'], result['sub'],
result['del'], result['ins']))
cluster_id = ''
cluster = []
# begin of cluster reached, like <Keyword>
@ -497,4 +549,6 @@ if __name__ == '__main__':
else:
cluster.append(token)
print()
print('===========================================================================')
print(
'==========================================================================='
)

@ -1,11 +1,21 @@
import os
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import jsonlines
def trans_hyp(origin_hyp,
trans_hyp = None,
trans_hyp_sclite = None):
def trans_hyp(origin_hyp, trans_hyp=None, trans_hyp_sclite=None):
"""
Args:
origin_hyp: The input json file which contains the model output
@ -27,9 +37,8 @@ def trans_hyp(origin_hyp,
line = input_dict[key] + "(" + key + ".wav" + ")" + "\n"
f.write(line)
def trans_ref(origin_ref,
trans_ref = None,
trans_ref_sclite = None):
def trans_ref(origin_ref, trans_ref=None, trans_ref_sclite=None):
"""
Args:
origin_hyp: The input json file which contains the model output
@ -53,28 +62,34 @@ def trans_ref(origin_ref,
f.write(line)
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='format hyp file for compute CER/WER', add_help=True)
parser = argparse.ArgumentParser(
prog='format hyp file for compute CER/WER', add_help=True)
parser.add_argument(
'--origin_hyp', type=str, default=None, help='origin hyp file')
parser.add_argument(
'--origin_hyp',
'--trans_hyp',
type=str,
default=None,
help='origin hyp file')
help='hyp file for caculating CER/WER')
parser.add_argument(
'--trans_hyp', type=str, default = None, help='hyp file for caculating CER/WER')
parser.add_argument(
'--trans_hyp_sclite', type=str, default = None, help='hyp file for caculating CER/WER by sclite')
'--trans_hyp_sclite',
type=str,
default=None,
help='hyp file for caculating CER/WER by sclite')
parser.add_argument(
'--origin_ref',
'--origin_ref', type=str, default=None, help='origin ref file')
parser.add_argument(
'--trans_ref',
type=str,
default=None,
help='origin ref file')
parser.add_argument(
'--trans_ref', type=str, default = None, help='ref file for caculating CER/WER')
help='ref file for caculating CER/WER')
parser.add_argument(
'--trans_ref_sclite', type=str, default = None, help='ref file for caculating CER/WER by sclite')
'--trans_ref_sclite',
type=str,
default=None,
help='ref file for caculating CER/WER by sclite')
parser_args = parser.parse_args()
if parser_args.origin_hyp is not None:

@ -82,7 +82,10 @@ def main(args):
lexicon_table.add(word)
out_n += 1
print(f"Filter lexicon by unit table: filter out {in_n - out_n}, {out_n}/{in_n}")
print(
f"Filter lexicon by unit table: filter out {in_n - out_n}, {out_n}/{in_n}"
)
if __name__ == '__main__':
parser = argparse.ArgumentParser(

Loading…
Cancel
Save