fix text featurizer

pull/578/head
Hui Zhang 5 years ago
parent 64f0bad5ca
commit af453e0234

@ -295,6 +295,7 @@
"source": [ "source": [
"dataset = ManifestDataset(\n", "dataset = ManifestDataset(\n",
" config.data.test_manifest,\n", " config.data.test_manifest,\n",
" config.data.unit_type,\n",
" config.data.vocab_filepath,\n", " config.data.vocab_filepath,\n",
" config.data.mean_std_filepath,\n", " config.data.mean_std_filepath,\n",
" augmentation_config=\"{}\",\n", " augmentation_config=\"{}\",\n",

@ -85,8 +85,10 @@ def start_server(config, args):
"""Start the ASR server""" """Start the ASR server"""
dataset = ManifestDataset( dataset = ManifestDataset(
config.data.test_manifest, config.data.test_manifest,
config.data.unit_type,
config.data.vocab_filepath, config.data.vocab_filepath,
config.data.mean_std_filepath, config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config="{}", augmentation_config="{}",
max_duration=config.data.max_duration, max_duration=config.data.max_duration,
min_duration=config.data.min_duration, min_duration=config.data.min_duration,

@ -37,8 +37,10 @@ def start_server(config, args):
"""Start the ASR server""" """Start the ASR server"""
dataset = ManifestDataset( dataset = ManifestDataset(
config.data.test_manifest, config.data.test_manifest,
config.data.unit_type,
config.data.vocab_filepath, config.data.vocab_filepath,
config.data.mean_std_filepath, config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config="{}", augmentation_config="{}",
max_duration=config.data.max_duration, max_duration=config.data.max_duration,
min_duration=config.data.min_duration, min_duration=config.data.min_duration,

@ -43,8 +43,10 @@ def tune(config, args):
dev_dataset = ManifestDataset( dev_dataset = ManifestDataset(
config.data.dev_manifest, config.data.dev_manifest,
config.data.unit_type,
config.data.vocab_filepath, config.data.vocab_filepath,
config.data.mean_std_filepath, config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config="{}", augmentation_config="{}",
max_duration=config.data.max_duration, max_duration=config.data.max_duration,
min_duration=config.data.min_duration, min_duration=config.data.min_duration,

@ -21,7 +21,9 @@ _C.data = CN(
train_manifest="", train_manifest="",
dev_manifest="", dev_manifest="",
test_manifest="", test_manifest="",
unit_type="char",
vocab_filepath="", vocab_filepath="",
spm_model_prefix="",
mean_std_filepath="", mean_std_filepath="",
augmentation_config="", augmentation_config="",
max_duration=float('inf'), max_duration=float('inf'),

@ -148,8 +148,10 @@ class DeepSpeech2Trainer(Trainer):
train_dataset = ManifestDataset( train_dataset = ManifestDataset(
config.data.train_manifest, config.data.train_manifest,
config.data.unit_type,
config.data.vocab_filepath, config.data.vocab_filepath,
config.data.mean_std_filepath, config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config=io.open( augmentation_config=io.open(
config.data.augmentation_config, mode='r', config.data.augmentation_config, mode='r',
encoding='utf8').read(), encoding='utf8').read(),
@ -168,8 +170,10 @@ class DeepSpeech2Trainer(Trainer):
dev_dataset = ManifestDataset( dev_dataset = ManifestDataset(
config.data.dev_manifest, config.data.dev_manifest,
config.data.unit_type,
config.data.vocab_filepath, config.data.vocab_filepath,
config.data.mean_std_filepath, config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config="{}", augmentation_config="{}",
max_duration=config.data.max_duration, max_duration=config.data.max_duration,
min_duration=config.data.min_duration, min_duration=config.data.min_duration,
@ -361,8 +365,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
# return raw text # return raw text
test_dataset = ManifestDataset( test_dataset = ManifestDataset(
config.data.test_manifest, config.data.test_manifest,
config.data.unit_type,
config.data.vocab_filepath, config.data.vocab_filepath,
config.data.mean_std_filepath, config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config="{}", augmentation_config="{}",
max_duration=config.data.max_duration, max_duration=config.data.max_duration,
min_duration=config.data.min_duration, min_duration=config.data.min_duration,

@ -109,7 +109,7 @@ class AudioFeaturizer(object):
feat_dim = int(fft_point * (self._target_sample_rate / 1000) / 2 + feat_dim = int(fft_point * (self._target_sample_rate / 1000) / 2 +
1) 1)
elif self._specgram_type == 'mfcc': elif self._specgram_type == 'mfcc':
# mfcc,delta, delta-delta # mfcc, delta, delta-delta
feat_dim = int(13 * 3) feat_dim = int(13 * 3)
else: else:
raise ValueError("Unknown specgram_type %s. " raise ValueError("Unknown specgram_type %s. "

@ -52,7 +52,9 @@ class SpeechFeaturizer(object):
""" """
def __init__(self, def __init__(self,
unit_type,
vocab_filepath, vocab_filepath,
spm_model_prefix=None,
specgram_type='linear', specgram_type='linear',
stride_ms=10.0, stride_ms=10.0,
window_ms=20.0, window_ms=20.0,
@ -70,7 +72,8 @@ class SpeechFeaturizer(object):
target_sample_rate=target_sample_rate, target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization, use_dB_normalization=use_dB_normalization,
target_dB=target_dB) target_dB=target_dB)
self._text_featurizer = TextFeaturizer(vocab_filepath) self._text_featurizer = TextFeaturizer(unit_type, vocab_filepath,
spm_model_prefix)
def featurize(self, speech_segment, keep_transcription_text): def featurize(self, speech_segment, keep_transcription_text):
"""Extract features for speech segment. """Extract features for speech segment.

@ -15,25 +15,35 @@
import os import os
import codecs import codecs
import sentencepiece as spm
from deepspeech.frontend.utility import UNK
class TextFeaturizer(object): class TextFeaturizer(object):
def __init__(self, unit_type, vocab_filepath, spm_model_prefix=None):
"""Text featurizer, for processing or extracting features from text. """Text featurizer, for processing or extracting features from text.
Currently, it only supports char-level tokenizing and conversion into Currently, it supports char/word/sentence-piece level tokenizing and conversion into
a list of token indices. Note that the token indexing order follows the a list of token indices. Note that the token indexing order follows the
given vocabulary file. given vocabulary file.
:param vocab_filepath: Filepath to load vocabulary for token indices Args:
conversion. unit_type (str): unit type, e.g. char, word, spm
:type specgram_type: str vocab_filepath (str): Filepath to load vocabulary for token indices conversion.
spm_model_prefix (str, optional): spm model prefix. Defaults to None.
""" """
assert unit_type in ('char', 'spm', 'word')
def __init__(self, vocab_filepath): self.unk = UNK
self.unk = '<unk>' self.unit_type = unit_type
self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file( self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file(
vocab_filepath) vocab_filepath)
if unit_type == 'spm':
spm_model = spm_model_prefix + '.model'
self.sp = spm.SentencePieceProcessor()
self.sp.Load(self.spm_model)
def featurize(self, text): def featurize(self, text):
"""Convert text string to a list of token indices in char-level.Note """Convert text string to a list of token indices in char-level.Note
that the token indexing order follows the given vocabulary file. that the token indexing order follows the given vocabulary file.
@ -43,7 +53,13 @@ class TextFeaturizer(object):
:return: List of char-level token indices. :return: List of char-level token indices.
:rtype: list :rtype: list
""" """
if unit_type == 'char':
tokens = self._char_tokenize(text) tokens = self._char_tokenize(text)
elif unit_type == 'word':
tokens = self._word_tokenize(text)
else:
tokens = self._spm_tokenize(text)
ids = [] ids = []
for token in tokens: for token in tokens:
token = token if token in self._vocab_dict else self.unk token = token if token in self._vocab_dict else self.unk
@ -72,6 +88,42 @@ class TextFeaturizer(object):
"""Character tokenizer.""" """Character tokenizer."""
return list(text.strip()) return list(text.strip())
def _word_tokenize(self, text):
"""Word tokenizer, spearte by <space>."""
return text.strip().split()
def _spm_tokenize(self, text):
"""spm tokenize.
Args:
text (str): text string.
Returns:
List[str]: sentence pieces str code
"""
stats = {"num_empty": 0, "num_filtered": 0}
def valid(line):
return True
def encode(l):
return self.sp.EncodeAsPieces(l)
def encode_line(line):
line = line.strip()
if len(line) > 0:
line = encode(line)
if valid(line):
return line
else:
stats["num_filtered"] += 1
else:
stats["num_empty"] += 1
return None
enc_line = encode_line(text)
return enc_line
def _load_vocabulary_from_file(self, vocab_filepath): def _load_vocabulary_from_file(self, vocab_filepath):
"""Load vocabulary from file.""" """Load vocabulary from file."""
vocab_lines = [] vocab_lines = []

@ -16,6 +16,7 @@
import numpy as np import numpy as np
import random import random
from deepspeech.frontend.utility import read_manifest from deepspeech.frontend.utility import read_manifest
from deepspeech.frontend.utility import load_cmvn
from deepspeech.frontend.audio import AudioSegment from deepspeech.frontend.audio import AudioSegment
@ -79,10 +80,8 @@ class FeatureNormalizer(object):
def _read_mean_std_from_file(self, filepath, eps=1e-20): def _read_mean_std_from_file(self, filepath, eps=1e-20):
"""Load mean and std from file.""" """Load mean and std from file."""
npzfile = np.load(filepath) mean, std = load_cmvn(filepath, filetype='npz')
self._mean = npzfile["mean"] self._mean = mean
std = npzfile["std"]
std = np.clip(std, eps, None)
self._istd = 1.0 / std self._istd = 1.0 / std
def _compute_mean_std(self, manifest_path, featurize_func, num_samples): def _compute_mean_std(self, manifest_path, featurize_func, num_samples):
@ -92,8 +91,7 @@ class FeatureNormalizer(object):
features = [] features = []
for instance in sampled_manifest: for instance in sampled_manifest:
features.append( features.append(
featurize_func( featurize_func(AudioSegment.from_file(instance["feat"])))
AudioSegment.from_file(instance["audio_filepath"])))
features = np.hstack(features) #(D, T) features = np.hstack(features) #(D, T)
self._mean = np.mean(features, axis=1).reshape([-1, 1]) #(D, 1) self._mean = np.mean(features, axis=1).reshape([-1, 1]) #(D, 1)
self._std = np.std(features, axis=1).reshape([-1, 1]) #(D, 1) self._std = np.std(features, axis=1).reshape([-1, 1]) #(D, 1)

@ -20,6 +20,7 @@ import os
import tarfile import tarfile
import time import time
import logging import logging
from typing import List
from threading import Thread from threading import Thread
from multiprocessing import Process, Manager, Value from multiprocessing import Process, Manager, Value
@ -39,31 +40,32 @@ EOS = SOS
UNK = "<unk>" UNK = "<unk>"
BLANK = "<blank>" BLANK = "<blank>"
# """Load and parse manifest file.
# Instances with durations outside [min_duration, max_duration] will be
# filtered out.
# :param manifest_path: Manifest file to load and parse.
# :type manifest_path: str
# :param max_duration:maximum output seq length, in seconds for raw wav, in frame numbers for feature data.
# :type max_duration: float
# :param min_duration: minimum input seq length, in seconds for raw wav, in frame numbers for feature data.
# :type min_duration: float
# :return: Manifest parsing results. List of dict.
# :rtype: list
# :raises IOError: If failed to parse the manifest.
# """
def read_manifest( def read_manifest(
manifest_path, manifest_path,
max_input_len=float('inf'), max_input_len=float('inf'),
min_input_len=0.0, min_input_len=0.0,
max_output_len=500.0, max_output_len=float('inf'),
min_output_len=0.0, min_output_len=0.0,
max_output_input_ratio=10.0, max_output_input_ratio=float('inf'),
min_output_input_ratio=0.05, ): min_output_input_ratio=0.0, ):
"""Load and parse manifest file.
Args:
manifest_path ([type]): Manifest file to load and parse.
max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf').
min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0.
max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05.
Raises:
IOError: If failed to parse the manifest.
Returns:
List[dict]: Manifest parsing results.
"""
manifest = [] manifest = []
for json_line in codecs.open(manifest_path, 'r', 'utf-8'): for json_line in codecs.open(manifest_path, 'r', 'utf-8'):
@ -71,33 +73,23 @@ def read_manifest(
json_data = json.loads(json_line) json_data = json.loads(json_line)
except Exception as e: except Exception as e:
raise IOError("Error reading manifest: %s" % str(e)) raise IOError("Error reading manifest: %s" % str(e))
feat_len = json_data["feat_shape"][0]
token_len = json_data["token_shape"][0] feat_len = json_data["feat_shape"][
0] if 'feat_shape' in json_data else 1.0
token_len = json_data["token_shape"][
0] if 'token_shape' in json_data else 1.0
conditions = [ conditions = [
feat_len > min_input_len, feat_len >= min_input_len,
feat_len < max_input_len, feat_len <= max_input_len,
token_len > min_output_len, token_len >= min_output_len,
token_len < max_output_len, token_len <= max_output_len,
token_len / feat_len > min_output_input_ratio, token_len / feat_len >= min_output_input_ratio,
token_len / feat_len < max_output_input_ratio, token_len / feat_len <= max_output_input_ratio,
] ]
if all(conditions): if all(conditions):
manifest.append(json_data) manifest.append(json_data)
return manifest return manifest
# parser.add_argument('--max_input_len', type=float,
# default=20,
# help='maximum output seq length, in seconds for raw wav, in frame numbers for feature data')
# parser.add_argument('--min_output_len', type=float,
# default=0, help='minimum input seq length, in modeling units')
# parser.add_argument('--max_output_len', type=float,
# default=500,
# help='maximum output seq length, in modeling units')
# parser.add_argument('--min_output_input_ratio', type=float, default=0.05,
# help='minimum output seq length/output seq length ratio')
# parser.add_argument('--max_output_input_ratio', type=float, default=10,
# help='maximum output seq length/output seq length ratio')
def rms_to_db(rms: float): def rms_to_db(rms: float):
"""Root Mean Square to dB. """Root Mean Square to dB.
@ -251,8 +243,8 @@ def _load_kaldi_cmvn(kaldi_cmvn_file):
def _load_npz_cmvn(npz_cmvn_file, eps=1e-20): def _load_npz_cmvn(npz_cmvn_file, eps=1e-20):
npzfile = np.load(npz_cmvn_file) npzfile = np.load(npz_cmvn_file)
means = npzfile["mean"] means = npzfile["mean"] #(D, 1)
std = npzfile["std"] std = npzfile["std"] #(D, 1)
std = np.clip(std, eps, None) std = np.clip(std, eps, None)
variance = 1.0 / std variance = 1.0 / std
cmvn = np.array([means, variance]) cmvn = np.array([means, variance])
@ -278,7 +270,7 @@ def load_cmvn(cmvn_file: str, filetype: str):
cmvn = _load_json_cmvn(cmvn_file) cmvn = _load_json_cmvn(cmvn_file)
elif filetype == "kaldi": elif filetype == "kaldi":
cmvn = _load_kaldi_cmvn(cmvn_file) cmvn = _load_kaldi_cmvn(cmvn_file)
elif filtype == "npz": elif filetype == "npz":
cmvn = _load_npz_cmvn(cmvn_file) cmvn = _load_npz_cmvn(cmvn_file)
else: else:
raise ValueError(f"cmvn file type no support: {filetype}") raise ValueError(f"cmvn file type no support: {filetype}")

@ -21,8 +21,10 @@ from deepspeech.io.dataset import ManifestDataset
def create_dataloader(manifest_path, def create_dataloader(manifest_path,
unit_type,
vocab_filepath, vocab_filepath,
mean_std_filepath, mean_std_filepath,
spm_model_prefix,
augmentation_config='{}', augmentation_config='{}',
max_duration=float('inf'), max_duration=float('inf'),
min_duration=0.0, min_duration=0.0,
@ -42,8 +44,10 @@ def create_dataloader(manifest_path,
dataset = ManifestDataset( dataset = ManifestDataset(
manifest_path, manifest_path,
unit_type,
vocab_filepath, vocab_filepath,
mean_std_filepath, mean_std_filepath,
spm_model_prefix=spm_model_prefix,
augmentation_config=augmentation_config, augmentation_config=augmentation_config,
max_duration=max_duration, max_duration=max_duration,
min_duration=min_duration, min_duration=min_duration,

@ -38,8 +38,10 @@ __all__ = [
class ManifestDataset(Dataset): class ManifestDataset(Dataset):
def __init__(self, def __init__(self,
manifest_path, manifest_path,
unit_type,
vocab_filepath, vocab_filepath,
mean_std_filepath, mean_std_filepath,
spm_model_prefix=None,
augmentation_config='{}', augmentation_config='{}',
max_duration=float('inf'), max_duration=float('inf'),
min_duration=0.0, min_duration=0.0,
@ -57,8 +59,10 @@ class ManifestDataset(Dataset):
Args: Args:
manifest_path (str): manifest josn file path manifest_path (str): manifest josn file path
vocab_filepath (str): vocab file path unit_type(str): token unit type, e.g. char, word, spm
vocab_filepath (str): vocab file path.
mean_std_filepath (str): mean and std file path, which suffix is *.npy mean_std_filepath (str): mean and std file path, which suffix is *.npy
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
augmentation_config (str, optional): augmentation json str. Defaults to '{}'. augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
max_duration (float, optional): audio length in seconds must less than this. Defaults to float('inf'). max_duration (float, optional): audio length in seconds must less than this. Defaults to float('inf').
min_duration (float, optional): audio length is seconds must greater than this. Defaults to 0.0. min_duration (float, optional): audio length is seconds must greater than this. Defaults to 0.0.
@ -78,10 +82,12 @@ class ManifestDataset(Dataset):
self._max_duration = max_duration self._max_duration = max_duration
self._min_duration = min_duration self._min_duration = min_duration
self._normalizer = FeatureNormalizer(mean_std_filepath) self._normalizer = FeatureNormalizer(mean_std_filepath)
self._augmentation_pipeline = AugmentationPipeline( self._audio_augmentation_pipeline = AugmentationPipeline(
augmentation_config=augmentation_config, random_seed=random_seed) augmentation_config=augmentation_config, random_seed=random_seed)
self._speech_featurizer = SpeechFeaturizer( self._speech_featurizer = SpeechFeaturizer(
unit_type=unit_type,
vocab_filepath=vocab_filepath, vocab_filepath=vocab_filepath,
spm_model_prefix=spm_model_prefix,
specgram_type=specgram_type, specgram_type=specgram_type,
stride_ms=stride_ms, stride_ms=stride_ms,
window_ms=window_ms, window_ms=window_ms,
@ -174,7 +180,7 @@ class ManifestDataset(Dataset):
self._subfile_from_tar(audio_file), transcript) self._subfile_from_tar(audio_file), transcript)
else: else:
speech_segment = SpeechSegment.from_file(audio_file, transcript) speech_segment = SpeechSegment.from_file(audio_file, transcript)
self._augmentation_pipeline.transform_audio(speech_segment) self._audio_augmentation_pipeline.transform_audio(speech_segment)
specgram, transcript_part = self._speech_featurizer.featurize( specgram, transcript_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text) speech_segment, self._keep_transcription_text)
specgram = self._normalizer.apply(specgram) specgram = self._normalizer.apply(specgram)
@ -191,7 +197,7 @@ class ManifestDataset(Dataset):
def reader(): def reader():
for instance in manifest: for instance in manifest:
inst = self.process_utterance(instance["audio_filepath"], inst = self.process_utterance(instance["feat"],
instance["text"]) instance["text"])
yield inst yield inst
@ -202,5 +208,4 @@ class ManifestDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
instance = self._manifest[idx] instance = self._manifest[idx]
return self.process_utterance(instance["audio_filepath"], return self.process_utterance(instance["feat"], instance["text"])
instance["text"])

@ -52,6 +52,7 @@ fi
# format manifest with tokenids, vocab size # format manifest with tokenids, vocab size
python3 ${MAIN_ROOT}/utils/format_data.py \ python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \ --feat_type "raw" \
--cmvn_path "data/mean_std.npz" \
--unit_type "bpe" \ --unit_type "bpe" \
--bpe_model_prefix ${bpeprefix} \ --bpe_model_prefix ${bpeprefix} \
--vocab_path="data/vocab.txt" \ --vocab_path="data/vocab.txt" \

@ -3,6 +3,7 @@ resampy==0.2.2
SoundFile==0.9.0.post1 SoundFile==0.9.0.post1
python_speech_features python_speech_features
tensorboardX tensorboardX
sentencepiece
yacs yacs
typeguard typeguard
pre-commit pre-commit

@ -24,6 +24,7 @@ from deepspeech.frontend.utility import read_manifest
from deepspeech.frontend.utility import UNK from deepspeech.frontend.utility import UNK
from deepspeech.frontend.utility import BLANK from deepspeech.frontend.utility import BLANK
from deepspeech.frontend.utility import SOS from deepspeech.frontend.utility import SOS
from deepspeech.frontend.utility import load_cmvn
from deepspeech.utils.utility import add_arguments from deepspeech.utils.utility import add_arguments
from deepspeech.utils.utility import print_arguments from deepspeech.utils.utility import print_arguments
@ -31,10 +32,13 @@ parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
add_arg('feat_type', str, "raw", "speech feature type, e.g. raw(wav, flac), kaldi") add_arg('feat_type', str, "raw", "speech feature type, e.g. raw(wav, flac), kaldi")
add_arg('cmvn_path', str,
'examples/librispeech/data/mean_std.npz',
"Filepath of cmvn.")
add_arg('unit_type', str, "character", "Unit type, e.g. character, word, bpe") add_arg('unit_type', str, "character", "Unit type, e.g. character, word, bpe")
add_arg('vocab_path', str, add_arg('vocab_path', str,
'examples/librispeech/data/vocab.txt', 'examples/librispeech/data/vocab.txt',
"Filepath to write the vocabulary.") "Filepath of the vocabulary.")
add_arg('manifest_paths', str, add_arg('manifest_paths', str,
None, None,
"Filepaths of manifests for building vocabulary. " "Filepaths of manifests for building vocabulary. "
@ -51,6 +55,11 @@ args = parser.parse_args()
def main(): def main():
print_arguments(args) print_arguments(args)
# get feat dim
mean, std = load_cmvn(args.cmvn_path, filetype='npz')
feat_dim = mean.shape[0]
print(f"Feature dim: {feat_dim}")
# read vocab # read vocab
vocab = dict() vocab = dict()
with open(args.vocab_path, 'r', encoding='utf-8') as fin: with open(args.vocab_path, 'r', encoding='utf-8') as fin:
@ -58,6 +67,7 @@ def main():
token = line.strip() token = line.strip()
vocab[token] = len(vocab) vocab[token] = len(vocab)
vocab_size = len(vocab) vocab_size = len(vocab)
print(f"Vocab size: {vocab_size}")
fout = open(args.output_path, 'w', encoding='utf-8') fout = open(args.output_path, 'w', encoding='utf-8')
@ -78,6 +88,12 @@ def main():
line_json['token'] = tokens line_json['token'] = tokens
line_json['token_id'] = tokenids line_json['token_id'] = tokenids
line_json['token_shape'] = (len(tokenids), vocab_size) line_json['token_shape'] = (len(tokenids), vocab_size)
feat_shape = line_json['feat_shape']
assert isinstance(feat_shape, (list, tuple)), type(feat_shape)
if args.feat_type == 'raw':
feat_shape.append(feat_dim)
else: # kaldi
raise NotImplemented('no support kaldi feat now!')
fout.write(json.dumps(line_json) + '\n') fout.write(json.dumps(line_json) + '\n')
else: else:
import sentencepiece as spm import sentencepiece as spm
@ -118,6 +134,12 @@ def main():
line_json['token'] = tokens line_json['token'] = tokens
line_json['token_id'] = tokenids line_json['token_id'] = tokenids
line_json['token_shape'] = (len(tokenids), vocab_size) line_json['token_shape'] = (len(tokenids), vocab_size)
feat_shape = line_json['feat_shape']
assert isinstance(feat_shape, (list, tuple)), type(feat_shape)
if args.feat_type == 'raw':
feat_shape.append(feat_dim)
else: # kaldi
raise NotImplemented('no support kaldi feat now!')
fout.write(json.dumps(line_json) + '\n') fout.write(json.dumps(line_json) + '\n')
fout.close() fout.close()

Loading…
Cancel
Save