parent
c7c95025fc
commit
6d8a2de6d8
@ -0,0 +1,42 @@
|
||||
import codecs
|
||||
|
||||
|
||||
class Alphabet(object):
|
||||
def __init__(self, config_file):
|
||||
self._config_file = config_file
|
||||
self._label_to_str = []
|
||||
self._str_to_label = {}
|
||||
self._size = 0
|
||||
self.blank_token = 1
|
||||
with codecs.open(config_file, 'r', 'utf-8') as fin:
|
||||
for line in fin:
|
||||
if line[0:2] == '\\#':
|
||||
line = '#\n'
|
||||
elif line[0] == '#':
|
||||
continue
|
||||
self._label_to_str += line[:-1] # remove the line ending
|
||||
self._str_to_label[line[:-1]] = self._size
|
||||
self._size += 1
|
||||
|
||||
def string_from_label(self, label):
|
||||
return self._label_to_str[label]
|
||||
|
||||
def label_from_string(self, string):
|
||||
try:
|
||||
return self._str_to_label[string]
|
||||
except KeyError as e:
|
||||
raise KeyError(
|
||||
'''ERROR: Your transcripts contain characters which do not occur in data/alphabet.txt! Use util/check_characters.py to see what characters are in your {train,dev,test}.csv transcripts, and then add all these to data/alphabet.txt.'''
|
||||
).with_traceback(e.__traceback__)
|
||||
|
||||
def decode(self, labels):
|
||||
res = ''
|
||||
for label in labels:
|
||||
res += self.string_from_label(label)
|
||||
return res
|
||||
|
||||
def size(self):
|
||||
return self._size
|
||||
|
||||
def config_file(self):
|
||||
return self._config_file
|
@ -0,0 +1,53 @@
|
||||
import json
|
||||
|
||||
|
||||
class Words(object):
|
||||
|
||||
def __init__(self, prob_split, metadata, alphabet, frame_to_sec=.03):
|
||||
|
||||
self.raw_output = ''
|
||||
self.extended_output = []
|
||||
|
||||
word = ''
|
||||
start_step, confidence, num_char = 0, 1.0, 0
|
||||
metadata_size = metadata.tokens.size()
|
||||
|
||||
for i in range(metadata_size):
|
||||
token = metadata.tokens[i]
|
||||
letter = alphabet.string_from_label(token)
|
||||
time_step = metadata.timesteps[i]
|
||||
|
||||
# prepare raw output
|
||||
self.raw_output += letter
|
||||
|
||||
# prepare extended output
|
||||
if token != alphabet.blank_token:
|
||||
word.append(letter)
|
||||
confidence *= prob_split[time_step][token]
|
||||
num_char += 1
|
||||
if len(word) == 1:
|
||||
start_step = time_step
|
||||
|
||||
if token == alphabet.blank_token or i == metadata_size-1:
|
||||
duration_step = time_step - start_step
|
||||
|
||||
if duration_step < 0:
|
||||
duration_step = 0
|
||||
|
||||
self.extended_output.append({"word": word,
|
||||
"start_time": frame_to_sec * start_step,
|
||||
"duration": frame_to_sec * duration_step,
|
||||
"confidence": confidence**(1.0/num_char)})
|
||||
# reset
|
||||
word = ''
|
||||
start_step, confidence, num_char = 0, 1.0, 0
|
||||
|
||||
def to_json(self):
|
||||
return json.dumps({"raw_output": self.raw_output,
|
||||
"extended_output": self.extended_output})
|
||||
|
||||
def save_json(self, file_path):
|
||||
with open(file_path, 'w') as outfile:
|
||||
json.dump({"raw_output": self.raw_output,
|
||||
"extended_output": self.extended_output},
|
||||
outfile)
|
Loading…
Reference in new issue