diff --git a/paddlespeech/s2t/models/hubert/kmeans/dump_hubert_feature.py b/paddlespeech/s2t/models/hubert/kmeans/dump_hubert_feature.py new file mode 100644 index 000000000..643510b9e --- /dev/null +++ b/paddlespeech/s2t/models/hubert/kmeans/dump_hubert_feature.py @@ -0,0 +1,112 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from Fairseq 2023 (https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/simple_kmeans/dump_hubert_feature.py) +import logging +import os +import sys +import time +from typing import Any +from typing import Dict +from typing import Optional +from typing import Union + +import paddle +import soundfile as sf +import tqdm +from feature_utils import dump_feature +from feature_utils import get_path_iterator +from paddlespeech.s2t.models.hubert.hubert_ASR import HubertASR +from paddlespeech.s2t.models.hubert.modules.hubert_model import HubertConfig +from paddlespeech.s2t.models.hubert.modules.hubert_model import HubertModel +from paddlespeech.s2t.models.hubert.modules.hubert_model import HubertPretrainingConfig +from paddlespeech.s2t.modules.align import LayerNorm +from paddlespeech.s2t.utils.utility import UpdateConfig +from yacs.config import CfgNode + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, ) +logger = logging.getLogger("dump_hubert_feature") + + +class HubertFeatureReader(object): + def __init__(self, config, layer, max_chunk=1600000): + self.config = CfgNode(new_allowed=True) + self.config.merge_from_file(config) + self.config.output_dim = 5002 + # import pdb + # pdb.set_trace() + model = HubertASR.from_config(self.config) + # model_dict = paddle.load(self.config.hubert_params_path) + # model.hubert.set_state_dict(model_dict) + + self.model = model + with open(self.config.vocab_filepath) as f: + dicts = [symbol.strip() for symbol in f.readlines()] + task_cfg = self.model.merge_with_parent(HubertPretrainingConfig, + dict(self.config.task_cfg)) + model_cfg = self.model.merge_with_parent(HubertConfig, + dict(self.config.model_cfg)) + self.hubert = HubertModel(model_cfg, task_cfg, dicts) + model_dict = paddle.load(self.config.hubert_params_path) + self.hubert.set_state_dict(model_dict) + + self.model.eval() + self.layer = layer + self.max_chunk = max_chunk + logger.info(f" max_chunk = {self.max_chunk}") + + def read_audio(self, path, ref_len=None): + wav, _ = sf.read(path, dtype="float32", always_2d=True) + if wav.ndim == 2: + wav = wav.mean(-1) + assert wav.ndim == 1, wav.ndim + if ref_len is not None and abs(ref_len - len(wav)) > 160: + logging.warning(f"ref {ref_len} != read {len(wav)} ({path})") + return wav + + def get_feats(self, path, ref_len=None): + x = self.read_audio(path, ref_len=ref_len) + with paddle.no_grad(): + x = paddle.to_tensor(x).float().cuda() + # if self.task.cfg.normalize: + # x = LayerNorm(x, x.shape) + x = x.view(1, -1) + + feat = [] + for start in range(0, x.shape[0], self.max_chunk): + x_chunk = x[:, start:start + self.max_chunk] + feat_chunk, _ = self.hubert.extract_features( + source=x_chunk, + padding_mask=None, + mask=False, + output_layer=self.layer, ) + feat.append(feat_chunk) + return paddle.concat(feat, 1).squeeze(0) + + +def main(tsv_dir, split, config, layer, nshard, rank, feat_dir, max_chunk): + reader = HubertFeatureReader(config, layer, max_chunk) + generator, num = get_path_iterator(f"{tsv_dir}/{split}.tsv", nshard, rank) + dump_feature(reader, generator, num, split, nshard, rank, feat_dir) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("tsv_dir") + parser.add_argument("split") + parser.add_argument("config") + parser.add_argument("layer", type=int) + parser.add_argument("nshard", type=int) + parser.add_argument("rank", type=int) + parser.add_argument("feat_dir") + parser.add_argument("--max_chunk", type=int, default=1600000) + args = parser.parse_args() + logger.info(args) + + main(**vars(args)) diff --git a/paddlespeech/s2t/models/hubert/kmeans/dump_km_label.py b/paddlespeech/s2t/models/hubert/kmeans/dump_km_label.py new file mode 100644 index 000000000..a906cbf79 --- /dev/null +++ b/paddlespeech/s2t/models/hubert/kmeans/dump_km_label.py @@ -0,0 +1,85 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from Fairseq 2023 (https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/simple_kmeans/dump_km_label.py) +import logging +import os +import sys + +import joblib +import numpy as np +import paddle +import tqdm + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, ) +logger = logging.getLogger("dump_km_label") + + +class ApplyKmeans(object): + def __init__(self, km_path): + self.km_model = joblib.load(km_path) + self.C_np = self.km_model.cluster_centers_.transpose() + self.Cnorm_np = (self.C_np**2).sum(0, keepdims=True) + + self.C = paddle.to_tensor(self.C_np) + self.Cnorm = paddle.to_tensor(self.Cnorm_np) + + def __call__(self, x): + if isinstance(x, paddle.Tensor): + dist = (x.pow(2).sum(1, keepdim=True) - 2 * paddle.matmul(x, self.C) + + self.Cnorm) + return dist.argmin(dim=1).cpu().numpy() + else: + dist = ((x**2).sum(1, keepdims=True) - 2 * np.matmul(x, self.C_np) + + self.Cnorm_np) + return np.argmin(dist, axis=1) + + +def get_feat_iterator(feat_dir, split, nshard, rank): + feat_path = f"{feat_dir}/{split}_{rank}_{nshard}.npy" + leng_path = f"{feat_dir}/{split}_{rank}_{nshard}.len" + with open(leng_path, "r") as f: + lengs = [int(line.rstrip()) for line in f] + offsets = [0] + np.cumsum(lengs[:-1]).tolist() + + def iterate(): + feat = np.load(feat_path, mmap_mode="r") + assert feat.shape[0] == (offsets[-1] + lengs[-1]) + for offset, leng in zip(offsets, lengs): + yield feat[offset:offset + leng] + + return iterate, len(lengs) + + +def dump_label(feat_dir, split, km_path, nshard, rank, lab_dir): + apply_kmeans = ApplyKmeans(km_path) + generator, num = get_feat_iterator(feat_dir, split, nshard, rank) + iterator = generator() + + lab_path = f"{lab_dir}/{split}_{rank}_{nshard}.km" + os.makedirs(lab_dir, exist_ok=True) + with open(lab_path, "w") as f: + for feat in tqdm.tqdm(iterator, total=num): + lab = apply_kmeans(feat).tolist() + f.write(" ".join(map(str, lab)) + "\n") + logger.info("finished successfully") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("feat_dir") + parser.add_argument("split") + parser.add_argument("km_path") + parser.add_argument("nshard", type=int) + parser.add_argument("rank", type=int) + parser.add_argument("lab_dir") + args = parser.parse_args() + logging.info(str(args)) + + dump_label(**vars(args)) diff --git a/paddlespeech/s2t/models/hubert/kmeans/dump_mfcc_feature.py b/paddlespeech/s2t/models/hubert/kmeans/dump_mfcc_feature.py new file mode 100644 index 000000000..a4f766754 --- /dev/null +++ b/paddlespeech/s2t/models/hubert/kmeans/dump_mfcc_feature.py @@ -0,0 +1,72 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from Fairseq 2023 (https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/simple_kmeans/dump_mfcc_feature.py) +import logging +import os +import sys + +import paddle +import paddleaudio +import soundfile as sf +from feature_utils import dump_feature +from feature_utils import get_path_iterator +from python_speech_features import delta + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, ) +logger = logging.getLogger("dump_mfcc_feature") + + +class MfccFeatureReader(object): + def __init__(self, sample_rate): + self.sample_rate = sample_rate + + def read_audio(self, path, ref_len=None): + wav, _ = sf.read(path, dtype="float32", always_2d=True) + if ref_len is not None and abs(ref_len - len(wav)) > 160: + logging.warning(f"ref {ref_len} != read {len(wav)} ({path})") + return wav + + def get_feats(self, path, ref_len=None): + x = self.read_audio(path, ref_len=ref_len) + with paddle.no_grad(): + x = paddle.to_tensor(x, dtype="float32") + x = x.reshape([1, -1]) + + mfccs = paddleaudio.compliance.kaldi.mfcc( + waveform=x, + sr=self.sample_rate, + use_energy=False, ) # (freq, time) + + deltas = delta(mfccs, 2) + ddeltas = delta(deltas, 2) + concat = paddle.concat( + x=[mfccs, paddle.to_tensor(deltas), paddle.to_tensor(ddeltas)], + axis=-1) + return concat + + +def main(tsv_dir, split, nshard, rank, feat_dir, sample_rate): + reader = MfccFeatureReader(sample_rate) + generator, num = get_path_iterator(f"{tsv_dir}/{split}.tsv", nshard, rank) + dump_feature(reader, generator, num, split, nshard, rank, feat_dir) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("tsv_dir") + parser.add_argument("split") + parser.add_argument("nshard", type=int) + parser.add_argument("rank", type=int) + parser.add_argument("feat_dir") + parser.add_argument("--sample_rate", type=int, default=16000) + args = parser.parse_args() + logger.info(args) + + main(**vars(args)) diff --git a/paddlespeech/s2t/models/hubert/kmeans/feature_utils.py b/paddlespeech/s2t/models/hubert/kmeans/feature_utils.py new file mode 100644 index 000000000..a5a2f69fe --- /dev/null +++ b/paddlespeech/s2t/models/hubert/kmeans/feature_utils.py @@ -0,0 +1,62 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from Fairseq 2023 (https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/simple_kmeans/feature_utils.py) +import logging +import os +import sys + +import tqdm +from npy_append_array import NpyAppendArray + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, ) +logger = logging.getLogger("feature_utils") + + +def get_shard_range(tot, nshard, rank): + assert rank < nshard and rank >= 0, f"invaid rank/nshard {rank}/{nshard}" + start = round(tot / nshard * rank) + end = round(tot / nshard * (rank + 1)) + assert start < end, f"start={start}, end={end}" + logger.info(f"rank {rank} of {nshard}, process {end-start} " + f"({start}-{end}) out of {tot}") + return start, end + + +def get_path_iterator(tsv, nshard, rank): + with open(tsv, "r") as f: + root = f.readline().rstrip() + lines = [line.rstrip() for line in f] + start, end = get_shard_range(len(lines), nshard, rank) + lines = lines[start:end] + + def iterate(): + for line in lines: + subpath = line + subpath, nsample = line.split("\t") + yield f"{root}/{subpath}", int(nsample) + + return iterate, len(lines) + + +def dump_feature(reader, generator, num, split, nshard, rank, feat_dir): + iterator = generator() + + feat_path = f"{feat_dir}/{split}_{rank}_{nshard}.npy" + leng_path = f"{feat_dir}/{split}_{rank}_{nshard}.len" + + os.makedirs(feat_dir, exist_ok=True) + if os.path.exists(feat_path): + os.remove(feat_path) + feat_f = NpyAppendArray(feat_path) + + with open(leng_path, "w") as leng_f: + for path, nsample in tqdm.tqdm(iterator, total=num): + feat = reader.get_feats(path, nsample) + feat_f.append(feat.cpu().numpy()) + leng_f.write(f"{len(feat)}\n") + logger.info("finished successfully") diff --git a/paddlespeech/s2t/models/hubert/kmeans/learn_kmeans.py b/paddlespeech/s2t/models/hubert/kmeans/learn_kmeans.py new file mode 100644 index 000000000..f30d75431 --- /dev/null +++ b/paddlespeech/s2t/models/hubert/kmeans/learn_kmeans.py @@ -0,0 +1,133 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from Fairseq 2023 (https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/simple_kmeans/learn_kmeans.py) +import logging +import os +import sys + +import joblib +import numpy as np +from sklearn.cluster import MiniBatchKMeans + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, ) +logger = logging.getLogger("learn_kmeans") + + +def get_km_model( + n_clusters, + init, + max_iter, + batch_size, + tol, + max_no_improvement, + n_init, + reassignment_ratio, ): + return MiniBatchKMeans( + n_clusters=n_clusters, + init=init, + max_iter=max_iter, + batch_size=batch_size, + verbose=1, + compute_labels=False, + tol=tol, + max_no_improvement=max_no_improvement, + init_size=None, + n_init=n_init, + reassignment_ratio=reassignment_ratio, ) + + +def load_feature_shard(feat_dir, split, nshard, rank, percent): + feat_path = f"{feat_dir}/{split}_{rank}_{nshard}.npy" + leng_path = f"{feat_dir}/{split}_{rank}_{nshard}.len" + with open(leng_path, "r") as f: + lengs = [int(line.rstrip()) for line in f] + offsets = [0] + np.cumsum(lengs[:-1]).tolist() + + if percent < 0: + return np.load(feat_path, mmap_mode="r") + else: + nsample = int(np.ceil(len(lengs) * percent)) + indices = np.random.choice(len(lengs), nsample, replace=False) + feat = np.load(feat_path, mmap_mode="r") + sampled_feat = np.concatenate( + [feat[offsets[i]:offsets[i] + lengs[i]] for i in indices], axis=0) + logger.info( + (f"sampled {nsample} utterances, {len(sampled_feat)} frames " + f"from shard {rank}/{nshard}")) + return sampled_feat + + +def load_feature(feat_dir, split, nshard, seed, percent): + assert percent <= 1.0 + feat = np.concatenate( + [ + load_feature_shard(feat_dir, split, nshard, r, percent) + for r in range(nshard) + ], + axis=0, ) + logging.info(f"loaded feature with dimension {feat.shape}") + return feat + + +def learn_kmeans( + feat_dir, + split, + nshard, + km_path, + n_clusters, + seed, + percent, + init, + max_iter, + batch_size, + tol, + n_init, + reassignment_ratio, + max_no_improvement, ): + np.random.seed(seed) + feat = load_feature(feat_dir, split, nshard, seed, percent) + km_model = get_km_model( + n_clusters, + init, + max_iter, + batch_size, + tol, + max_no_improvement, + n_init, + reassignment_ratio, ) + km_model.fit(feat) + joblib.dump(km_model, km_path) + + inertia = -km_model.score(feat) / len(feat) + logger.info("total intertia: %.5f", inertia) + logger.info("finished successfully") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("feat_dir", type=str) + parser.add_argument("split", type=str) + parser.add_argument("nshard", type=int) + parser.add_argument("km_path", type=str) + parser.add_argument("n_clusters", type=int) + parser.add_argument("--seed", default=0, type=int) + parser.add_argument( + "--percent", default=-1, type=float, help="sample a subset; -1 for all") + parser.add_argument("--init", default="k-means++") + parser.add_argument("--max_iter", default=100, type=int) + parser.add_argument("--batch_size", default=10000, type=int) + parser.add_argument("--tol", default=0.0, type=float) + parser.add_argument("--max_no_improvement", default=100, type=int) + parser.add_argument("--n_init", default=20, type=int) + parser.add_argument("--reassignment_ratio", default=0.0, type=float) + args = parser.parse_args() + logging.info(str(args)) + + learn_kmeans(**vars(args))