pull/1740/head
Hui Zhang 3 years ago
parent caf7225892
commit c7d9b11529

@ -12,6 +12,8 @@ exclude =
.git, .git,
# python cache # python cache
__pycache__, __pycache__,
# third party
utils/compute-wer.py,
third_party/, third_party/,
# Provide a comma-separate list of glob patterns to include for checks. # Provide a comma-separate list of glob patterns to include for checks.
filename = filename =

@ -40,6 +40,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ['ASRExecutor'] __all__ = ['ASRExecutor']
@cli_register( @cli_register(
name='paddlespeech.asr', description='Speech to text infer command.') name='paddlespeech.asr', description='Speech to text infer command.')
class ASRExecutor(BaseExecutor): class ASRExecutor(BaseExecutor):
@ -278,7 +279,8 @@ class ASRExecutor(BaseExecutor):
self._outputs["result"] = result_transcripts[0] self._outputs["result"] = result_transcripts[0]
elif "conformer" in model_type or "transformer" in model_type: elif "conformer" in model_type or "transformer" in model_type:
logger.info(f"we will use the transformer like model : {model_type}") logger.info(
f"we will use the transformer like model : {model_type}")
try: try:
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
audio, audio,

@ -305,6 +305,7 @@ class ASRClientExecutor(BaseExecutor):
return res['asr_results'] return res['asr_results']
@cli_client_register( @cli_client_register(
name='paddlespeech_client.cls', description='visit cls service') name='paddlespeech_client.cls', description='visit cls service')
class CLSClientExecutor(BaseExecutor): class CLSClientExecutor(BaseExecutor):

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
import paddle import paddle
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.s2t.utils.utility import log_add from paddlespeech.s2t.utils.utility import log_add

@ -36,7 +36,7 @@ class ASRAudioHandler:
x_len = len(samples) x_len = len(samples)
chunk_size = 85 * 16 #80ms, sample_rate = 16kHz chunk_size = 85 * 16 #80ms, sample_rate = 16kHz
if x_len % chunk_size!= 0: if x_len % chunk_size != 0:
padding_len_x = chunk_size - x_len % chunk_size padding_len_x = chunk_size - x_len % chunk_size
else: else:
padding_len_x = 0 padding_len_x = 0

@ -20,11 +20,11 @@ A few sklearn functions are modified in this script as per requirement.
import argparse import argparse
import copy import copy
import warnings import warnings
from distutils.util import strtobool
import numpy as np import numpy as np
import scipy import scipy
import sklearn import sklearn
from distutils.util import strtobool
from scipy import linalg from scipy import linalg
from scipy import sparse from scipy import sparse
from scipy.sparse.csgraph import connected_components from scipy.sparse.csgraph import connected_components

@ -2,6 +2,7 @@
import argparse import argparse
from collections import Counter from collections import Counter
def main(args): def main(args):
counter = Counter() counter = Counter()
with open(args.text, 'r') as fin, open(args.lexicon, 'w') as fout: with open(args.text, 'r') as fin, open(args.lexicon, 'w') as fout:
@ -20,21 +21,16 @@ def main(args):
fout.write(f"{word}\t{val}\n") fout.write(f"{word}\t{val}\n")
fout.flush() fout.flush()
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='text(line:utt1 中国 人) to lexiconline:中国 中 国).') description='text(line:utt1 中国 人) to lexiconline:中国 中 国).')
parser.add_argument( parser.add_argument(
'--has_key', '--has_key', default=True, help='text path, with utt or not')
default=True,
help='text path, with utt or not')
parser.add_argument( parser.add_argument(
'--text', '--text', required=True, help='text path. line: utt1 中国 人 or 中国 人')
required=True,
help='text path. line: utt1 中国 人 or 中国 人')
parser.add_argument( parser.add_argument(
'--lexicon', '--lexicon', required=True, help='lexicon path. line:中国 中 国')
required=True,
help='lexicon path. line:中国 中 国')
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)

@ -1,15 +1,16 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# modify from https://sites.google.com/site/homepageoffuyanwei/Home/remarksandexcellentdiscussion/page-2 # modify from https://sites.google.com/site/homepageoffuyanwei/Home/remarksandexcellentdiscussion/page-2
class Word: class Word:
def __init__(self,text = '',freq = 0): def __init__(self, text='', freq=0):
self.text = text self.text = text
self.freq = freq self.freq = freq
self.length = len(text) self.length = len(text)
class Chunk: class Chunk:
def __init__(self,w1,w2 = None,w3 = None): def __init__(self, w1, w2=None, w3=None):
self.words = [] self.words = []
self.words.append(w1) self.words.append(w1)
if w2: if w2:
@ -44,8 +45,8 @@ class Chunk:
sum += word.freq sum += word.freq
return sum return sum
class ComplexCompare:
class ComplexCompare:
def takeHightest(self, chunks, comparator): def takeHightest(self, chunks, comparator):
i = 1 i = 1
for j in range(1, len(chunks)): for j in range(1, len(chunks)):
@ -59,23 +60,27 @@ class ComplexCompare:
#以下四个函数是mmseg算法的四种过滤原则核心算法 #以下四个函数是mmseg算法的四种过滤原则核心算法
def mmFilter(self, chunks): def mmFilter(self, chunks):
def comparator(a,b): def comparator(a, b):
return a.totalWordLength() - b.totalWordLength() return a.totalWordLength() - b.totalWordLength()
return self.takeHightest(chunks, comparator) return self.takeHightest(chunks, comparator)
def lawlFilter(self,chunks): def lawlFilter(self, chunks):
def comparator(a,b): def comparator(a, b):
return a.averageWordLength() - b.averageWordLength() return a.averageWordLength() - b.averageWordLength()
return self.takeHightest(chunks,comparator)
def svmlFilter(self,chunks): return self.takeHightest(chunks, comparator)
def comparator(a,b):
def svmlFilter(self, chunks):
def comparator(a, b):
return b.standardDeviation() - a.standardDeviation() return b.standardDeviation() - a.standardDeviation()
return self.takeHightest(chunks, comparator) return self.takeHightest(chunks, comparator)
def logFreqFilter(self,chunks): def logFreqFilter(self, chunks):
def comparator(a,b): def comparator(a, b):
return a.wordFrequency() - b.wordFrequency() return a.wordFrequency() - b.wordFrequency()
return self.takeHightest(chunks, comparator) return self.takeHightest(chunks, comparator)
@ -83,6 +88,7 @@ class ComplexCompare:
dictWord = {} dictWord = {}
maxWordLength = 0 maxWordLength = 0
def loadDictChars(filepath): def loadDictChars(filepath):
global maxWordLength global maxWordLength
fsock = open(filepath) fsock = open(filepath)
@ -90,18 +96,22 @@ def loadDictChars(filepath):
freq, word = line.split() freq, word = line.split()
word = word.strip() word = word.strip()
dictWord[word] = (len(word), int(freq)) dictWord[word] = (len(word), int(freq))
maxWordLength = len(word) if maxWordLength < len(word) else maxWordLength maxWordLength = len(word) if maxWordLength < len(
word) else maxWordLength
fsock.close() fsock.close()
def loadDictWords(filepath): def loadDictWords(filepath):
global maxWordLength global maxWordLength
fsock = open(filepath) fsock = open(filepath)
for line in fsock.readlines(): for line in fsock.readlines():
word = line.strip() word = line.strip()
dictWord[word] = (len(word), 0) dictWord[word] = (len(word), 0)
maxWordLength = len(word) if maxWordLength < len(word) else maxWordLength maxWordLength = len(word) if maxWordLength < len(
word) else maxWordLength
fsock.close() fsock.close()
#判断该词word是否在字典dictWord中 #判断该词word是否在字典dictWord中
def getDictWord(word): def getDictWord(word):
result = dictWord.get(word) result = dictWord.get(word)
@ -109,14 +119,15 @@ def getDictWord(word):
return Word(word, result[1]) return Word(word, result[1])
return None return None
#开始加载字典 #开始加载字典
def run(): def run():
from os.path import join, dirname from os.path import join, dirname
loadDictChars(join(dirname(__file__), 'data', 'chars.dic')) loadDictChars(join(dirname(__file__), 'data', 'chars.dic'))
loadDictWords(join(dirname(__file__), 'data', 'words.dic')) loadDictWords(join(dirname(__file__), 'data', 'words.dic'))
class Analysis:
class Analysis:
def __init__(self, text): def __init__(self, text):
self.text = text self.text = text
self.cacheSize = 3 self.cacheSize = 3
@ -134,11 +145,10 @@ class Analysis:
if not dictWord: if not dictWord:
run() run()
def __iter__(self): def __iter__(self):
while True: while True:
token = self.getNextToken() token = self.getNextToken()
if token == None: if token is None:
raise StopIteration raise StopIteration
yield token yield token
@ -146,7 +156,7 @@ class Analysis:
return self.text[self.pos] return self.text[self.pos]
#判断该字符是否是中文字符(不包括中文标点) #判断该字符是否是中文字符(不包括中文标点)
def isChineseChar(self,charater): def isChineseChar(self, charater):
return 0x4e00 <= ord(charater) < 0x9fa6 return 0x4e00 <= ord(charater) < 0x9fa6
#判断是否是ASCII码 #判断是否是ASCII码
@ -163,8 +173,8 @@ class Analysis:
while self.pos < self.textLength: while self.pos < self.textLength:
if self.isChineseChar(self.getNextChar()): if self.isChineseChar(self.getNextChar()):
token = self.getChineseWords() token = self.getChineseWords()
else : else:
token = self.getASCIIWords()+'/' token = self.getASCIIWords() + '/'
if len(token) > 0: if len(token) > 0:
return token return token
return None return None
@ -211,7 +221,7 @@ class Analysis:
chunks = self.complexCompare.svmlFilter(chunks) chunks = self.complexCompare.svmlFilter(chunks)
if len(chunks) > 1: if len(chunks) > 1:
chunks = self.complexCompare.logFreqFilter(chunks) chunks = self.complexCompare.logFreqFilter(chunks)
if len(chunks) == 0 : if len(chunks) == 0:
return '' return ''
#最后只有一种切割方法 #最后只有一种切割方法
@ -242,13 +252,13 @@ class Analysis:
for word3 in words3: for word3 in words3:
# print(word3.length, word3.text) # print(word3.length, word3.text)
if word3.length == -1: if word3.length == -1:
chunk = Chunk(word1,word2) chunk = Chunk(word1, word2)
# print("Ture") # print("Ture")
else : else:
chunk = Chunk(word1,word2,word3) chunk = Chunk(word1, word2, word3)
chunks.append(chunk) chunks.append(chunk)
elif self.pos == self.textLength: elif self.pos == self.textLength:
chunks.append(Chunk(word1,word2)) chunks.append(Chunk(word1, word2))
self.pos -= len(word2.text) self.pos -= len(word2.text)
elif self.pos == self.textLength: elif self.pos == self.textLength:
chunks.append(Chunk(word1)) chunks.append(Chunk(word1))
@ -268,7 +278,7 @@ class Analysis:
words = [] words = []
index = 0 index = 0
while self.pos < self.textLength: while self.pos < self.textLength:
if index >= maxWordLength : if index >= maxWordLength:
break break
if not self.isChineseChar(self.getNextChar()): if not self.isChineseChar(self.getNextChar()):
break break
@ -288,18 +298,18 @@ class Analysis:
word.text = 'X' word.text = 'X'
words.append(word) words.append(word)
self.cache[self.cacheIndex] = (self.pos,words) self.cache[self.cacheIndex] = (self.pos, words)
self.cacheIndex += 1 self.cacheIndex += 1
if self.cacheIndex >= self.cacheSize: if self.cacheIndex >= self.cacheSize:
self.cacheIndex = 0 self.cacheIndex = 0
return words return words
if __name__=="__main__": if __name__ == "__main__":
def cuttest(text): def cuttest(text):
#cut = Analysis(text) #cut = Analysis(text)
tmp="" tmp = ""
try: try:
for word in iter(Analysis(text)): for word in iter(Analysis(text)):
tmp += word tmp += word
@ -375,6 +385,8 @@ if __name__=="__main__":
cuttest(u"好人使用了它就可以解决一些问题") cuttest(u"好人使用了它就可以解决一些问题")
cuttest(u"是因为和国家") cuttest(u"是因为和国家")
cuttest(u"老年搜索还支持") cuttest(u"老年搜索还支持")
cuttest(u"干脆就把那部蒙人的闲法给废了拉倒RT @laoshipukong : 27日全国人大常委会第三次审议侵权责任法草案删除了有关医疗损害责任“举证倒置”的规定。在医患纠纷中本已处于弱势地位的消费者由此将陷入万劫不复的境地。 ") cuttest(
u"干脆就把那部蒙人的闲法给废了拉倒RT @laoshipukong : 27日全国人大常委会第三次审议侵权责任法草案删除了有关医疗损害责任“举证倒置”的规定。在医患纠纷中本已处于弱势地位的消费者由此将陷入万劫不复的境地。 "
)
cuttest("2022年12月30日是星期几") cuttest("2022年12月30日是星期几")
cuttest("二零二二年十二月三十日是星期几?") cuttest("二零二二年十二月三十日是星期几?")

@ -26,9 +26,9 @@ import argparse
import os import os
import re import re
import subprocess import subprocess
from distutils.util import strtobool
import numpy as np import numpy as np
from distutils.util import strtobool
FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)") FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)")
SCORED_SPEAKER_TIME = re.compile(r"(?<=SCORED SPEAKER TIME =)[\d.]+") SCORED_SPEAKER_TIME = re.compile(r"(?<=SCORED SPEAKER TIME =)[\d.]+")

@ -1,17 +1,20 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# CopyRight WeNet Apache-2.0 License # CopyRight WeNet Apache-2.0 License
import re, sys, unicodedata
import codecs import codecs
import re
import sys
import unicodedata
remove_tag = True remove_tag = True
spacelist= [' ', '\t', '\r', '\n'] spacelist = [' ', '\t', '\r', '\n']
puncts = ['!', ',', '?', puncts = [
'', '', '', '', '', '', '!', ',', '?', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', ''] '', ''
]
def characterize(string) : def characterize(string):
res = [] res = []
i = 0 i = 0
while i < len(string): while i < len(string):
@ -31,10 +34,10 @@ def characterize(string) :
# some input looks like: <unk><noise>, we want to separate it to two words. # some input looks like: <unk><noise>, we want to separate it to two words.
sep = ' ' sep = ' '
if char == '<': sep = '>' if char == '<': sep = '>'
j = i+1 j = i + 1
while j < len(string): while j < len(string):
c = string[j] c = string[j]
if ord(c) >= 128 or (c in spacelist) or (c==sep): if ord(c) >= 128 or (c in spacelist) or (c == sep):
break break
j += 1 j += 1
if j < len(string) and string[j] == '>': if j < len(string) and string[j] == '>':
@ -43,10 +46,12 @@ def characterize(string) :
i = j i = j
return res return res
def stripoff_tags(x): def stripoff_tags(x):
if not x: return '' if not x: return ''
chars = [] chars = []
i = 0; T=len(x) i = 0
T = len(x)
while i < T: while i < T:
if x[i] == '<': if x[i] == '<':
while i < T and x[i] != '>': while i < T and x[i] != '>':
@ -78,8 +83,9 @@ def normalize(sentence, ignore_words, cs, split=None):
new_sentence.append(x) new_sentence.append(x)
return new_sentence return new_sentence
class Calculator :
def __init__(self) : class Calculator:
def __init__(self):
self.data = {} self.data = {}
self.space = [] self.space = []
self.cost = {} self.cost = {}
@ -87,66 +93,87 @@ class Calculator :
self.cost['sub'] = 1 self.cost['sub'] = 1
self.cost['del'] = 1 self.cost['del'] = 1
self.cost['ins'] = 1 self.cost['ins'] = 1
def calculate(self, lab, rec) :
def calculate(self, lab, rec):
# Initialization # Initialization
lab.insert(0, '') lab.insert(0, '')
rec.insert(0, '') rec.insert(0, '')
while len(self.space) < len(lab) : while len(self.space) < len(lab):
self.space.append([]) self.space.append([])
for row in self.space : for row in self.space:
for element in row : for element in row:
element['dist'] = 0 element['dist'] = 0
element['error'] = 'non' element['error'] = 'non'
while len(row) < len(rec) : while len(row) < len(rec):
row.append({'dist' : 0, 'error' : 'non'}) row.append({'dist': 0, 'error': 'non'})
for i in range(len(lab)) : for i in range(len(lab)):
self.space[i][0]['dist'] = i self.space[i][0]['dist'] = i
self.space[i][0]['error'] = 'del' self.space[i][0]['error'] = 'del'
for j in range(len(rec)) : for j in range(len(rec)):
self.space[0][j]['dist'] = j self.space[0][j]['dist'] = j
self.space[0][j]['error'] = 'ins' self.space[0][j]['error'] = 'ins'
self.space[0][0]['error'] = 'non' self.space[0][0]['error'] = 'non'
for token in lab : for token in lab:
if token not in self.data and len(token) > 0 : if token not in self.data and len(token) > 0:
self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0} self.data[token] = {
for token in rec : 'all': 0,
if token not in self.data and len(token) > 0 : 'cor': 0,
self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0} 'sub': 0,
'ins': 0,
'del': 0
}
for token in rec:
if token not in self.data and len(token) > 0:
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
# Computing edit distance # Computing edit distance
for i, lab_token in enumerate(lab) : for i, lab_token in enumerate(lab):
for j, rec_token in enumerate(rec) : for j, rec_token in enumerate(rec):
if i == 0 or j == 0 : if i == 0 or j == 0:
continue continue
min_dist = sys.maxsize min_dist = sys.maxsize
min_error = 'none' min_error = 'none'
dist = self.space[i-1][j]['dist'] + self.cost['del'] dist = self.space[i - 1][j]['dist'] + self.cost['del']
error = 'del' error = 'del'
if dist < min_dist : if dist < min_dist:
min_dist = dist min_dist = dist
min_error = error min_error = error
dist = self.space[i][j-1]['dist'] + self.cost['ins'] dist = self.space[i][j - 1]['dist'] + self.cost['ins']
error = 'ins' error = 'ins'
if dist < min_dist : if dist < min_dist:
min_dist = dist min_dist = dist
min_error = error min_error = error
if lab_token == rec_token : if lab_token == rec_token:
dist = self.space[i-1][j-1]['dist'] + self.cost['cor'] dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor']
error = 'cor' error = 'cor'
else : else:
dist = self.space[i-1][j-1]['dist'] + self.cost['sub'] dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub']
error = 'sub' error = 'sub'
if dist < min_dist : if dist < min_dist:
min_dist = dist min_dist = dist
min_error = error min_error = error
self.space[i][j]['dist'] = min_dist self.space[i][j]['dist'] = min_dist
self.space[i][j]['error'] = min_error self.space[i][j]['error'] = min_error
# Tracing back # Tracing back
result = {'lab':[], 'rec':[], 'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} result = {
'lab': [],
'rec': [],
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
i = len(lab) - 1 i = len(lab) - 1
j = len(rec) - 1 j = len(rec) - 1
while True : while True:
if self.space[i][j]['error'] == 'cor' : # correct if self.space[i][j]['error'] == 'cor': # correct
if len(lab[i]) > 0 : if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
result['all'] = result['all'] + 1 result['all'] = result['all'] + 1
@ -155,8 +182,8 @@ class Calculator :
result['rec'].insert(0, rec[j]) result['rec'].insert(0, rec[j])
i = i - 1 i = i - 1
j = j - 1 j = j - 1
elif self.space[i][j]['error'] == 'sub' : # substitution elif self.space[i][j]['error'] == 'sub': # substitution
if len(lab[i]) > 0 : if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
result['all'] = result['all'] + 1 result['all'] = result['all'] + 1
@ -165,8 +192,8 @@ class Calculator :
result['rec'].insert(0, rec[j]) result['rec'].insert(0, rec[j])
i = i - 1 i = i - 1
j = j - 1 j = j - 1
elif self.space[i][j]['error'] == 'del' : # deletion elif self.space[i][j]['error'] == 'del': # deletion
if len(lab[i]) > 0 : if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
result['all'] = result['all'] + 1 result['all'] = result['all'] + 1
@ -174,57 +201,64 @@ class Calculator :
result['lab'].insert(0, lab[i]) result['lab'].insert(0, lab[i])
result['rec'].insert(0, "") result['rec'].insert(0, "")
i = i - 1 i = i - 1
elif self.space[i][j]['error'] == 'ins' : # insertion elif self.space[i][j]['error'] == 'ins': # insertion
if len(rec[j]) > 0 : if len(rec[j]) > 0:
self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
result['ins'] = result['ins'] + 1 result['ins'] = result['ins'] + 1
result['lab'].insert(0, "") result['lab'].insert(0, "")
result['rec'].insert(0, rec[j]) result['rec'].insert(0, rec[j])
j = j - 1 j = j - 1
elif self.space[i][j]['error'] == 'non' : # starting point elif self.space[i][j]['error'] == 'non': # starting point
break break
else : # shouldn't reach here else: # shouldn't reach here
print('this should not happen , i = {i} , j = {j} , error = {error}'.format(i = i, j = j, error = self.space[i][j]['error'])) print(
'this should not happen , i = {i} , j = {j} , error = {error}'.
format(i=i, j=j, error=self.space[i][j]['error']))
return result return result
def overall(self) :
result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} def overall(self):
for token in self.data : result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
for token in self.data:
result['all'] = result['all'] + self.data[token]['all'] result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor'] result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub'] result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins'] result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del'] result['del'] = result['del'] + self.data[token]['del']
return result return result
def cluster(self, data) :
result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} def cluster(self, data):
for token in data : result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
if token in self.data : for token in data:
if token in self.data:
result['all'] = result['all'] + self.data[token]['all'] result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor'] result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub'] result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins'] result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del'] result['del'] = result['del'] + self.data[token]['del']
return result return result
def keys(self) :
def keys(self):
return list(self.data.keys()) return list(self.data.keys())
def width(string): def width(string):
return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
def default_cluster(word) :
unicode_names = [ unicodedata.name(char) for char in word ] def default_cluster(word):
for i in reversed(range(len(unicode_names))) : unicode_names = [unicodedata.name(char) for char in word]
if unicode_names[i].startswith('DIGIT') : # 1 for i in reversed(range(len(unicode_names))):
if unicode_names[i].startswith('DIGIT'): # 1
unicode_names[i] = 'Number' # 'DIGIT' unicode_names[i] = 'Number' # 'DIGIT'
elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or
unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')) : unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')):
# 明 / 郎 # 明 / 郎
unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH' unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH'
elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or
unicode_names[i].startswith('LATIN SMALL LETTER')) : unicode_names[i].startswith('LATIN SMALL LETTER')):
# A / a # A / a
unicode_names[i] = 'English' # 'LATIN LETTER' unicode_names[i] = 'English' # 'LATIN LETTER'
elif unicode_names[i].startswith('HIRAGANA LETTER') : # は こ め elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め
unicode_names[i] = 'Japanese' # 'GANA LETTER' unicode_names[i] = 'Japanese' # 'GANA LETTER'
elif (unicode_names[i].startswith('AMPERSAND') or elif (unicode_names[i].startswith('AMPERSAND') or
unicode_names[i].startswith('APOSTROPHE') or unicode_names[i].startswith('APOSTROPHE') or
@ -236,34 +270,40 @@ def default_cluster(word) :
unicode_names[i].startswith('LOW LINE') or unicode_names[i].startswith('LOW LINE') or
unicode_names[i].startswith('NUMBER SIGN') or unicode_names[i].startswith('NUMBER SIGN') or
unicode_names[i].startswith('PLUS SIGN') or unicode_names[i].startswith('PLUS SIGN') or
unicode_names[i].startswith('SEMICOLON')) : unicode_names[i].startswith('SEMICOLON')):
# & / ' / @ / ℃ / = / . / - / _ / # / + / ; # & / ' / @ / ℃ / = / . / - / _ / # / + / ;
del unicode_names[i] del unicode_names[i]
else : else:
return 'Other' return 'Other'
if len(unicode_names) == 0 : if len(unicode_names) == 0:
return 'Other' return 'Other'
if len(unicode_names) == 1 : if len(unicode_names) == 1:
return unicode_names[0] return unicode_names[0]
for i in range(len(unicode_names)-1) : for i in range(len(unicode_names) - 1):
if unicode_names[i] != unicode_names[i+1] : if unicode_names[i] != unicode_names[i + 1]:
return 'Other' return 'Other'
return unicode_names[0] return unicode_names[0]
def usage() :
print("compute-wer.py : compute word error rate (WER) and align recognition results and references.") def usage():
print(" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer") print(
"compute-wer.py : compute word error rate (WER) and align recognition results and references."
)
print(
" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer"
)
if __name__ == '__main__': if __name__ == '__main__':
if len(sys.argv) == 1 : if len(sys.argv) == 1:
usage() usage()
sys.exit(0) sys.exit(0)
calculator = Calculator() calculator = Calculator()
cluster_file = '' cluster_file = ''
ignore_words = set() ignore_words = set()
tochar = False tochar = False
verbose= 1 verbose = 1
padding_symbol= ' ' padding_symbol = ' '
case_sensitive = False case_sensitive = False
max_words_per_line = sys.maxsize max_words_per_line = sys.maxsize
split = None split = None
@ -322,9 +362,9 @@ if __name__ == '__main__':
if sys.argv[1].startswith(a): if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower() b = sys.argv[1][len(a):].lower()
del sys.argv[1] del sys.argv[1]
verbose=0 verbose = 0
try: try:
verbose=int(b) verbose = int(b)
except: except:
if b == 'true' or b != '0': if b == 'true' or b != '0':
verbose = 1 verbose = 1
@ -334,9 +374,9 @@ if __name__ == '__main__':
b = sys.argv[1][len(a):].lower() b = sys.argv[1][len(a):].lower()
del sys.argv[1] del sys.argv[1]
if b == 'space': if b == 'space':
padding_symbol= ' ' padding_symbol = ' '
elif b == 'underline': elif b == 'underline':
padding_symbol= '_' padding_symbol = '_'
continue continue
if True or sys.argv[1].startswith('-'): if True or sys.argv[1].startswith('-'):
#ignore invalid switch #ignore invalid switch
@ -344,7 +384,7 @@ if __name__ == '__main__':
continue continue
if not case_sensitive: if not case_sensitive:
ig=set([w.upper() for w in ignore_words]) ig = set([w.upper() for w in ignore_words])
ignore_words = ig ignore_words = ig
default_clusters = {} default_clusters = {}
@ -368,17 +408,18 @@ if __name__ == '__main__':
array = characterize(line) array = characterize(line)
else: else:
array = line.strip().split() array = line.strip().split()
if len(array)==0: continue if len(array) == 0: continue
fid = array[0] fid = array[0]
rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, split) rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive,
split)
# compute error rate on the interaction of reference file and hyp file # compute error rate on the interaction of reference file and hyp file
for line in open(ref_file, 'r', encoding='utf-8') : for line in open(ref_file, 'r', encoding='utf-8'):
if tochar: if tochar:
array = characterize(line) array = characterize(line)
else: else:
array = line.rstrip('\n').split() array = line.rstrip('\n').split()
if len(array)==0: continue if len(array) == 0: continue
fid = array[0] fid = array[0]
if fid not in rec_set: if fid not in rec_set:
continue continue
@ -387,105 +428,116 @@ if __name__ == '__main__':
if verbose: if verbose:
print('\nutt: %s' % fid) print('\nutt: %s' % fid)
for word in rec + lab : for word in rec + lab:
if word not in default_words : if word not in default_words:
default_cluster_name = default_cluster(word) default_cluster_name = default_cluster(word)
if default_cluster_name not in default_clusters : if default_cluster_name not in default_clusters:
default_clusters[default_cluster_name] = {} default_clusters[default_cluster_name] = {}
if word not in default_clusters[default_cluster_name] : if word not in default_clusters[default_cluster_name]:
default_clusters[default_cluster_name][word] = 1 default_clusters[default_cluster_name][word] = 1
default_words[word] = default_cluster_name default_words[word] = default_cluster_name
result = calculator.calculate(lab, rec) result = calculator.calculate(lab, rec)
if verbose: if verbose:
if result['all'] != 0 : if result['all'] != 0:
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] wer = float(result['ins'] + result['sub'] + result[
else : 'del']) * 100.0 / result['all']
else:
wer = 0.0 wer = 0.0
print('WER: %4.2f %%' % wer, end = ' ') print('WER: %4.2f %%' % wer, end=' ')
print('N=%d C=%d S=%d D=%d I=%d' % print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins'])) (result['all'], result['cor'], result['sub'], result['del'],
result['ins']))
space = {} space = {}
space['lab'] = [] space['lab'] = []
space['rec'] = [] space['rec'] = []
for idx in range(len(result['lab'])) : for idx in range(len(result['lab'])):
len_lab = width(result['lab'][idx]) len_lab = width(result['lab'][idx])
len_rec = width(result['rec'][idx]) len_rec = width(result['rec'][idx])
length = max(len_lab, len_rec) length = max(len_lab, len_rec)
space['lab'].append(length-len_lab) space['lab'].append(length - len_lab)
space['rec'].append(length-len_rec) space['rec'].append(length - len_rec)
upper_lab = len(result['lab']) upper_lab = len(result['lab'])
upper_rec = len(result['rec']) upper_rec = len(result['rec'])
lab1, rec1 = 0, 0 lab1, rec1 = 0, 0
while lab1 < upper_lab or rec1 < upper_rec: while lab1 < upper_lab or rec1 < upper_rec:
if verbose > 1: if verbose > 1:
print('lab(%s):' % fid.encode('utf-8'), end = ' ') print('lab(%s):' % fid.encode('utf-8'), end=' ')
else: else:
print('lab:', end = ' ') print('lab:', end=' ')
lab2 = min(upper_lab, lab1 + max_words_per_line) lab2 = min(upper_lab, lab1 + max_words_per_line)
for idx in range(lab1, lab2): for idx in range(lab1, lab2):
token = result['lab'][idx] token = result['lab'][idx]
print('{token}'.format(token = token), end = '') print('{token}'.format(token=token), end='')
for n in range(space['lab'][idx]) : for n in range(space['lab'][idx]):
print(padding_symbol, end = '') print(padding_symbol, end='')
print(' ',end='') print(' ', end='')
print() print()
if verbose > 1: if verbose > 1:
print('rec(%s):' % fid.encode('utf-8'), end = ' ') print('rec(%s):' % fid.encode('utf-8'), end=' ')
else: else:
print('rec:', end = ' ') print('rec:', end=' ')
rec2 = min(upper_rec, rec1 + max_words_per_line) rec2 = min(upper_rec, rec1 + max_words_per_line)
for idx in range(rec1, rec2): for idx in range(rec1, rec2):
token = result['rec'][idx] token = result['rec'][idx]
print('{token}'.format(token = token), end = '') print('{token}'.format(token=token), end='')
for n in range(space['rec'][idx]) : for n in range(space['rec'][idx]):
print(padding_symbol, end = '') print(padding_symbol, end='')
print(' ',end='') print(' ', end='')
print('\n', end='\n') print('\n', end='\n')
lab1 = lab2 lab1 = lab2
rec1 = rec2 rec1 = rec2
if verbose: if verbose:
print('===========================================================================') print(
'==========================================================================='
)
print() print()
result = calculator.overall() result = calculator.overall()
if result['all'] != 0 : if result['all'] != 0:
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] wer = float(result['ins'] + result['sub'] + result[
else : 'del']) * 100.0 / result['all']
else:
wer = 0.0 wer = 0.0
print('Overall -> %4.2f %%' % wer, end = ' ') print('Overall -> %4.2f %%' % wer, end=' ')
print('N=%d C=%d S=%d D=%d I=%d' % print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins'])) (result['all'], result['cor'], result['sub'], result['del'],
result['ins']))
if not verbose: if not verbose:
print() print()
if verbose: if verbose:
for cluster_id in default_clusters : for cluster_id in default_clusters:
result = calculator.cluster([ k for k in default_clusters[cluster_id] ]) result = calculator.cluster(
if result['all'] != 0 : [k for k in default_clusters[cluster_id]])
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] if result['all'] != 0:
else : wer = float(result['ins'] + result['sub'] + result[
'del']) * 100.0 / result['all']
else:
wer = 0.0 wer = 0.0
print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ') print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
print('N=%d C=%d S=%d D=%d I=%d' % print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins'])) (result['all'], result['cor'], result['sub'], result['del'],
if len(cluster_file) > 0 : # compute separated WERs for word clusters result['ins']))
if len(cluster_file) > 0: # compute separated WERs for word clusters
cluster_id = '' cluster_id = ''
cluster = [] cluster = []
for line in open(cluster_file, 'r', encoding='utf-8') : for line in open(cluster_file, 'r', encoding='utf-8'):
for token in line.decode('utf-8').rstrip('\n').split() : for token in line.decode('utf-8').rstrip('\n').split():
# end of cluster reached, like </Keyword> # end of cluster reached, like </Keyword>
if token[0:2] == '</' and token[len(token)-1] == '>' and \ if token[0:2] == '</' and token[len(token)-1] == '>' and \
token.lstrip('</').rstrip('>') == cluster_id : token.lstrip('</').rstrip('>') == cluster_id :
result = calculator.cluster(cluster) result = calculator.cluster(cluster)
if result['all'] != 0 : if result['all'] != 0:
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] wer = float(result['ins'] + result['sub'] + result[
else : 'del']) * 100.0 / result['all']
else:
wer = 0.0 wer = 0.0
print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ') print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
print('N=%d C=%d S=%d D=%d I=%d' % print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins'])) (result['all'], result['cor'], result['sub'],
result['del'], result['ins']))
cluster_id = '' cluster_id = ''
cluster = [] cluster = []
# begin of cluster reached, like <Keyword> # begin of cluster reached, like <Keyword>
@ -494,7 +546,9 @@ if __name__ == '__main__':
cluster_id = token.lstrip('<').rstrip('>') cluster_id = token.lstrip('<').rstrip('>')
cluster = [] cluster = []
# general terms, like WEATHER / CAR / ... # general terms, like WEATHER / CAR / ...
else : else:
cluster.append(token) cluster.append(token)
print() print()
print('===========================================================================') print(
'==========================================================================='
)

@ -1,11 +1,21 @@
import os # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
import jsonlines import jsonlines
def trans_hyp(origin_hyp, def trans_hyp(origin_hyp, trans_hyp=None, trans_hyp_sclite=None):
trans_hyp = None,
trans_hyp_sclite = None):
""" """
Args: Args:
origin_hyp: The input json file which contains the model output origin_hyp: The input json file which contains the model output
@ -24,12 +34,11 @@ def trans_hyp(origin_hyp,
if trans_hyp_sclite is not None: if trans_hyp_sclite is not None:
with open(trans_hyp_sclite, "w+") as f: with open(trans_hyp_sclite, "w+") as f:
for key in input_dict.keys(): for key in input_dict.keys():
line = input_dict[key] + "(" + key + ".wav" +")" + "\n" line = input_dict[key] + "(" + key + ".wav" + ")" + "\n"
f.write(line) f.write(line)
def trans_ref(origin_ref,
trans_ref = None, def trans_ref(origin_ref, trans_ref=None, trans_ref_sclite=None):
trans_ref_sclite = None):
""" """
Args: Args:
origin_hyp: The input json file which contains the model output origin_hyp: The input json file which contains the model output
@ -49,42 +58,48 @@ def trans_ref(origin_ref,
if trans_ref_sclite is not None: if trans_ref_sclite is not None:
with open(trans_ref_sclite, "w") as f: with open(trans_ref_sclite, "w") as f:
for key in input_dict.keys(): for key in input_dict.keys():
line = input_dict[key] + "(" + key + ".wav" +")" + "\n" line = input_dict[key] + "(" + key + ".wav" + ")" + "\n"
f.write(line) f.write(line)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='format hyp file for compute CER/WER', add_help=True) parser = argparse.ArgumentParser(
prog='format hyp file for compute CER/WER', add_help=True)
parser.add_argument( parser.add_argument(
'--origin_hyp', '--origin_hyp', type=str, default=None, help='origin hyp file')
type=str,
default = None,
help='origin hyp file')
parser.add_argument( parser.add_argument(
'--trans_hyp', type=str, default = None, help='hyp file for caculating CER/WER') '--trans_hyp',
type=str,
default=None,
help='hyp file for caculating CER/WER')
parser.add_argument( parser.add_argument(
'--trans_hyp_sclite', type=str, default = None, help='hyp file for caculating CER/WER by sclite') '--trans_hyp_sclite',
type=str,
default=None,
help='hyp file for caculating CER/WER by sclite')
parser.add_argument( parser.add_argument(
'--origin_ref', '--origin_ref', type=str, default=None, help='origin ref file')
type=str,
default = None,
help='origin ref file')
parser.add_argument( parser.add_argument(
'--trans_ref', type=str, default = None, help='ref file for caculating CER/WER') '--trans_ref',
type=str,
default=None,
help='ref file for caculating CER/WER')
parser.add_argument( parser.add_argument(
'--trans_ref_sclite', type=str, default = None, help='ref file for caculating CER/WER by sclite') '--trans_ref_sclite',
type=str,
default=None,
help='ref file for caculating CER/WER by sclite')
parser_args = parser.parse_args() parser_args = parser.parse_args()
if parser_args.origin_hyp is not None: if parser_args.origin_hyp is not None:
trans_hyp( trans_hyp(
origin_hyp = parser_args.origin_hyp, origin_hyp=parser_args.origin_hyp,
trans_hyp = parser_args.trans_hyp, trans_hyp=parser_args.trans_hyp,
trans_hyp_sclite = parser_args.trans_hyp_sclite, ) trans_hyp_sclite=parser_args.trans_hyp_sclite, )
if parser_args.origin_ref is not None: if parser_args.origin_ref is not None:
trans_ref( trans_ref(
origin_ref = parser_args.origin_ref, origin_ref=parser_args.origin_ref,
trans_ref = parser_args.trans_ref, trans_ref=parser_args.trans_ref,
trans_ref_sclite = parser_args.trans_ref_sclite, ) trans_ref_sclite=parser_args.trans_ref_sclite, )

@ -82,7 +82,10 @@ def main(args):
lexicon_table.add(word) lexicon_table.add(word)
out_n += 1 out_n += 1
print(f"Filter lexicon by unit table: filter out {in_n - out_n}, {out_n}/{in_n}") print(
f"Filter lexicon by unit table: filter out {in_n - out_n}, {out_n}/{in_n}"
)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(

Loading…
Cancel
Save