parent
3ad55a31e7
commit
2bdd87633a
@ -0,0 +1,112 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from Fairseq 2023 (https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/simple_kmeans/dump_hubert_feature.py)
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
from typing import Dict
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import soundfile as sf
|
||||||
|
import tqdm
|
||||||
|
from feature_utils import dump_feature
|
||||||
|
from feature_utils import get_path_iterator
|
||||||
|
from paddlespeech.s2t.models.hubert.hubert_ASR import HubertASR
|
||||||
|
from paddlespeech.s2t.models.hubert.modules.hubert_model import HubertConfig
|
||||||
|
from paddlespeech.s2t.models.hubert.modules.hubert_model import HubertModel
|
||||||
|
from paddlespeech.s2t.models.hubert.modules.hubert_model import HubertPretrainingConfig
|
||||||
|
from paddlespeech.s2t.modules.align import LayerNorm
|
||||||
|
from paddlespeech.s2t.utils.utility import UpdateConfig
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
||||||
|
stream=sys.stdout, )
|
||||||
|
logger = logging.getLogger("dump_hubert_feature")
|
||||||
|
|
||||||
|
|
||||||
|
class HubertFeatureReader(object):
|
||||||
|
def __init__(self, config, layer, max_chunk=1600000):
|
||||||
|
self.config = CfgNode(new_allowed=True)
|
||||||
|
self.config.merge_from_file(config)
|
||||||
|
self.config.output_dim = 5002
|
||||||
|
# import pdb
|
||||||
|
# pdb.set_trace()
|
||||||
|
model = HubertASR.from_config(self.config)
|
||||||
|
# model_dict = paddle.load(self.config.hubert_params_path)
|
||||||
|
# model.hubert.set_state_dict(model_dict)
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
with open(self.config.vocab_filepath) as f:
|
||||||
|
dicts = [symbol.strip() for symbol in f.readlines()]
|
||||||
|
task_cfg = self.model.merge_with_parent(HubertPretrainingConfig,
|
||||||
|
dict(self.config.task_cfg))
|
||||||
|
model_cfg = self.model.merge_with_parent(HubertConfig,
|
||||||
|
dict(self.config.model_cfg))
|
||||||
|
self.hubert = HubertModel(model_cfg, task_cfg, dicts)
|
||||||
|
model_dict = paddle.load(self.config.hubert_params_path)
|
||||||
|
self.hubert.set_state_dict(model_dict)
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
self.layer = layer
|
||||||
|
self.max_chunk = max_chunk
|
||||||
|
logger.info(f" max_chunk = {self.max_chunk}")
|
||||||
|
|
||||||
|
def read_audio(self, path, ref_len=None):
|
||||||
|
wav, _ = sf.read(path, dtype="float32", always_2d=True)
|
||||||
|
if wav.ndim == 2:
|
||||||
|
wav = wav.mean(-1)
|
||||||
|
assert wav.ndim == 1, wav.ndim
|
||||||
|
if ref_len is not None and abs(ref_len - len(wav)) > 160:
|
||||||
|
logging.warning(f"ref {ref_len} != read {len(wav)} ({path})")
|
||||||
|
return wav
|
||||||
|
|
||||||
|
def get_feats(self, path, ref_len=None):
|
||||||
|
x = self.read_audio(path, ref_len=ref_len)
|
||||||
|
with paddle.no_grad():
|
||||||
|
x = paddle.to_tensor(x).float().cuda()
|
||||||
|
# if self.task.cfg.normalize:
|
||||||
|
# x = LayerNorm(x, x.shape)
|
||||||
|
x = x.view(1, -1)
|
||||||
|
|
||||||
|
feat = []
|
||||||
|
for start in range(0, x.shape[0], self.max_chunk):
|
||||||
|
x_chunk = x[:, start:start + self.max_chunk]
|
||||||
|
feat_chunk, _ = self.hubert.extract_features(
|
||||||
|
source=x_chunk,
|
||||||
|
padding_mask=None,
|
||||||
|
mask=False,
|
||||||
|
output_layer=self.layer, )
|
||||||
|
feat.append(feat_chunk)
|
||||||
|
return paddle.concat(feat, 1).squeeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def main(tsv_dir, split, config, layer, nshard, rank, feat_dir, max_chunk):
|
||||||
|
reader = HubertFeatureReader(config, layer, max_chunk)
|
||||||
|
generator, num = get_path_iterator(f"{tsv_dir}/{split}.tsv", nshard, rank)
|
||||||
|
dump_feature(reader, generator, num, split, nshard, rank, feat_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("tsv_dir")
|
||||||
|
parser.add_argument("split")
|
||||||
|
parser.add_argument("config")
|
||||||
|
parser.add_argument("layer", type=int)
|
||||||
|
parser.add_argument("nshard", type=int)
|
||||||
|
parser.add_argument("rank", type=int)
|
||||||
|
parser.add_argument("feat_dir")
|
||||||
|
parser.add_argument("--max_chunk", type=int, default=1600000)
|
||||||
|
args = parser.parse_args()
|
||||||
|
logger.info(args)
|
||||||
|
|
||||||
|
main(**vars(args))
|
@ -0,0 +1,85 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from Fairseq 2023 (https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/simple_kmeans/dump_km_label.py)
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import joblib
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
||||||
|
stream=sys.stdout, )
|
||||||
|
logger = logging.getLogger("dump_km_label")
|
||||||
|
|
||||||
|
|
||||||
|
class ApplyKmeans(object):
|
||||||
|
def __init__(self, km_path):
|
||||||
|
self.km_model = joblib.load(km_path)
|
||||||
|
self.C_np = self.km_model.cluster_centers_.transpose()
|
||||||
|
self.Cnorm_np = (self.C_np**2).sum(0, keepdims=True)
|
||||||
|
|
||||||
|
self.C = paddle.to_tensor(self.C_np)
|
||||||
|
self.Cnorm = paddle.to_tensor(self.Cnorm_np)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
if isinstance(x, paddle.Tensor):
|
||||||
|
dist = (x.pow(2).sum(1, keepdim=True) - 2 * paddle.matmul(x, self.C)
|
||||||
|
+ self.Cnorm)
|
||||||
|
return dist.argmin(dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
dist = ((x**2).sum(1, keepdims=True) - 2 * np.matmul(x, self.C_np) +
|
||||||
|
self.Cnorm_np)
|
||||||
|
return np.argmin(dist, axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_feat_iterator(feat_dir, split, nshard, rank):
|
||||||
|
feat_path = f"{feat_dir}/{split}_{rank}_{nshard}.npy"
|
||||||
|
leng_path = f"{feat_dir}/{split}_{rank}_{nshard}.len"
|
||||||
|
with open(leng_path, "r") as f:
|
||||||
|
lengs = [int(line.rstrip()) for line in f]
|
||||||
|
offsets = [0] + np.cumsum(lengs[:-1]).tolist()
|
||||||
|
|
||||||
|
def iterate():
|
||||||
|
feat = np.load(feat_path, mmap_mode="r")
|
||||||
|
assert feat.shape[0] == (offsets[-1] + lengs[-1])
|
||||||
|
for offset, leng in zip(offsets, lengs):
|
||||||
|
yield feat[offset:offset + leng]
|
||||||
|
|
||||||
|
return iterate, len(lengs)
|
||||||
|
|
||||||
|
|
||||||
|
def dump_label(feat_dir, split, km_path, nshard, rank, lab_dir):
|
||||||
|
apply_kmeans = ApplyKmeans(km_path)
|
||||||
|
generator, num = get_feat_iterator(feat_dir, split, nshard, rank)
|
||||||
|
iterator = generator()
|
||||||
|
|
||||||
|
lab_path = f"{lab_dir}/{split}_{rank}_{nshard}.km"
|
||||||
|
os.makedirs(lab_dir, exist_ok=True)
|
||||||
|
with open(lab_path, "w") as f:
|
||||||
|
for feat in tqdm.tqdm(iterator, total=num):
|
||||||
|
lab = apply_kmeans(feat).tolist()
|
||||||
|
f.write(" ".join(map(str, lab)) + "\n")
|
||||||
|
logger.info("finished successfully")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("feat_dir")
|
||||||
|
parser.add_argument("split")
|
||||||
|
parser.add_argument("km_path")
|
||||||
|
parser.add_argument("nshard", type=int)
|
||||||
|
parser.add_argument("rank", type=int)
|
||||||
|
parser.add_argument("lab_dir")
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(str(args))
|
||||||
|
|
||||||
|
dump_label(**vars(args))
|
@ -0,0 +1,72 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from Fairseq 2023 (https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/simple_kmeans/dump_mfcc_feature.py)
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddleaudio
|
||||||
|
import soundfile as sf
|
||||||
|
from feature_utils import dump_feature
|
||||||
|
from feature_utils import get_path_iterator
|
||||||
|
from python_speech_features import delta
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
||||||
|
stream=sys.stdout, )
|
||||||
|
logger = logging.getLogger("dump_mfcc_feature")
|
||||||
|
|
||||||
|
|
||||||
|
class MfccFeatureReader(object):
|
||||||
|
def __init__(self, sample_rate):
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
|
||||||
|
def read_audio(self, path, ref_len=None):
|
||||||
|
wav, _ = sf.read(path, dtype="float32", always_2d=True)
|
||||||
|
if ref_len is not None and abs(ref_len - len(wav)) > 160:
|
||||||
|
logging.warning(f"ref {ref_len} != read {len(wav)} ({path})")
|
||||||
|
return wav
|
||||||
|
|
||||||
|
def get_feats(self, path, ref_len=None):
|
||||||
|
x = self.read_audio(path, ref_len=ref_len)
|
||||||
|
with paddle.no_grad():
|
||||||
|
x = paddle.to_tensor(x, dtype="float32")
|
||||||
|
x = x.reshape([1, -1])
|
||||||
|
|
||||||
|
mfccs = paddleaudio.compliance.kaldi.mfcc(
|
||||||
|
waveform=x,
|
||||||
|
sr=self.sample_rate,
|
||||||
|
use_energy=False, ) # (freq, time)
|
||||||
|
|
||||||
|
deltas = delta(mfccs, 2)
|
||||||
|
ddeltas = delta(deltas, 2)
|
||||||
|
concat = paddle.concat(
|
||||||
|
x=[mfccs, paddle.to_tensor(deltas), paddle.to_tensor(ddeltas)],
|
||||||
|
axis=-1)
|
||||||
|
return concat
|
||||||
|
|
||||||
|
|
||||||
|
def main(tsv_dir, split, nshard, rank, feat_dir, sample_rate):
|
||||||
|
reader = MfccFeatureReader(sample_rate)
|
||||||
|
generator, num = get_path_iterator(f"{tsv_dir}/{split}.tsv", nshard, rank)
|
||||||
|
dump_feature(reader, generator, num, split, nshard, rank, feat_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("tsv_dir")
|
||||||
|
parser.add_argument("split")
|
||||||
|
parser.add_argument("nshard", type=int)
|
||||||
|
parser.add_argument("rank", type=int)
|
||||||
|
parser.add_argument("feat_dir")
|
||||||
|
parser.add_argument("--sample_rate", type=int, default=16000)
|
||||||
|
args = parser.parse_args()
|
||||||
|
logger.info(args)
|
||||||
|
|
||||||
|
main(**vars(args))
|
@ -0,0 +1,62 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from Fairseq 2023 (https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/simple_kmeans/feature_utils.py)
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import tqdm
|
||||||
|
from npy_append_array import NpyAppendArray
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
||||||
|
stream=sys.stdout, )
|
||||||
|
logger = logging.getLogger("feature_utils")
|
||||||
|
|
||||||
|
|
||||||
|
def get_shard_range(tot, nshard, rank):
|
||||||
|
assert rank < nshard and rank >= 0, f"invaid rank/nshard {rank}/{nshard}"
|
||||||
|
start = round(tot / nshard * rank)
|
||||||
|
end = round(tot / nshard * (rank + 1))
|
||||||
|
assert start < end, f"start={start}, end={end}"
|
||||||
|
logger.info(f"rank {rank} of {nshard}, process {end-start} "
|
||||||
|
f"({start}-{end}) out of {tot}")
|
||||||
|
return start, end
|
||||||
|
|
||||||
|
|
||||||
|
def get_path_iterator(tsv, nshard, rank):
|
||||||
|
with open(tsv, "r") as f:
|
||||||
|
root = f.readline().rstrip()
|
||||||
|
lines = [line.rstrip() for line in f]
|
||||||
|
start, end = get_shard_range(len(lines), nshard, rank)
|
||||||
|
lines = lines[start:end]
|
||||||
|
|
||||||
|
def iterate():
|
||||||
|
for line in lines:
|
||||||
|
subpath = line
|
||||||
|
subpath, nsample = line.split("\t")
|
||||||
|
yield f"{root}/{subpath}", int(nsample)
|
||||||
|
|
||||||
|
return iterate, len(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def dump_feature(reader, generator, num, split, nshard, rank, feat_dir):
|
||||||
|
iterator = generator()
|
||||||
|
|
||||||
|
feat_path = f"{feat_dir}/{split}_{rank}_{nshard}.npy"
|
||||||
|
leng_path = f"{feat_dir}/{split}_{rank}_{nshard}.len"
|
||||||
|
|
||||||
|
os.makedirs(feat_dir, exist_ok=True)
|
||||||
|
if os.path.exists(feat_path):
|
||||||
|
os.remove(feat_path)
|
||||||
|
feat_f = NpyAppendArray(feat_path)
|
||||||
|
|
||||||
|
with open(leng_path, "w") as leng_f:
|
||||||
|
for path, nsample in tqdm.tqdm(iterator, total=num):
|
||||||
|
feat = reader.get_feats(path, nsample)
|
||||||
|
feat_f.append(feat.cpu().numpy())
|
||||||
|
leng_f.write(f"{len(feat)}\n")
|
||||||
|
logger.info("finished successfully")
|
@ -0,0 +1,133 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from Fairseq 2023 (https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/simple_kmeans/learn_kmeans.py)
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import joblib
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.cluster import MiniBatchKMeans
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
||||||
|
stream=sys.stdout, )
|
||||||
|
logger = logging.getLogger("learn_kmeans")
|
||||||
|
|
||||||
|
|
||||||
|
def get_km_model(
|
||||||
|
n_clusters,
|
||||||
|
init,
|
||||||
|
max_iter,
|
||||||
|
batch_size,
|
||||||
|
tol,
|
||||||
|
max_no_improvement,
|
||||||
|
n_init,
|
||||||
|
reassignment_ratio, ):
|
||||||
|
return MiniBatchKMeans(
|
||||||
|
n_clusters=n_clusters,
|
||||||
|
init=init,
|
||||||
|
max_iter=max_iter,
|
||||||
|
batch_size=batch_size,
|
||||||
|
verbose=1,
|
||||||
|
compute_labels=False,
|
||||||
|
tol=tol,
|
||||||
|
max_no_improvement=max_no_improvement,
|
||||||
|
init_size=None,
|
||||||
|
n_init=n_init,
|
||||||
|
reassignment_ratio=reassignment_ratio, )
|
||||||
|
|
||||||
|
|
||||||
|
def load_feature_shard(feat_dir, split, nshard, rank, percent):
|
||||||
|
feat_path = f"{feat_dir}/{split}_{rank}_{nshard}.npy"
|
||||||
|
leng_path = f"{feat_dir}/{split}_{rank}_{nshard}.len"
|
||||||
|
with open(leng_path, "r") as f:
|
||||||
|
lengs = [int(line.rstrip()) for line in f]
|
||||||
|
offsets = [0] + np.cumsum(lengs[:-1]).tolist()
|
||||||
|
|
||||||
|
if percent < 0:
|
||||||
|
return np.load(feat_path, mmap_mode="r")
|
||||||
|
else:
|
||||||
|
nsample = int(np.ceil(len(lengs) * percent))
|
||||||
|
indices = np.random.choice(len(lengs), nsample, replace=False)
|
||||||
|
feat = np.load(feat_path, mmap_mode="r")
|
||||||
|
sampled_feat = np.concatenate(
|
||||||
|
[feat[offsets[i]:offsets[i] + lengs[i]] for i in indices], axis=0)
|
||||||
|
logger.info(
|
||||||
|
(f"sampled {nsample} utterances, {len(sampled_feat)} frames "
|
||||||
|
f"from shard {rank}/{nshard}"))
|
||||||
|
return sampled_feat
|
||||||
|
|
||||||
|
|
||||||
|
def load_feature(feat_dir, split, nshard, seed, percent):
|
||||||
|
assert percent <= 1.0
|
||||||
|
feat = np.concatenate(
|
||||||
|
[
|
||||||
|
load_feature_shard(feat_dir, split, nshard, r, percent)
|
||||||
|
for r in range(nshard)
|
||||||
|
],
|
||||||
|
axis=0, )
|
||||||
|
logging.info(f"loaded feature with dimension {feat.shape}")
|
||||||
|
return feat
|
||||||
|
|
||||||
|
|
||||||
|
def learn_kmeans(
|
||||||
|
feat_dir,
|
||||||
|
split,
|
||||||
|
nshard,
|
||||||
|
km_path,
|
||||||
|
n_clusters,
|
||||||
|
seed,
|
||||||
|
percent,
|
||||||
|
init,
|
||||||
|
max_iter,
|
||||||
|
batch_size,
|
||||||
|
tol,
|
||||||
|
n_init,
|
||||||
|
reassignment_ratio,
|
||||||
|
max_no_improvement, ):
|
||||||
|
np.random.seed(seed)
|
||||||
|
feat = load_feature(feat_dir, split, nshard, seed, percent)
|
||||||
|
km_model = get_km_model(
|
||||||
|
n_clusters,
|
||||||
|
init,
|
||||||
|
max_iter,
|
||||||
|
batch_size,
|
||||||
|
tol,
|
||||||
|
max_no_improvement,
|
||||||
|
n_init,
|
||||||
|
reassignment_ratio, )
|
||||||
|
km_model.fit(feat)
|
||||||
|
joblib.dump(km_model, km_path)
|
||||||
|
|
||||||
|
inertia = -km_model.score(feat) / len(feat)
|
||||||
|
logger.info("total intertia: %.5f", inertia)
|
||||||
|
logger.info("finished successfully")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("feat_dir", type=str)
|
||||||
|
parser.add_argument("split", type=str)
|
||||||
|
parser.add_argument("nshard", type=int)
|
||||||
|
parser.add_argument("km_path", type=str)
|
||||||
|
parser.add_argument("n_clusters", type=int)
|
||||||
|
parser.add_argument("--seed", default=0, type=int)
|
||||||
|
parser.add_argument(
|
||||||
|
"--percent", default=-1, type=float, help="sample a subset; -1 for all")
|
||||||
|
parser.add_argument("--init", default="k-means++")
|
||||||
|
parser.add_argument("--max_iter", default=100, type=int)
|
||||||
|
parser.add_argument("--batch_size", default=10000, type=int)
|
||||||
|
parser.add_argument("--tol", default=0.0, type=float)
|
||||||
|
parser.add_argument("--max_no_improvement", default=100, type=int)
|
||||||
|
parser.add_argument("--n_init", default=20, type=int)
|
||||||
|
parser.add_argument("--reassignment_ratio", default=0.0, type=float)
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(str(args))
|
||||||
|
|
||||||
|
learn_kmeans(**vars(args))
|
Loading…
Reference in new issue