From bc53f726fece7f1536ad5c0d049c79686af1caee Mon Sep 17 00:00:00 2001 From: ccrrong <1039058843@qq.com> Date: Fri, 8 Apr 2022 21:34:16 +0800 Subject: [PATCH] convert dataset format to paddlespeech, test=doc --- examples/ami/sd0/local/compute_embdding.py | 7 +- examples/ami/sd0/local/process.sh | 4 +- examples/ami/sd0/run.sh | 28 +++-- paddlespeech/vector/io/dataset_from_json.py | 116 ++++++++++++++++++++ 4 files changed, 138 insertions(+), 17 deletions(-) create mode 100644 paddlespeech/vector/io/dataset_from_json.py diff --git a/examples/ami/sd0/local/compute_embdding.py b/examples/ami/sd0/local/compute_embdding.py index 30d49d51..dc824d7c 100644 --- a/examples/ami/sd0/local/compute_embdding.py +++ b/examples/ami/sd0/local/compute_embdding.py @@ -19,7 +19,6 @@ import sys import numpy as np import paddle -from ami_dataset import AMIDataset from paddle.io import BatchSampler from paddle.io import DataLoader from tqdm.contrib import tqdm @@ -28,6 +27,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.utils.log import Log from paddlespeech.vector.cluster.diarization import EmbeddingMeta from paddlespeech.vector.io.batch import batch_feature_normalize +from paddlespeech.vector.io.dataset_from_json import JSONDataset from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.modules.sid_model import SpeakerIdetification from paddlespeech.vector.training.seeding import seed_everything @@ -65,7 +65,7 @@ def create_dataloader(json_file, batch_size): """ # create datasets - dataset = AMIDataset( + dataset = JSONDataset( json_file=json_file, feat_type='melspectrogram', n_mels=config.n_mels, @@ -93,8 +93,7 @@ def main(args, config): ecapa_tdnn = EcapaTdnn(**config.model) # stage2: build the speaker verification eval instance with backbone model - model = SpeakerIdetification( - backbone=ecapa_tdnn, num_class=1) + model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=1) # stage3: load the pre-trained model # we get the last model from the epoch and save_interval diff --git a/examples/ami/sd0/local/process.sh b/examples/ami/sd0/local/process.sh index 72c58b10..1dfd11b8 100755 --- a/examples/ami/sd0/local/process.sh +++ b/examples/ami/sd0/local/process.sh @@ -4,7 +4,6 @@ stage=0 set=L . ${MAIN_ROOT}/utils/parse_options.sh || exit 1; -set -u set -o pipefail data_folder=$1 @@ -12,6 +11,7 @@ manual_annot_folder=$2 save_folder=$3 pretrained_model_dir=$4 conf_path=$5 +device=$6 ref_rttm_dir=${save_folder}/ref_rttms meta_data_dir=${save_folder}/metadata @@ -35,7 +35,7 @@ if [ ${stage} -le 1 ]; then for name in dev eval; do python local/compute_embdding.py --config ${conf_path} \ --data-dir ${save_folder} \ - --device gpu:0 \ + --device ${device} \ --dataset ${name} \ --load-checkpoint ${pretrained_model_dir} done diff --git a/examples/ami/sd0/run.sh b/examples/ami/sd0/run.sh index fc6a91cc..9035f595 100644 --- a/examples/ami/sd0/run.sh +++ b/examples/ami/sd0/run.sh @@ -3,8 +3,7 @@ . ./path.sh || exit 1; set -e -stage=1 -stop_stage=50 +stage=0 #TARGET_DIR=${MAIN_ROOT}/dataset/ami TARGET_DIR=/home/dataset/AMI @@ -12,15 +11,14 @@ data_folder=${TARGET_DIR}/amicorpus #e.g., /path/to/amicorpus/ manual_annot_folder=${TARGET_DIR}/ami_public_manual_1.6.2 #e.g., /path/to/ami_public_manual_1.6.2/ save_folder=./save -pretraind_model_dir=${save_folder}/model - +pretraind_model_dir=${save_folder}/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1/model conf_path=conf/ecapa_tdnn.yaml - +device=gpu . ${MAIN_ROOT}/utils/parse_options.sh || exit 1; -if [ $stage -le 0 ] && [ ${stop_stage} -ge 0 ]; then - # Prepare data and model +if [ $stage -le 0 ]; then + # Prepare data # Download AMI corpus, You need around 10GB of free space to get whole data # The signals are too large to package in this way, # so you need to use the chooser to indicate which ones you wish to download @@ -29,12 +27,20 @@ if [ $stage -le 0 ] && [ ${stop_stage} -ge 0 ]; then echo "Signals: " echo "1) Select one or more AMI meetings: the IDs please follow ./ami_split.py" echo "2) Select media streams: Just select Headset mix" - # Download the pretrained Model from HuggingFace or other pretrained model - echo "Please download the pretrained ECAPA-TDNN Model and put the pretrainde model in given path: "${pretraind_model_dir} fi -if [ $stage -le 1 ] && [ ${stop_stage} -ge 1 ]; then +if [ $stage -le 1 ]; then + # Download the pretrained model + wget https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz + mkdir -p ${save_folder} && tar -xvf sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz -C ${save_folder} + rm -rf sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz + echo "download the pretrained ECAPA-TDNN Model to path: "${pretraind_model_dir} +fi + +if [ $stage -le 2 ]; then # Tune hyperparams on dev set and perform final diarization on dev and eval with best hyperparams. - bash ./local/process.sh ${data_folder} ${manual_annot_folder} ${save_folder} ${pretraind_model_dir} ${conf_path} || exit 1 + echo ${data_folder} ${manual_annot_folder} ${save_folder} ${pretraind_model_dir} ${conf_path} + bash ./local/process.sh ${data_folder} ${manual_annot_folder} \ + ${save_folder} ${pretraind_model_dir} ${conf_path} ${device} || exit 1 fi diff --git a/paddlespeech/vector/io/dataset_from_json.py b/paddlespeech/vector/io/dataset_from_json.py new file mode 100644 index 00000000..5ffd2c18 --- /dev/null +++ b/paddlespeech/vector/io/dataset_from_json.py @@ -0,0 +1,116 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json + +from dataclasses import dataclass +from dataclasses import fields +from paddle.io import Dataset + +from paddleaudio import load as load_audio +from paddleaudio.compliance.librosa import melspectrogram +from paddleaudio.compliance.librosa import mfcc + + +@dataclass +class meta_info: + """the audio meta info in the vector JSONDataset + Args: + id (str): the segment name + duration (float): segment time + wav (str): wav file path + start (int): start point in the original wav file + stop (int): stop point in the original wav file + lab_id (str): the record id + """ + id: str + duration: float + wav: str + start: int + stop: int + record_id: str + + +# json dataset support feature type +feat_funcs = { + 'raw': None, + 'melspectrogram': melspectrogram, + 'mfcc': mfcc, +} + + +class JSONDataset(Dataset): + """ + dataset from json file. + """ + + def __init__(self, json_file: str, feat_type: str='raw', **kwargs): + """ + Ags: + json_file (:obj:`str`): Data prep JSON file. + labels (:obj:`List[int]`): Labels of audio files. + feat_type (:obj:`str`, `optional`, defaults to `raw`): + It identifies the feature type that user wants to extrace of an audio file. + """ + if feat_type not in feat_funcs.keys(): + raise RuntimeError( + f"Unknown feat_type: {feat_type}, it must be one in {list(feat_funcs.keys())}" + ) + + self.json_file = json_file + self.feat_type = feat_type + self.feat_config = kwargs + self._data = self._get_data() + super(JSONDataset, self).__init__() + + def _get_data(self): + with open(self.json_file, "r") as f: + meta_data = json.load(f) + data = [] + for key in meta_data: + sub_seg = meta_data[key]["wav"] + wav = sub_seg["file"] + duration = sub_seg["duration"] + start = sub_seg["start"] + stop = sub_seg["stop"] + rec_id = str(key).rsplit("_", 2)[0] + data.append( + meta_info( + str(key), + float(duration), wav, int(start), int(stop), str(rec_id))) + return data + + def _convert_to_record(self, idx: int): + sample = self._data[idx] + + record = {} + # To show all fields in a namedtuple + for field in fields(sample): + record[field.name] = getattr(sample, field.name) + + waveform, sr = load_audio(record['wav']) + waveform = waveform[record['start']:record['stop']] + + feat_func = feat_funcs[self.feat_type] + feat = feat_func( + waveform, sr=sr, **self.feat_config) if feat_func else waveform + + record.update({'feat': feat}) + + return record + + def __getitem__(self, idx): + return self._convert_to_record(idx) + + def __len__(self): + return len(self._data)