You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
202 lines
7.3 KiB
202 lines
7.3 KiB
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import collections
|
|
import csv
|
|
import os
|
|
import random
|
|
from typing import List
|
|
|
|
from paddle.io import Dataset
|
|
from tqdm import tqdm
|
|
|
|
from ..backends.soundfile_backend import soundfile_load as load_audio
|
|
from ..backends.soundfile_backend import soundfile_save as save_wav
|
|
from ..utils import DATA_HOME
|
|
from ..utils.download import download_and_decompress
|
|
from .dataset import feat_funcs
|
|
|
|
__all__ = ['OpenRIRNoise']
|
|
|
|
|
|
class OpenRIRNoise(Dataset):
|
|
archieves = [
|
|
{
|
|
'url': 'http://www.openslr.org/resources/28/rirs_noises.zip',
|
|
'md5': 'e6f48e257286e05de56413b4779d8ffb',
|
|
},
|
|
]
|
|
|
|
sample_rate = 16000
|
|
meta_info = collections.namedtuple('META_INFO', ('id', 'duration', 'wav'))
|
|
base_path = os.path.join(DATA_HOME, 'open_rir_noise')
|
|
wav_path = os.path.join(base_path, 'RIRS_NOISES')
|
|
csv_path = os.path.join(base_path, 'csv')
|
|
subsets = ['rir', 'noise']
|
|
|
|
def __init__(self,
|
|
subset: str='rir',
|
|
feat_type: str='raw',
|
|
target_dir=None,
|
|
random_chunk: bool=True,
|
|
chunk_duration: float=3.0,
|
|
seed: int=0,
|
|
**kwargs):
|
|
|
|
assert subset in self.subsets, \
|
|
'Dataset subset must be one in {}, but got {}'.format(self.subsets, subset)
|
|
|
|
self.subset = subset
|
|
self.feat_type = feat_type
|
|
self.feat_config = kwargs
|
|
self.random_chunk = random_chunk
|
|
self.chunk_duration = chunk_duration
|
|
|
|
OpenRIRNoise.csv_path = os.path.join(
|
|
target_dir, "open_rir_noise",
|
|
"csv") if target_dir else self.csv_path
|
|
self._data = self._get_data()
|
|
super(OpenRIRNoise, self).__init__()
|
|
|
|
# Set up a seed to reproduce training or predicting result.
|
|
# random.seed(seed)
|
|
|
|
def _get_data(self):
|
|
# Download audio files.
|
|
print(f"rirs noises base path: {self.base_path}")
|
|
if not os.path.isdir(self.base_path):
|
|
download_and_decompress(
|
|
self.archieves, self.base_path, decompress=True)
|
|
else:
|
|
print(
|
|
f"{self.base_path} already exists, we will not download and decompress again"
|
|
)
|
|
|
|
# Data preparation.
|
|
print(f"prepare the csv to {self.csv_path}")
|
|
if not os.path.isdir(self.csv_path):
|
|
os.makedirs(self.csv_path)
|
|
self.prepare_data()
|
|
|
|
data = []
|
|
with open(os.path.join(self.csv_path, f'{self.subset}.csv'), 'r') as rf:
|
|
for line in rf.readlines()[1:]:
|
|
audio_id, duration, wav = line.strip().split(',')
|
|
data.append(self.meta_info(audio_id, float(duration), wav))
|
|
|
|
random.shuffle(data)
|
|
return data
|
|
|
|
def _convert_to_record(self, idx: int):
|
|
sample = self._data[idx]
|
|
|
|
record = {}
|
|
# To show all fields in a namedtuple: `type(sample)._fields`
|
|
for field in type(sample)._fields:
|
|
record[field] = getattr(sample, field)
|
|
|
|
waveform, sr = load_audio(record['wav'])
|
|
|
|
assert self.feat_type in feat_funcs.keys(), \
|
|
f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}"
|
|
feat_func = feat_funcs[self.feat_type]
|
|
feat = feat_func(
|
|
waveform, sr=sr, **self.feat_config) if feat_func else waveform
|
|
|
|
record.update({'feat': feat})
|
|
return record
|
|
|
|
@staticmethod
|
|
def _get_chunks(seg_dur, audio_id, audio_duration):
|
|
num_chunks = int(audio_duration / seg_dur) # all in milliseconds
|
|
|
|
chunk_lst = [
|
|
audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur)
|
|
for i in range(num_chunks)
|
|
]
|
|
return chunk_lst
|
|
|
|
def _get_audio_info(self, wav_file: str,
|
|
split_chunks: bool) -> List[List[str]]:
|
|
waveform, sr = load_audio(wav_file)
|
|
audio_id = wav_file.split("/open_rir_noise/")[-1].split(".")[0]
|
|
audio_duration = waveform.shape[0] / sr
|
|
|
|
ret = []
|
|
if split_chunks and audio_duration > self.chunk_duration: # Split into pieces of self.chunk_duration seconds.
|
|
uniq_chunks_list = self._get_chunks(self.chunk_duration, audio_id,
|
|
audio_duration)
|
|
|
|
for idx, chunk in enumerate(uniq_chunks_list):
|
|
s, e = chunk.split("_")[-2:] # Timestamps of start and end
|
|
start_sample = int(float(s) * sr)
|
|
end_sample = int(float(e) * sr)
|
|
new_wav_file = os.path.join(self.base_path,
|
|
audio_id + f'_chunk_{idx+1:02}.wav')
|
|
save_wav(waveform[start_sample:end_sample], sr, new_wav_file)
|
|
# id, duration, new_wav
|
|
ret.append([chunk, self.chunk_duration, new_wav_file])
|
|
else: # Keep whole audio.
|
|
ret.append([audio_id, audio_duration, wav_file])
|
|
return ret
|
|
|
|
def generate_csv(self,
|
|
wav_files: List[str],
|
|
output_file: str,
|
|
split_chunks: bool=True):
|
|
print(f'Generating csv: {output_file}')
|
|
header = ["id", "duration", "wav"]
|
|
|
|
infos = list(
|
|
tqdm(
|
|
map(self._get_audio_info, wav_files, [split_chunks] * len(
|
|
wav_files)),
|
|
total=len(wav_files)))
|
|
|
|
csv_lines = []
|
|
for info in infos:
|
|
csv_lines.extend(info)
|
|
|
|
with open(output_file, mode="w") as csv_f:
|
|
csv_writer = csv.writer(
|
|
csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
|
csv_writer.writerow(header)
|
|
for line in csv_lines:
|
|
csv_writer.writerow(line)
|
|
|
|
def prepare_data(self):
|
|
rir_list = os.path.join(self.wav_path, "real_rirs_isotropic_noises",
|
|
"rir_list")
|
|
rir_files = []
|
|
with open(rir_list, 'r') as f:
|
|
for line in f.readlines():
|
|
rir_file = line.strip().split(' ')[-1]
|
|
rir_files.append(os.path.join(self.base_path, rir_file))
|
|
|
|
noise_list = os.path.join(self.wav_path, "pointsource_noises",
|
|
"noise_list")
|
|
noise_files = []
|
|
with open(noise_list, 'r') as f:
|
|
for line in f.readlines():
|
|
noise_file = line.strip().split(' ')[-1]
|
|
noise_files.append(os.path.join(self.base_path, noise_file))
|
|
|
|
self.generate_csv(rir_files, os.path.join(self.csv_path, 'rir.csv'))
|
|
self.generate_csv(noise_files, os.path.join(self.csv_path, 'noise.csv'))
|
|
|
|
def __getitem__(self, idx):
|
|
return self._convert_to_record(idx)
|
|
|
|
def __len__(self):
|
|
return len(self._data)
|