add vector csv dataset format, test=doc

pull/1630/head
xiongxinlei 3 years ago
parent 5b05300e53
commit 30b5b3cb9e

@ -4,9 +4,9 @@
# we should explicitly specify the wav path of vox2 audio data converted from m4a # we should explicitly specify the wav path of vox2 audio data converted from m4a
vox2_base_path: vox2_base_path:
augment: True augment: True
batch_size: 16 batch_size: 32
num_workers: 2 num_workers: 2
num_speakers: 7205 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41 num_speakers: 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
shuffle: True shuffle: True
skip_prep: False skip_prep: False
split_ratio: 0.9 split_ratio: 0.9

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
stage=7 stage=0
stop_stage=100 stop_stage=100
. ${MAIN_ROOT}/utils/parse_options.sh || exit -1; . ${MAIN_ROOT}/utils/parse_options.sh || exit -1;
@ -32,7 +32,7 @@ mkdir -p ${dir}
# Generally the `MAIN_ROOT` refers to the root of PaddleSpeech, # Generally the `MAIN_ROOT` refers to the root of PaddleSpeech,
# which is defined in the path.sh # which is defined in the path.sh
# And we will download the # And we will download the voxceleb data and rirs noise to ${MAIN_ROOT}/dataset
TARGET_DIR=${MAIN_ROOT}/dataset TARGET_DIR=${MAIN_ROOT}/dataset
mkdir -p ${TARGET_DIR} mkdir -p ${TARGET_DIR}
@ -98,7 +98,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# generate the vox csv file # generate the vox csv file
# Currently, our training system use csv file for dataset # Currently, our training system use csv file for dataset
echo "convert the json format to csv format to be compatible with training process" echo "convert the json format to csv format to be compatible with training process"
python3 local/make_csv_dataset_from_json.py\ python3 local/make_vox_csv_dataset_from_json.py\
--train "${dir}/vox1/manifest.dev" \ --train "${dir}/vox1/manifest.dev" \
--test "${dir}/vox1/manifest.test" \ --test "${dir}/vox1/manifest.test" \
--target_dir "${dir}/vox/" \ --target_dir "${dir}/vox/" \

@ -20,31 +20,29 @@ import csv
import os import os
from typing import List from typing import List
import paddle
import tqdm import tqdm
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.training.seeding import seed_everything
logger = Log(__name__).getlog()
from paddleaudio import load as load_audio from paddleaudio import load as load_audio
from paddleaudio import save as save_wav from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.utils.utils import get_chunks
def get_chunks(seg_dur, audio_id, audio_duration): logger = Log(__name__).getlog()
num_chunks = int(audio_duration / seg_dur) # all in milliseconds
chunk_lst = [
audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur)
for i in range(num_chunks)
]
return chunk_lst
def get_chunks_list(wav_file: str,
split_chunks: bool,
base_path: str,
chunk_duration: float=3.0) -> List[List[str]]:
"""Get the single audio file info
def get_audio_info(wav_file: str, Args:
split_chunks: bool, wav_file (list): the wav audio file and get this audio segment info list
base_path: str, split_chunks (bool): audio split flag
chunk_duration: float=3.0) -> List[List[str]]: base_path (str): the audio base path
chunk_duration (float): the chunk duration.
if set the split_chunks, we split the audio into multi-chunks segment.
"""
waveform, sr = load_audio(wav_file) waveform, sr = load_audio(wav_file)
audio_id = wav_file.split("/rir_noise/")[-1].split(".")[0] audio_id = wav_file.split("/rir_noise/")[-1].split(".")[0]
audio_duration = waveform.shape[0] / sr audio_duration = waveform.shape[0] / sr
@ -57,13 +55,16 @@ def get_audio_info(wav_file: str,
s, e = chunk.split("_")[-2:] # Timestamps of start and end s, e = chunk.split("_")[-2:] # Timestamps of start and end
start_sample = int(float(s) * sr) start_sample = int(float(s) * sr)
end_sample = int(float(e) * sr) end_sample = int(float(e) * sr)
new_wav_file = os.path.join(base_path,
audio_id + f'_chunk_{idx+1:02}.wav') # currently, all vector csv data format use one representation
save_wav(waveform[start_sample:end_sample], sr, new_wav_file) # id, duration, wav, start, stop, spk_id
# id, duration, new_wav ret.append([
ret.append([chunk, chunk_duration, new_wav_file]) chunk, audio_duration, wav_file, start_sample, end_sample,
"noise"
])
else: # Keep whole audio. else: # Keep whole audio.
ret.append([audio_id, audio_duration, wav_file]) ret.append(
[audio_id, audio_duration, wav_file, 0, waveform.shape[0], "noise"])
return ret return ret
@ -71,12 +72,20 @@ def generate_csv(wav_files,
output_file: str, output_file: str,
base_path: str, base_path: str,
split_chunks: bool=True): split_chunks: bool=True):
print(f'Generating csv: {output_file}') """Prepare the csv file according the wav files
header = ["id", "duration", "wav"]
Args:
wav_files (list): all the audio list to prepare the csv file
output_file (str): the output csv file
config (CfgNode): yaml configuration content
split_chunks (bool): audio split flag
"""
logger.info(f'Generating csv: {output_file}')
header = ["utt_id", "duration", "wav", "start", "stop", "lab_id"]
csv_lines = [] csv_lines = []
for item in tqdm.tqdm(wav_files): for item in tqdm.tqdm(wav_files):
csv_lines.extend( csv_lines.extend(
get_audio_info( get_chunks_list(
item, base_path=base_path, split_chunks=split_chunks)) item, base_path=base_path, split_chunks=split_chunks))
if not os.path.exists(os.path.dirname(output_file)): if not os.path.exists(os.path.dirname(output_file)):
@ -91,11 +100,12 @@ def generate_csv(wav_files,
def prepare_data(args, config): def prepare_data(args, config):
# stage0: set the cpu device, """Convert the jsonline format to csv format
# all data prepare process will be done in cpu mode
paddle.device.set_device("cpu") Args:
# set the random seed, it is a must for multiprocess training args (argparse.Namespace): scripts args
seed_everything(config.seed) config (CfgNode): yaml configuration content
"""
# if external config set the skip_prep flat, we will do nothing # if external config set the skip_prep flat, we will do nothing
if config.skip_prep: if config.skip_prep:
return return
@ -119,6 +129,7 @@ def prepare_data(args, config):
noise_files.append(os.path.join(base_path, noise_file)) noise_files.append(os.path.join(base_path, noise_file))
csv_path = os.path.join(args.data_dir, 'csv') csv_path = os.path.join(args.data_dir, 'csv')
logger.info(f"csv path: {csv_path}")
generate_csv( generate_csv(
rir_files, os.path.join(csv_path, 'rir.csv'), base_path=base_path) rir_files, os.path.join(csv_path, 'rir.csv'), base_path=base_path)
generate_csv( generate_csv(

@ -21,51 +21,34 @@ import json
import os import os
import random import random
import paddle
import tqdm import tqdm
from yacs.config import CfgNode from yacs.config import CfgNode
from paddleaudio import load as load_audio from paddleaudio import load as load_audio
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.training.seeding import seed_everything from paddlespeech.vector.utils.utils import get_chunks
logger = Log(__name__).getlog()
def get_chunks(seg_dur, audio_id, audio_duration):
"""Get all chunk segments from a utterance
Args: logger = Log(__name__).getlog()
seg_dur (float): segment chunk duration
audio_id (str): utterance name
audio_duration (float): utterance duration
Returns:
List: all the chunk segments
"""
num_chunks = int(audio_duration / seg_dur) # all in milliseconds
chunk_lst = [
audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur)
for i in range(num_chunks)
]
return chunk_lst
def prepare_csv(wav_files, output_file, config, split_chunks=True): def prepare_csv(wav_files, output_file, config, split_chunks=True):
"""Prepare the csv file according the wav files """Prepare the csv file according the wav files
Args: Args:
dataset_list (list): all the dataset to get the test utterances wav_files (list): all the audio list to prepare the csv file
verification_file (str): voxceleb1 trial file output_file (str): the output csv file
config (CfgNode): yaml configuration content
split_chunks (bool): audio split flag
""" """
if not os.path.exists(os.path.dirname(output_file)): if not os.path.exists(os.path.dirname(output_file)):
os.makedirs(os.path.dirname(output_file)) os.makedirs(os.path.dirname(output_file))
csv_lines = [] csv_lines = []
header = ["id", "duration", "wav", "start", "stop", "spk_id"] header = ["utt_id", "duration", "wav", "start", "stop", "lab_id"]
# voxceleb meta info for each training utterance segment # voxceleb meta info for each training utterance segment
# we extract a segment from a utterance to train # we extract a segment from a utterance to train
# and the segment' period is between start and stop time point in the original wav file # and the segment' period is between start and stop time point in the original wav file
# each field in the meta means as follows: # each field in the meta means as follows:
# id: the utterance segment name # utt_id: the utterance segment name
# duration: utterance segment time # duration: utterance segment time
# wav: utterance file path # wav: utterance file path
# start: start point in the original wav file # start: start point in the original wav file
@ -194,11 +177,9 @@ def prepare_data(args, config):
args (argparse.Namespace): scripts args args (argparse.Namespace): scripts args
config (CfgNode): yaml configuration content config (CfgNode): yaml configuration content
""" """
# stage0: set the cpu device, # stage0: set the random seed
# all data prepare process will be done in cpu mode random.seed(config.seed)
paddle.device.set_device("cpu")
# set the random seed, it is a must for multiprocess training
seed_everything(config.seed)
# if external config set the skip_prep flat, we will do nothing # if external config set the skip_prep flat, we will do nothing
if config.skip_prep: if config.skip_prep:
return return

@ -29,7 +29,7 @@ from paddlespeech.vector.io.augment import waveform_augment
from paddlespeech.vector.io.batch import batch_pad_right from paddlespeech.vector.io.batch import batch_pad_right
from paddlespeech.vector.io.batch import feature_normalize from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.io.batch import waveform_collate_fn from paddlespeech.vector.io.batch import waveform_collate_fn
from paddlespeech.vector.io.dataset import VoxCelebDataset from paddlespeech.vector.io.dataset import CSVDataset
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.loss import AdditiveAngularMargin from paddlespeech.vector.modules.loss import AdditiveAngularMargin
from paddlespeech.vector.modules.loss import LogSoftmaxWrapper from paddlespeech.vector.modules.loss import LogSoftmaxWrapper
@ -55,11 +55,11 @@ def main(args, config):
# stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline # stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline
# note: some cmd must do in rank==0, so wo will refactor the data prepare code # note: some cmd must do in rank==0, so wo will refactor the data prepare code
train_dataset = VoxCelebDataset( train_dataset = CSVDataset(
csv_path=os.path.join(args.data_dir, "vox/csv/train.csv"), csv_path=os.path.join(args.data_dir, "vox/csv/train.csv"),
spk_id2label_path=os.path.join(args.data_dir, spk_id2label_path=os.path.join(args.data_dir,
"vox/meta/spk_id2label.txt")) "vox/meta/spk_id2label.txt"))
dev_dataset = VoxCelebDataset( dev_dataset = CSVDataset(
csv_path=os.path.join(args.data_dir, "vox/csv/dev.csv"), csv_path=os.path.join(args.data_dir, "vox/csv/dev.csv"),
spk_id2label_path=os.path.join(args.data_dir, spk_id2label_path=os.path.join(args.data_dir,
"vox/meta/spk_id2label.txt")) "vox/meta/spk_id2label.txt"))
@ -74,7 +74,7 @@ def main(args, config):
# stage4: build the speaker verification train instance with backbone model # stage4: build the speaker verification train instance with backbone model
model = SpeakerIdetification( model = SpeakerIdetification(
backbone=ecapa_tdnn, num_class=VoxCeleb.num_speakers) backbone=ecapa_tdnn, num_class=config.num_speakers)
# stage5: build the optimizer, we now only construct the AdamW optimizer # stage5: build the optimizer, we now only construct the AdamW optimizer
# 140000 is single gpu steps # 140000 is single gpu steps
@ -148,6 +148,7 @@ def main(args, config):
train_reader_cost = 0.0 train_reader_cost = 0.0
train_feat_cost = 0.0 train_feat_cost = 0.0
train_run_cost = 0.0 train_run_cost = 0.0
train_misce_cost = 0.0
reader_start = time.time() reader_start = time.time()
for batch_idx, batch in enumerate(train_loader): for batch_idx, batch in enumerate(train_loader):
@ -203,12 +204,14 @@ def main(args, config):
train_run_cost += time.time() - train_start train_run_cost += time.time() - train_start
# stage 9-8: Calculate average loss per batch # stage 9-8: Calculate average loss per batch
avg_loss += loss.numpy()[0] train_misce_start = time.time()
avg_loss = loss.item()
# stage 9-9: Calculate metrics, which is one-best accuracy # stage 9-9: Calculate metrics, which is one-best accuracy
preds = paddle.argmax(logits, axis=1) preds = paddle.argmax(logits, axis=1)
num_corrects += (preds == labels).numpy().sum() num_corrects += (preds == labels).numpy().sum()
num_samples += feats.shape[0] num_samples += feats.shape[0]
timer.count() # step plus one in timer timer.count() # step plus one in timer
# stage 9-10: print the log information only on 0-rank per log-freq batchs # stage 9-10: print the log information only on 0-rank per log-freq batchs
@ -227,6 +230,7 @@ def main(args, config):
train_feat_cost / config.log_interval) train_feat_cost / config.log_interval)
print_msg += ' avg_train_cost: {:.5f} sec,'.format( print_msg += ' avg_train_cost: {:.5f} sec,'.format(
train_run_cost / config.log_interval) train_run_cost / config.log_interval)
print_msg += ' lr={:.4E} step/sec={:.2f} | ETA {}'.format( print_msg += ' lr={:.4E} step/sec={:.2f} | ETA {}'.format(
lr, timer.timing, timer.eta) lr, timer.timing, timer.eta)
logger.info(print_msg) logger.info(print_msg)

@ -14,6 +14,7 @@
# this is modified from SpeechBrain # this is modified from SpeechBrain
# https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/lobes/augment.py # https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/lobes/augment.py
import math import math
import os
from typing import List from typing import List
import numpy as np import numpy as np
@ -22,13 +23,12 @@ import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.dataset import RIRSNoiseDataset from paddlespeech.vector.io.dataset import CSVDataset
from paddlespeech.vector.io.signal_processing import compute_amplitude from paddlespeech.vector.io.signal_processing import compute_amplitude
from paddlespeech.vector.io.signal_processing import convolve1d from paddlespeech.vector.io.signal_processing import convolve1d
from paddlespeech.vector.io.signal_processing import dB_to_amplitude from paddlespeech.vector.io.signal_processing import dB_to_amplitude
from paddlespeech.vector.io.signal_processing import notch_filter from paddlespeech.vector.io.signal_processing import notch_filter
from paddlespeech.vector.io.signal_processing import reverberate from paddlespeech.vector.io.signal_processing import reverberate
# from paddleaudio.datasets.rirs_noises import OpenRIRNoise
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -510,7 +510,7 @@ class AddNoise(nn.Layer):
assert w >= 0, f'Target length {target_length} is less than origin length {x.shape[0]}' assert w >= 0, f'Target length {target_length} is less than origin length {x.shape[0]}'
return np.pad(x, [0, w], mode=mode, **kwargs) return np.pad(x, [0, w], mode=mode, **kwargs)
ids = [item['id'] for item in batch] ids = [item['utt_id'] for item in batch]
lengths = np.asarray([item['feat'].shape[0] for item in batch]) lengths = np.asarray([item['feat'].shape[0] for item in batch])
waveforms = list( waveforms = list(
map(lambda x: pad(x, max(max_length, lengths.max().item())), map(lambda x: pad(x, max(max_length, lengths.max().item())),
@ -590,7 +590,7 @@ class AddReverb(nn.Layer):
assert w >= 0, f'Target length {target_length} is less than origin length {x.shape[0]}' assert w >= 0, f'Target length {target_length} is less than origin length {x.shape[0]}'
return np.pad(x, [0, w], mode=mode, **kwargs) return np.pad(x, [0, w], mode=mode, **kwargs)
ids = [item['id'] for item in batch] ids = [item['utt_id'] for item in batch]
lengths = np.asarray([item['feat'].shape[0] for item in batch]) lengths = np.asarray([item['feat'].shape[0] for item in batch])
waveforms = list( waveforms = list(
map(lambda x: pad(x, lengths.max().item()), map(lambda x: pad(x, lengths.max().item()),
@ -840,10 +840,10 @@ def build_augment_pipeline(target_dir=None) -> List[paddle.nn.Layer]:
List[paddle.nn.Layer]: all augment process List[paddle.nn.Layer]: all augment process
""" """
logger.info("start to build the augment pipeline") logger.info("start to build the augment pipeline")
noise_dataset = RIRSNoiseDataset(csv_path=os.path.join( noise_dataset = CSVDataset(csv_path=os.path.join(target_dir,
target_dir, "rir_noise/csv/noise.csv")) "rir_noise/csv/noise.csv"))
rir_dataset = OpenRIRNoise(csv_path=os.path.join(target_dir, rir_dataset = CSVDataset(csv_path=os.path.join(target_dir,
"rir_noise/csv/rir.csv")) "rir_noise/csv/rir.csv"))
wavedrop = TimeDomainSpecAugment( wavedrop = TimeDomainSpecAugment(
sample_rate=16000, sample_rate=16000,

@ -11,18 +11,38 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections from dataclasses import dataclass
from dataclasses import fields
from paddle.io import Dataset from paddle.io import Dataset
from paddleaudio import load as load_audio from paddleaudio import load as load_audio
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
# the audio meta info in the vector CSVDataset
# utt_id: the utterance segment name
# duration: utterance segment time
# wav: utterance file path
# start: start point in the original wav file
# stop: stop point in the original wav file
# lab_id: the utterance segment's label id
@dataclass
class meta_info:
utt_id: str
duration: float
wav: str
start: int
stop: int
lab_id: str
class VoxCelebDataset(Dataset):
meta_info = collections.namedtuple(
'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'spk_id'))
def __init__(self, csv_path, spk_id2label_path, config): class CSVDataset(Dataset):
# meta_info = collections.namedtuple(
# 'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'spk_id'))
def __init__(self, csv_path, spk_id2label_path=None, config=None):
super().__init__() super().__init__()
self.csv_path = csv_path self.csv_path = csv_path
self.spk_id2label_path = spk_id2label_path self.spk_id2label_path = spk_id2label_path
@ -32,34 +52,41 @@ class VoxCelebDataset(Dataset):
def load_data_csv(self): def load_data_csv(self):
data = [] data = []
with open(self.csv_path, 'r') as rf: with open(self.csv_path, 'r') as rf:
for line in rf.readlines()[1:]: for line in rf.readlines()[1:]:
audio_id, duration, wav, start, stop, spk_id = line.strip( audio_id, duration, wav, start, stop, spk_id = line.strip(
).split(',') ).split(',')
data.append( data.append(
self.meta_info(audio_id, meta_info(audio_id,
float(duration), wav, float(duration), wav,
int(start), int(stop), spk_id)) int(start), int(stop), spk_id))
return data return data
def load_speaker_to_label(self): def load_speaker_to_label(self):
if not self.spk_id2label_path:
logger.warning("No speaker id to label file")
return
spk_id2label = {}
with open(self.spk_id2label_path, 'r') as f: with open(self.spk_id2label_path, 'r') as f:
for line in f.readlines(): for line in f.readlines():
spk_id, label = line.strip().split(' ') spk_id, label = line.strip().split(' ')
self.spk_id2label[spk_id] = int(label) spk_id2label[spk_id] = int(label)
return spk_id2label
def convert_to_record(self, idx: int): def convert_to_record(self, idx: int):
sample = self.data[idx] sample = self.data[idx]
record = {} record = {}
# To show all fields in a namedtuple: `type(sample)._fields` # To show all fields in a namedtuple: `type(sample)._fields`
for field in type(sample)._fields: for field in fields(sample):
record[field] = getattr(sample, field) record[field.name] = getattr(sample, field.name)
waveform, sr = load_audio(record['wav']) waveform, sr = load_audio(record['wav'])
# random select a chunk audio samples from the audio # random select a chunk audio samples from the audio
if self.config.random_chunk: if self.config and self.config.random_chunk:
num_wav_samples = waveform.shape[0] num_wav_samples = waveform.shape[0]
num_chunk_samples = int(self.config.chunk_duration * sr) num_chunk_samples = int(self.config.chunk_duration * sr)
start = random.randint(0, num_wav_samples - num_chunk_samples - 1) start = random.randint(0, num_wav_samples - num_chunk_samples - 1)
@ -71,46 +98,9 @@ class VoxCelebDataset(Dataset):
# we only return the waveform as feat # we only return the waveform as feat
waveform = waveform[start:stop] waveform = waveform[start:stop]
record.update({'feat': waveform}) record.update({'feat': waveform})
record.update({'label': self.spk_id2label[record['spk_id']]}) if self.spk_id2label:
record.update({'label': self.spk_id2label[record['lab_id']]})
return record
def __getitem__(self, idx):
return self.convert_to_record(idx)
def __len__(self):
return len(self.data)
class RIRSNoiseDataset(Dataset):
meta_info = collections.namedtuple('META_INFO', ('id', 'duration', 'wav'))
def __init__(self, csv_path):
super().__init__()
self.csv_path = csv_path
self.data = self.load_csv_data()
def load_csv_data(self):
data = []
with open(self.csv_path, 'r') as rf:
for line in rf.readlines()[1:]:
audio_id, duration, wav = line.strip().split(',')
data.append(self.meta_info(audio_id, float(duration), wav))
random.shuffle(data)
return data
def convert_to_record(self, idx: int):
sample = self.data[idx]
record = {}
# To show all fields in a namedtuple: `type(sample)._fields`
for field in type(sample)._fields:
record[field] = getattr(sample, field)
waveform, sr = load_audio(record['wav'])
record.update({'feat': waveform})
return record return record
def __getitem__(self, idx): def __getitem__(self, idx):

@ -0,0 +1,32 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def get_chunks(seg_dur, audio_id, audio_duration):
"""Get all chunk segments from a utterance
Args:
seg_dur (float): segment chunk duration
audio_id (str): utterance name
audio_duration (float): utterance duration
Returns:
List: all the chunk segments
"""
num_chunks = int(audio_duration / seg_dur) # all in milliseconds
chunk_lst = [
audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur)
for i in range(num_chunks)
]
return chunk_lst
Loading…
Cancel
Save