commit
a1c6ee5ca1
@ -0,0 +1,370 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from deepspeech.utils.log import Log
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
__all__ = ['CRF']
|
||||
|
||||
|
||||
class CRF(nn.Layer):
|
||||
"""
|
||||
Linear-chain Conditional Random Field (CRF).
|
||||
|
||||
Args:
|
||||
nb_labels (int): number of labels in your tagset, including special symbols.
|
||||
bos_tag_id (int): integer representing the beginning of sentence symbol in
|
||||
your tagset.
|
||||
eos_tag_id (int): integer representing the end of sentence symbol in your tagset.
|
||||
pad_tag_id (int, optional): integer representing the pad symbol in your tagset.
|
||||
If None, the model will treat the PAD as a normal tag. Otherwise, the model
|
||||
will apply constraints for PAD transitions.
|
||||
batch_first (bool): Whether the first dimension represents the batch dimension.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
nb_labels: int,
|
||||
bos_tag_id: int,
|
||||
eos_tag_id: int,
|
||||
pad_tag_id: int=None,
|
||||
batch_first: bool=True):
|
||||
super().__init__()
|
||||
|
||||
self.nb_labels = nb_labels
|
||||
self.BOS_TAG_ID = bos_tag_id
|
||||
self.EOS_TAG_ID = eos_tag_id
|
||||
self.PAD_TAG_ID = pad_tag_id
|
||||
self.batch_first = batch_first
|
||||
|
||||
# initialize transitions from a random uniform distribution between -0.1 and 0.1
|
||||
self.transitions = self.create_parameter(
|
||||
[self.nb_labels, self.nb_labels],
|
||||
default_initializer=nn.initializer.Uniform(-0.1, 0.1))
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
# enforce contraints (rows=from, columns=to) with a big negative number
|
||||
# so exp(-10000) will tend to zero
|
||||
|
||||
# no transitions allowed to the beginning of sentence
|
||||
self.transitions[:, self.BOS_TAG_ID] = -10000.0
|
||||
# no transition alloed from the end of sentence
|
||||
self.transitions[self.EOS_TAG_ID, :] = -10000.0
|
||||
|
||||
if self.PAD_TAG_ID is not None:
|
||||
# no transitions from padding
|
||||
self.transitions[self.PAD_TAG_ID, :] = -10000.0
|
||||
# no transitions to padding
|
||||
self.transitions[:, self.PAD_TAG_ID] = -10000.0
|
||||
# except if the end of sentence is reached
|
||||
# or we are already in a pad position
|
||||
self.transitions[self.PAD_TAG_ID, self.EOS_TAG_ID] = 0.0
|
||||
self.transitions[self.PAD_TAG_ID, self.PAD_TAG_ID] = 0.0
|
||||
|
||||
def forward(self,
|
||||
emissions: paddle.Tensor,
|
||||
tags: paddle.Tensor,
|
||||
mask: paddle.Tensor=None) -> paddle.Tensor:
|
||||
"""Compute the negative log-likelihood. See `log_likelihood` method."""
|
||||
nll = -self.log_likelihood(emissions, tags, mask=mask)
|
||||
return nll
|
||||
|
||||
def log_likelihood(self, emissions, tags, mask=None):
|
||||
"""Compute the probability of a sequence of tags given a sequence of
|
||||
emissions scores.
|
||||
|
||||
Args:
|
||||
emissions (paddle.Tensor): Sequence of emissions for each label.
|
||||
Shape of (batch_size, seq_len, nb_labels) if batch_first is True,
|
||||
(seq_len, batch_size, nb_labels) otherwise.
|
||||
tags (paddle.LongTensor): Sequence of labels.
|
||||
Shape of (batch_size, seq_len) if batch_first is True,
|
||||
(seq_len, batch_size) otherwise.
|
||||
mask (paddle.FloatTensor, optional): Tensor representing valid positions.
|
||||
If None, all positions are considered valid.
|
||||
Shape of (batch_size, seq_len) if batch_first is True,
|
||||
(seq_len, batch_size) otherwise.
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: sum of the log-likelihoods for each sequence in the batch.
|
||||
Shape of ()
|
||||
"""
|
||||
# fix tensors order by setting batch as the first dimension
|
||||
if not self.batch_first:
|
||||
emissions = emissions.transpose(0, 1)
|
||||
tags = tags.transpose(0, 1)
|
||||
|
||||
if mask is None:
|
||||
mask = paddle.ones(emissions.shape[:2], dtype=paddle.float)
|
||||
|
||||
scores = self._compute_scores(emissions, tags, mask=mask)
|
||||
partition = self._compute_log_partition(emissions, mask=mask)
|
||||
return paddle.sum(scores - partition)
|
||||
|
||||
def decode(self, emissions, mask=None):
|
||||
"""Find the most probable sequence of labels given the emissions using
|
||||
the Viterbi algorithm.
|
||||
|
||||
Args:
|
||||
emissions (paddle.Tensor): Sequence of emissions for each label.
|
||||
Shape (batch_size, seq_len, nb_labels) if batch_first is True,
|
||||
(seq_len, batch_size, nb_labels) otherwise.
|
||||
mask (paddle.FloatTensor, optional): Tensor representing valid positions.
|
||||
If None, all positions are considered valid.
|
||||
Shape (batch_size, seq_len) if batch_first is True,
|
||||
(seq_len, batch_size) otherwise.
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: the viterbi score for the for each batch.
|
||||
Shape of (batch_size,)
|
||||
list of lists: the best viterbi sequence of labels for each batch. [B, T]
|
||||
"""
|
||||
# fix tensors order by setting batch as the first dimension
|
||||
if not self.batch_first:
|
||||
emissions = emissions.transpose(0, 1)
|
||||
tags = tags.transpose(0, 1)
|
||||
|
||||
if mask is None:
|
||||
mask = paddle.ones(emissions.shape[:2], dtype=paddle.float)
|
||||
|
||||
scores, sequences = self._viterbi_decode(emissions, mask)
|
||||
return scores, sequences
|
||||
|
||||
def _compute_scores(self, emissions, tags, mask):
|
||||
"""Compute the scores for a given batch of emissions with their tags.
|
||||
|
||||
Args:
|
||||
emissions (paddle.Tensor): (batch_size, seq_len, nb_labels)
|
||||
tags (Paddle.LongTensor): (batch_size, seq_len)
|
||||
mask (Paddle.FloatTensor): (batch_size, seq_len)
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: Scores for each batch.
|
||||
Shape of (batch_size,)
|
||||
"""
|
||||
batch_size, seq_length = tags.shape
|
||||
scores = paddle.zeros([batch_size])
|
||||
|
||||
# save first and last tags to be used later
|
||||
first_tags = tags[:, 0]
|
||||
last_valid_idx = mask.int().sum(1) - 1
|
||||
|
||||
# TODO(Hui Zhang): not support fancy index.
|
||||
# last_tags = tags.gather(last_valid_idx.unsqueeze(1), axis=1).squeeze()
|
||||
batch_idx = paddle.arange(batch_size, dtype=last_valid_idx.dtype)
|
||||
gather_last_valid_idx = paddle.stack(
|
||||
[batch_idx, last_valid_idx], axis=-1)
|
||||
last_tags = tags.gather_nd(gather_last_valid_idx)
|
||||
|
||||
# add the transition from BOS to the first tags for each batch
|
||||
# t_scores = self.transitions[self.BOS_TAG_ID, first_tags]
|
||||
t_scores = self.transitions[self.BOS_TAG_ID].gather(first_tags)
|
||||
|
||||
# add the [unary] emission scores for the first tags for each batch
|
||||
# for all batches, the first word, see the correspondent emissions
|
||||
# for the first tags (which is a list of ids):
|
||||
# emissions[:, 0, [tag_1, tag_2, ..., tag_nblabels]]
|
||||
# e_scores = emissions[:, 0].gather(1, first_tags.unsqueeze(1)).squeeze()
|
||||
gather_first_tags_idx = paddle.stack([batch_idx, first_tags], axis=-1)
|
||||
e_scores = emissions[:, 0].gather_nd(gather_first_tags_idx)
|
||||
|
||||
# the scores for a word is just the sum of both scores
|
||||
scores += e_scores + t_scores
|
||||
|
||||
# now lets do this for each remaining word
|
||||
for i in range(1, seq_length):
|
||||
|
||||
# we could: iterate over batches, check if we reached a mask symbol
|
||||
# and stop the iteration, but vecotrizing is faster due to gpu,
|
||||
# so instead we perform an element-wise multiplication
|
||||
is_valid = mask[:, i]
|
||||
|
||||
previous_tags = tags[:, i - 1]
|
||||
current_tags = tags[:, i]
|
||||
|
||||
# calculate emission and transition scores as we did before
|
||||
# e_scores = emissions[:, i].gather(1, current_tags.unsqueeze(1)).squeeze()
|
||||
gather_current_tags_idx = paddle.stack(
|
||||
[batch_idx, current_tags], axis=-1)
|
||||
e_scores = emissions[:, i].gather_nd(gather_current_tags_idx)
|
||||
# t_scores = self.transitions[previous_tags, current_tags]
|
||||
gather_transitions_idx = paddle.stack(
|
||||
[previous_tags, current_tags], axis=-1)
|
||||
t_scores = self.transitions.gather_nd(gather_transitions_idx)
|
||||
|
||||
# apply the mask
|
||||
e_scores = e_scores * is_valid
|
||||
t_scores = t_scores * is_valid
|
||||
|
||||
scores += e_scores + t_scores
|
||||
|
||||
# add the transition from the end tag to the EOS tag for each batch
|
||||
# scores += self.transitions[last_tags, self.EOS_TAG_ID]
|
||||
scores += self.transitions.gather(last_tags)[:, self.EOS_TAG_ID]
|
||||
|
||||
return scores
|
||||
|
||||
def _compute_log_partition(self, emissions, mask):
|
||||
"""Compute the partition function in log-space using the forward-algorithm.
|
||||
|
||||
Args:
|
||||
emissions (paddle.Tensor): (batch_size, seq_len, nb_labels)
|
||||
mask (Paddle.FloatTensor): (batch_size, seq_len)
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: the partition scores for each batch.
|
||||
Shape of (batch_size,)
|
||||
"""
|
||||
batch_size, seq_length, nb_labels = emissions.shape
|
||||
|
||||
# in the first iteration, BOS will have all the scores
|
||||
alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze(
|
||||
0) + emissions[:, 0]
|
||||
|
||||
for i in range(1, seq_length):
|
||||
# (bs, nb_labels) -> (bs, 1, nb_labels)
|
||||
e_scores = emissions[:, i].unsqueeze(1)
|
||||
|
||||
# (nb_labels, nb_labels) -> (bs, nb_labels, nb_labels)
|
||||
t_scores = self.transitions.unsqueeze(0)
|
||||
|
||||
# (bs, nb_labels) -> (bs, nb_labels, 1)
|
||||
a_scores = alphas.unsqueeze(2)
|
||||
|
||||
scores = e_scores + t_scores + a_scores
|
||||
new_alphas = paddle.logsumexp(scores, axis=1)
|
||||
|
||||
# set alphas if the mask is valid, otherwise keep the current values
|
||||
is_valid = mask[:, i].unsqueeze(-1)
|
||||
alphas = is_valid * new_alphas + (1 - is_valid) * alphas
|
||||
|
||||
# add the scores for the final transition
|
||||
last_transition = self.transitions[:, self.EOS_TAG_ID]
|
||||
end_scores = alphas + last_transition.unsqueeze(0)
|
||||
|
||||
# return a *log* of sums of exps
|
||||
return paddle.logsumexp(end_scores, axis=1)
|
||||
|
||||
def _viterbi_decode(self, emissions, mask):
|
||||
"""Compute the viterbi algorithm to find the most probable sequence of labels
|
||||
given a sequence of emissions.
|
||||
|
||||
Args:
|
||||
emissions (paddle.Tensor): (batch_size, seq_len, nb_labels)
|
||||
mask (Paddle.FloatTensor): (batch_size, seq_len)
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: the viterbi score for the for each batch.
|
||||
Shape of (batch_size,)
|
||||
list of lists of ints: the best viterbi sequence of labels for each batch
|
||||
"""
|
||||
batch_size, seq_length, nb_labels = emissions.shape
|
||||
|
||||
# in the first iteration, BOS will have all the scores and then, the max
|
||||
alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze(
|
||||
0) + emissions[:, 0]
|
||||
|
||||
backpointers = []
|
||||
|
||||
for i in range(1, seq_length):
|
||||
# (bs, nb_labels) -> (bs, 1, nb_labels)
|
||||
e_scores = emissions[:, i].unsqueeze(1)
|
||||
|
||||
# (nb_labels, nb_labels) -> (bs, nb_labels, nb_labels)
|
||||
t_scores = self.transitions.unsqueeze(0)
|
||||
|
||||
# (bs, nb_labels) -> (bs, nb_labels, 1)
|
||||
a_scores = alphas.unsqueeze(2)
|
||||
|
||||
# combine current scores with previous alphas
|
||||
scores = e_scores + t_scores + a_scores
|
||||
|
||||
# so far is exactly like the forward algorithm,
|
||||
# but now, instead of calculating the logsumexp,
|
||||
# we will find the highest score and the tag associated with it
|
||||
# max_scores, max_score_tags = paddle.max(scores, axis=1)
|
||||
max_scores = paddle.max(scores, axis=1)
|
||||
max_score_tags = paddle.argmax(scores, axis=1)
|
||||
|
||||
# set alphas if the mask is valid, otherwise keep the current values
|
||||
is_valid = mask[:, i].unsqueeze(-1)
|
||||
alphas = is_valid * max_scores + (1 - is_valid) * alphas
|
||||
|
||||
# add the max_score_tags for our list of backpointers
|
||||
# max_scores has shape (batch_size, nb_labels) so we transpose it to
|
||||
# be compatible with our previous loopy version of viterbi
|
||||
backpointers.append(max_score_tags.t())
|
||||
|
||||
# add the scores for the final transition
|
||||
last_transition = self.transitions[:, self.EOS_TAG_ID]
|
||||
end_scores = alphas + last_transition.unsqueeze(0)
|
||||
|
||||
# get the final most probable score and the final most probable tag
|
||||
# max_final_scores, max_final_tags = paddle.max(end_scores, axis=1)
|
||||
max_final_scores = paddle.max(end_scores, axis=1)
|
||||
max_final_tags = paddle.argmax(end_scores, axis=1)
|
||||
|
||||
# find the best sequence of labels for each sample in the batch
|
||||
best_sequences = []
|
||||
emission_lengths = mask.int().sum(axis=1)
|
||||
for i in range(batch_size):
|
||||
|
||||
# recover the original sentence length for the i-th sample in the batch
|
||||
sample_length = emission_lengths[i].item()
|
||||
|
||||
# recover the max tag for the last timestep
|
||||
sample_final_tag = max_final_tags[i].item()
|
||||
|
||||
# limit the backpointers until the last but one
|
||||
# since the last corresponds to the sample_final_tag
|
||||
sample_backpointers = backpointers[:sample_length - 1]
|
||||
|
||||
# follow the backpointers to build the sequence of labels
|
||||
sample_path = self._find_best_path(i, sample_final_tag,
|
||||
sample_backpointers)
|
||||
|
||||
# add this path to the list of best sequences
|
||||
best_sequences.append(sample_path)
|
||||
|
||||
return max_final_scores, best_sequences
|
||||
|
||||
def _find_best_path(self, sample_id, best_tag, backpointers):
|
||||
"""Auxiliary function to find the best path sequence for a specific sample.
|
||||
|
||||
Args:
|
||||
sample_id (int): sample index in the range [0, batch_size)
|
||||
best_tag (int): tag which maximizes the final score
|
||||
backpointers (list of lists of tensors): list of pointers with
|
||||
shape (seq_len_i-1, nb_labels, batch_size) where seq_len_i
|
||||
represents the length of the ith sample in the batch
|
||||
|
||||
Returns:
|
||||
list of ints: a list of tag indexes representing the bast path
|
||||
"""
|
||||
# add the final best_tag to our best path
|
||||
best_path = [best_tag]
|
||||
|
||||
# traverse the backpointers in backwards
|
||||
for backpointers_t in reversed(backpointers):
|
||||
|
||||
# recover the best_tag at this timestep
|
||||
best_tag = backpointers_t[best_tag][sample_id].item()
|
||||
|
||||
# append to the beginning of the list so we don't need to reverse it later
|
||||
best_path.insert(0, best_tag)
|
||||
|
||||
return best_path
|
@ -0,0 +1,114 @@
|
||||
# https://yaml.org/type/float.html
|
||||
data:
|
||||
train_manifest: data/manifest.train
|
||||
dev_manifest: data/manifest.dev
|
||||
test_manifest: data/manifest.test
|
||||
vocab_filepath: data/vocab.txt
|
||||
unit_type: 'char'
|
||||
spm_model_prefix: ''
|
||||
augmentation_config: conf/augmentation.json
|
||||
batch_size: 32
|
||||
min_input_len: 0.5
|
||||
max_input_len: 20.0 # second
|
||||
min_output_len: 0.0
|
||||
max_output_len: 400.0
|
||||
min_output_input_ratio: 0.05
|
||||
max_output_input_ratio: 10.0
|
||||
raw_wav: True # use raw_wav or kaldi feature
|
||||
specgram_type: fbank #linear, mfcc, fbank
|
||||
feat_dim: 80
|
||||
delta_delta: False
|
||||
dither: 1.0
|
||||
target_sample_rate: 16000
|
||||
max_freq: None
|
||||
n_fft: None
|
||||
stride_ms: 10.0
|
||||
window_ms: 25.0
|
||||
use_dB_normalization: True
|
||||
target_dB: -20
|
||||
random_seed: 0
|
||||
keep_transcription_text: False
|
||||
sortagrad: True
|
||||
shuffle_method: batch_shuffle
|
||||
num_workers: 0
|
||||
|
||||
|
||||
# network architecture
|
||||
model:
|
||||
cmvn_file: "data/mean_std.json"
|
||||
cmvn_file_type: "json"
|
||||
# encoder related
|
||||
encoder: conformer
|
||||
encoder_conf:
|
||||
output_size: 256 # dimension of attention
|
||||
attention_heads: 4
|
||||
linear_units: 2048 # the number of units of position-wise feed forward
|
||||
num_blocks: 12 # the number of encoder blocks
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
|
||||
normalize_before: True
|
||||
use_cnn_module: True
|
||||
cnn_module_kernel: 15
|
||||
activation_type: 'swish'
|
||||
pos_enc_layer_type: 'rel_pos'
|
||||
selfattention_layer_type: 'rel_selfattn'
|
||||
causal: true
|
||||
use_dynamic_chunk: true
|
||||
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
|
||||
use_dynamic_left_chunk: false
|
||||
|
||||
# decoder related
|
||||
decoder: transformer
|
||||
decoder_conf:
|
||||
attention_heads: 4
|
||||
linear_units: 2048
|
||||
num_blocks: 6
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
self_attention_dropout_rate: 0.0
|
||||
src_attention_dropout_rate: 0.0
|
||||
|
||||
# hybrid CTC/attention
|
||||
model_conf:
|
||||
ctc_weight: 0.3
|
||||
lsm_weight: 0.1 # label smoothing option
|
||||
length_normalized_loss: false
|
||||
|
||||
|
||||
training:
|
||||
n_epoch: 180
|
||||
accum_grad: 4
|
||||
global_grad_clip: 5.0
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
weight_decay: 1e-6
|
||||
scheduler: warmuplr # pytorch v1.1.0+ required
|
||||
scheduler_conf:
|
||||
warmup_steps: 25000
|
||||
lr_decay: 1.0
|
||||
log_interval: 100
|
||||
|
||||
|
||||
decoding:
|
||||
batch_size: 128
|
||||
error_rate_type: cer
|
||||
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
|
||||
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
|
||||
alpha: 2.5
|
||||
beta: 0.3
|
||||
beam_size: 10
|
||||
cutoff_prob: 1.0
|
||||
cutoff_top_n: 0
|
||||
num_proc_bsearch: 8
|
||||
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
|
||||
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
|
||||
# <0: for decoding, use full chunk.
|
||||
# >0: for decoding, use fixed chunk size as set.
|
||||
# 0: used for training, it's prohibited here.
|
||||
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
|
||||
simulate_streaming: true # simulate streaming inference. Defaults to False.
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1 @@
|
||||
__version__ = "0.2.2"
|
@ -0,0 +1,490 @@
|
||||
"""
|
||||
Module containing functions cloned from librosa
|
||||
|
||||
To make sure nnAudio would not become broken when updating librosa
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import warnings
|
||||
### ----------------Functions for generating kenral for Mel Spectrogram------------ ###
|
||||
# This code is equalvant to from librosa.filters import mel
|
||||
# By doing so, we can run nnAudio without installing librosa
|
||||
def fft2gammatonemx(sr=20000, n_fft=2048, n_bins=64, width=1.0, fmin=0.0,
|
||||
fmax=11025, maxlen=1024):
|
||||
"""
|
||||
# Ellis' description in MATLAB:
|
||||
# [wts,cfreqa] = fft2gammatonemx(nfft, sr, nfilts, width, minfreq, maxfreq, maxlen)
|
||||
# Generate a matrix of weights to combine FFT bins into
|
||||
# Gammatone bins. nfft defines the source FFT size at
|
||||
# sampling rate sr. Optional nfilts specifies the number of
|
||||
# output bands required (default 64), and width is the
|
||||
# constant width of each band in Bark (default 1).
|
||||
# minfreq, maxfreq specify range covered in Hz (100, sr/2).
|
||||
# While wts has nfft columns, the second half are all zero.
|
||||
# Hence, aud spectrum is
|
||||
# fft2gammatonemx(nfft,sr)*abs(fft(xincols,nfft));
|
||||
# maxlen truncates the rows to this many bins.
|
||||
# cfreqs returns the actual center frequencies of each
|
||||
# gammatone band in Hz.
|
||||
#
|
||||
# 2009/02/22 02:29:25 Dan Ellis dpwe@ee.columbia.edu based on rastamat/audspec.m
|
||||
# Sat May 27 15:37:50 2017 Maddie Cusimano, mcusi@mit.edu 27 May 2017: convert to python
|
||||
"""
|
||||
|
||||
wts = np.zeros([n_bins, n_fft], dtype=np.float32)
|
||||
|
||||
# after Slaney's MakeERBFilters
|
||||
EarQ = 9.26449;
|
||||
minBW = 24.7;
|
||||
order = 1;
|
||||
|
||||
nFr = np.array(range(n_bins)) + 1
|
||||
em = EarQ * minBW
|
||||
cfreqs = (fmax + em) * np.exp(nFr * (-np.log(fmax + em) + np.log(fmin + em)) / n_bins) - em
|
||||
cfreqs = cfreqs[::-1]
|
||||
|
||||
GTord = 4
|
||||
ucircArray = np.array(range(int(n_fft / 2 + 1)))
|
||||
ucirc = np.exp(1j * 2 * np.pi * ucircArray / n_fft);
|
||||
# justpoles = 0 :taking out the 'if' corresponding to this.
|
||||
|
||||
ERB = width * np.power(np.power(cfreqs / EarQ, order) + np.power(minBW, order), 1 / order);
|
||||
B = 1.019 * 2 * np.pi * ERB;
|
||||
r = np.exp(-B / sr)
|
||||
theta = 2 * np.pi * cfreqs / sr
|
||||
pole = r * np.exp(1j * theta)
|
||||
T = 1 / sr
|
||||
ebt = np.exp(B * T);
|
||||
cpt = 2 * cfreqs * np.pi * T;
|
||||
ccpt = 2 * T * np.cos(cpt);
|
||||
scpt = 2 * T * np.sin(cpt);
|
||||
A11 = -np.divide(np.divide(ccpt, ebt) + np.divide(np.sqrt(3 + 2 ** 1.5) * scpt, ebt), 2);
|
||||
A12 = -np.divide(np.divide(ccpt, ebt) - np.divide(np.sqrt(3 + 2 ** 1.5) * scpt, ebt), 2);
|
||||
A13 = -np.divide(np.divide(ccpt, ebt) + np.divide(np.sqrt(3 - 2 ** 1.5) * scpt, ebt), 2);
|
||||
A14 = -np.divide(np.divide(ccpt, ebt) - np.divide(np.sqrt(3 - 2 ** 1.5) * scpt, ebt), 2);
|
||||
zros = -np.array([A11, A12, A13, A14]) / T;
|
||||
wIdx = range(int(n_fft / 2 + 1))
|
||||
gain = np.abs((-2 * np.exp(4 * 1j * cfreqs * np.pi * T) * T + 2 * np.exp(
|
||||
-(B * T) + 2 * 1j * cfreqs * np.pi * T) * T * (
|
||||
np.cos(2 * cfreqs * np.pi * T) - np.sqrt(3 - 2 ** (3 / 2)) * np.sin(
|
||||
2 * cfreqs * np.pi * T))) * (-2 * np.exp(4 * 1j * cfreqs * np.pi * T) * T + 2 * np.exp(
|
||||
-(B * T) + 2 * 1j * cfreqs * np.pi * T) * T * (np.cos(2 * cfreqs * np.pi * T) + np.sqrt(
|
||||
3 - 2 ** (3 / 2)) * np.sin(2 * cfreqs * np.pi * T))) * (
|
||||
-2 * np.exp(4 * 1j * cfreqs * np.pi * T) * T + 2 * np.exp(
|
||||
-(B * T) + 2 * 1j * cfreqs * np.pi * T) * T * (
|
||||
np.cos(2 * cfreqs * np.pi * T) - np.sqrt(3 + 2 ** (3 / 2)) * np.sin(
|
||||
2 * cfreqs * np.pi * T))) * (
|
||||
-2 * np.exp(4 * 1j * cfreqs * np.pi * T) * T + 2 * np.exp(
|
||||
-(B * T) + 2 * 1j * cfreqs * np.pi * T) * T * (
|
||||
np.cos(2 * cfreqs * np.pi * T) + np.sqrt(3 + 2 ** (3 / 2)) * np.sin(
|
||||
2 * cfreqs * np.pi * T))) / (
|
||||
-2 / np.exp(2 * B * T) - 2 * np.exp(4 * 1j * cfreqs * np.pi * T) + 2 * (
|
||||
1 + np.exp(4 * 1j * cfreqs * np.pi * T)) / np.exp(B * T)) ** 4);
|
||||
# in MATLAB, there used to be 64 where here it says n_bins:
|
||||
wts[:, wIdx] = ((T ** 4) / np.reshape(gain, (n_bins, 1))) * np.abs(
|
||||
ucirc - np.reshape(zros[0], (n_bins, 1))) * np.abs(ucirc - np.reshape(zros[1], (n_bins, 1))) * np.abs(
|
||||
ucirc - np.reshape(zros[2], (n_bins, 1))) * np.abs(ucirc - np.reshape(zros[3], (n_bins, 1))) * (np.abs(
|
||||
np.power(np.multiply(np.reshape(pole, (n_bins, 1)) - ucirc, np.conj(np.reshape(pole, (n_bins, 1))) - ucirc),
|
||||
-GTord)));
|
||||
wts = wts[:, range(maxlen)];
|
||||
|
||||
return wts, cfreqs
|
||||
|
||||
def gammatone(sr, n_fft, n_bins=64, fmin=20.0, fmax=None, htk=False,
|
||||
norm=1, dtype=np.float32):
|
||||
"""Create a Filterbank matrix to combine FFT bins into Gammatone bins
|
||||
Parameters
|
||||
----------
|
||||
sr : number > 0 [scalar]
|
||||
sampling rate of the incoming signal
|
||||
n_fft : int > 0 [scalar]
|
||||
number of FFT components
|
||||
n_bins : int > 0 [scalar]
|
||||
number of Mel bands to generate
|
||||
fmin : float >= 0 [scalar]
|
||||
lowest frequency (in Hz)
|
||||
fmax : float >= 0 [scalar]
|
||||
highest frequency (in Hz).
|
||||
If `None`, use `fmax = sr / 2.0`
|
||||
htk : bool [scalar]
|
||||
use HTK formula instead of Slaney
|
||||
norm : {None, 1, np.inf} [scalar]
|
||||
if 1, divide the triangular mel weights by the width of the mel band
|
||||
(area normalization). Otherwise, leave all the triangles aiming for
|
||||
a peak value of 1.0
|
||||
dtype : np.dtype
|
||||
The data type of the output basis.
|
||||
By default, uses 32-bit (single-precision) floating point.
|
||||
Returns
|
||||
-------
|
||||
G : np.ndarray [shape=(n_bins, 1 + n_fft/2)]
|
||||
Gammatone transform matrix
|
||||
"""
|
||||
|
||||
if fmax is None:
|
||||
fmax = float(sr) / 2
|
||||
n_bins = int(n_bins)
|
||||
|
||||
weights,_ = fft2gammatonemx(sr=sr, n_fft=n_fft, n_bins=n_bins, fmin=fmin, fmax=fmax, maxlen=int(n_fft//2+1))
|
||||
|
||||
return (1/n_fft)*weights
|
||||
|
||||
def mel_to_hz(mels, htk=False):
|
||||
"""Convert mel bin numbers to frequencies
|
||||
Examples
|
||||
--------
|
||||
>>> librosa.mel_to_hz(3)
|
||||
200.
|
||||
>>> librosa.mel_to_hz([1,2,3,4,5])
|
||||
array([ 66.667, 133.333, 200. , 266.667, 333.333])
|
||||
Parameters
|
||||
----------
|
||||
mels : np.ndarray [shape=(n,)], float
|
||||
mel bins to convert
|
||||
htk : bool
|
||||
use HTK formula instead of Slaney
|
||||
Returns
|
||||
-------
|
||||
frequencies : np.ndarray [shape=(n,)]
|
||||
input mels in Hz
|
||||
See Also
|
||||
--------
|
||||
hz_to_mel
|
||||
"""
|
||||
|
||||
mels = np.asanyarray(mels)
|
||||
|
||||
if htk:
|
||||
return 700.0 * (10.0**(mels / 2595.0) - 1.0)
|
||||
|
||||
# Fill in the linear scale
|
||||
f_min = 0.0
|
||||
f_sp = 200.0 / 3
|
||||
freqs = f_min + f_sp * mels
|
||||
|
||||
# And now the nonlinear scale
|
||||
min_log_hz = 1000.0 # beginning of log region (Hz)
|
||||
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
||||
logstep = np.log(6.4) / 27.0 # step size for log region
|
||||
|
||||
if mels.ndim:
|
||||
# If we have vector data, vectorize
|
||||
log_t = (mels >= min_log_mel)
|
||||
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
|
||||
elif mels >= min_log_mel:
|
||||
# If we have scalar data, check directly
|
||||
freqs = min_log_hz * np.exp(logstep * (mels - min_log_mel))
|
||||
|
||||
return freqs
|
||||
|
||||
def hz_to_mel(frequencies, htk=False):
|
||||
"""Convert Hz to Mels
|
||||
Examples
|
||||
--------
|
||||
>>> librosa.hz_to_mel(60)
|
||||
0.9
|
||||
>>> librosa.hz_to_mel([110, 220, 440])
|
||||
array([ 1.65, 3.3 , 6.6 ])
|
||||
Parameters
|
||||
----------
|
||||
frequencies : number or np.ndarray [shape=(n,)] , float
|
||||
scalar or array of frequencies
|
||||
htk : bool
|
||||
use HTK formula instead of Slaney
|
||||
Returns
|
||||
-------
|
||||
mels : number or np.ndarray [shape=(n,)]
|
||||
input frequencies in Mels
|
||||
See Also
|
||||
--------
|
||||
mel_to_hz
|
||||
"""
|
||||
|
||||
frequencies = np.asanyarray(frequencies)
|
||||
|
||||
if htk:
|
||||
return 2595.0 * np.log10(1.0 + frequencies / 700.0)
|
||||
|
||||
# Fill in the linear part
|
||||
f_min = 0.0
|
||||
f_sp = 200.0 / 3
|
||||
|
||||
mels = (frequencies - f_min) / f_sp
|
||||
|
||||
# Fill in the log-scale part
|
||||
|
||||
min_log_hz = 1000.0 # beginning of log region (Hz)
|
||||
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
||||
logstep = np.log(6.4) / 27.0 # step size for log region
|
||||
|
||||
if frequencies.ndim:
|
||||
# If we have array data, vectorize
|
||||
log_t = (frequencies >= min_log_hz)
|
||||
mels[log_t] = min_log_mel + np.log(frequencies[log_t]/min_log_hz) / logstep
|
||||
elif frequencies >= min_log_hz:
|
||||
# If we have scalar data, heck directly
|
||||
mels = min_log_mel + np.log(frequencies / min_log_hz) / logstep
|
||||
|
||||
return mels
|
||||
|
||||
def fft_frequencies(sr=22050, n_fft=2048):
|
||||
'''Alternative implementation of `np.fft.fftfreq`
|
||||
Parameters
|
||||
----------
|
||||
sr : number > 0 [scalar]
|
||||
Audio sampling rate
|
||||
n_fft : int > 0 [scalar]
|
||||
FFT window size
|
||||
Returns
|
||||
-------
|
||||
freqs : np.ndarray [shape=(1 + n_fft/2,)]
|
||||
Frequencies `(0, sr/n_fft, 2*sr/n_fft, ..., sr/2)`
|
||||
Examples
|
||||
--------
|
||||
>>> librosa.fft_frequencies(sr=22050, n_fft=16)
|
||||
array([ 0. , 1378.125, 2756.25 , 4134.375,
|
||||
5512.5 , 6890.625, 8268.75 , 9646.875, 11025. ])
|
||||
'''
|
||||
|
||||
return np.linspace(0,
|
||||
float(sr) / 2,
|
||||
int(1 + n_fft//2),
|
||||
endpoint=True)
|
||||
|
||||
def mel_frequencies(n_mels=128, fmin=0.0, fmax=11025.0, htk=False):
|
||||
"""
|
||||
This function is cloned from librosa 0.7.
|
||||
Please refer to the original
|
||||
`documentation <https://librosa.org/doc/latest/generated/librosa.mel_frequencies.html?highlight=mel_frequencies#librosa.mel_frequencies>`__
|
||||
for more info.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_mels : int > 0 [scalar]
|
||||
Number of mel bins.
|
||||
|
||||
fmin : float >= 0 [scalar]
|
||||
Minimum frequency (Hz).
|
||||
|
||||
fmax : float >= 0 [scalar]
|
||||
Maximum frequency (Hz).
|
||||
|
||||
htk : bool
|
||||
If True, use HTK formula to convert Hz to mel.
|
||||
Otherwise (False), use Slaney's Auditory Toolbox.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bin_frequencies : ndarray [shape=(n_mels,)]
|
||||
Vector of n_mels frequencies in Hz which are uniformly spaced on the Mel
|
||||
axis.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> librosa.mel_frequencies(n_mels=40)
|
||||
array([ 0. , 85.317, 170.635, 255.952,
|
||||
341.269, 426.586, 511.904, 597.221,
|
||||
682.538, 767.855, 853.173, 938.49 ,
|
||||
1024.856, 1119.114, 1222.042, 1334.436,
|
||||
1457.167, 1591.187, 1737.532, 1897.337,
|
||||
2071.84 , 2262.393, 2470.47 , 2697.686,
|
||||
2945.799, 3216.731, 3512.582, 3835.643,
|
||||
4188.417, 4573.636, 4994.285, 5453.621,
|
||||
5955.205, 6502.92 , 7101.009, 7754.107,
|
||||
8467.272, 9246.028, 10096.408, 11025. ])
|
||||
"""
|
||||
|
||||
# 'Center freqs' of mel bands - uniformly spaced between limits
|
||||
min_mel = hz_to_mel(fmin, htk=htk)
|
||||
max_mel = hz_to_mel(fmax, htk=htk)
|
||||
|
||||
mels = np.linspace(min_mel, max_mel, n_mels)
|
||||
|
||||
return mel_to_hz(mels, htk=htk)
|
||||
|
||||
def mel(sr, n_fft, n_mels=128, fmin=0.0, fmax=None, htk=False,
|
||||
norm=1, dtype=np.float32):
|
||||
"""
|
||||
This function is cloned from librosa 0.7.
|
||||
Please refer to the original
|
||||
`documentation <https://librosa.org/doc/latest/generated/librosa.filters.mel.html>`__
|
||||
for more info.
|
||||
Create a Filterbank matrix to combine FFT bins into Mel-frequency bins
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sr : number > 0 [scalar]
|
||||
sampling rate of the incoming signal
|
||||
n_fft : int > 0 [scalar]
|
||||
number of FFT components
|
||||
n_mels : int > 0 [scalar]
|
||||
number of Mel bands to generate
|
||||
fmin : float >= 0 [scalar]
|
||||
lowest frequency (in Hz)
|
||||
fmax : float >= 0 [scalar]
|
||||
highest frequency (in Hz).
|
||||
If `None`, use `fmax = sr / 2.0`
|
||||
htk : bool [scalar]
|
||||
use HTK formula instead of Slaney
|
||||
norm : {None, 1, np.inf} [scalar]
|
||||
if 1, divide the triangular mel weights by the width of the mel band
|
||||
(area normalization). Otherwise, leave all the triangles aiming for
|
||||
a peak value of 1.0
|
||||
dtype : np.dtype
|
||||
The data type of the output basis.
|
||||
By default, uses 32-bit (single-precision) floating point.
|
||||
|
||||
Returns
|
||||
-------
|
||||
M : np.ndarray [shape=(n_mels, 1 + n_fft/2)]
|
||||
Mel transform matrix
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function caches at level 10.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> melfb = librosa.filters.mel(22050, 2048)
|
||||
>>> melfb
|
||||
array([[ 0. , 0.016, ..., 0. , 0. ],
|
||||
[ 0. , 0. , ..., 0. , 0. ],
|
||||
...,
|
||||
[ 0. , 0. , ..., 0. , 0. ],
|
||||
[ 0. , 0. , ..., 0. , 0. ]])
|
||||
Clip the maximum frequency to 8KHz
|
||||
>>> librosa.filters.mel(22050, 2048, fmax=8000)
|
||||
array([[ 0. , 0.02, ..., 0. , 0. ],
|
||||
[ 0. , 0. , ..., 0. , 0. ],
|
||||
...,
|
||||
[ 0. , 0. , ..., 0. , 0. ],
|
||||
[ 0. , 0. , ..., 0. , 0. ]])
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> plt.figure()
|
||||
>>> librosa.display.specshow(melfb, x_axis='linear')
|
||||
>>> plt.ylabel('Mel filter')
|
||||
>>> plt.title('Mel filter bank')
|
||||
>>> plt.colorbar()
|
||||
>>> plt.tight_layout()
|
||||
>>> plt.show()
|
||||
"""
|
||||
|
||||
if fmax is None:
|
||||
fmax = float(sr) / 2
|
||||
|
||||
if norm is not None and norm != 1 and norm != np.inf:
|
||||
raise ParameterError('Unsupported norm: {}'.format(repr(norm)))
|
||||
|
||||
# Initialize the weights
|
||||
n_mels = int(n_mels)
|
||||
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
|
||||
|
||||
# Center freqs of each FFT bin
|
||||
fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft)
|
||||
|
||||
# 'Center freqs' of mel bands - uniformly spaced between limits
|
||||
mel_f = mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax, htk=htk)
|
||||
|
||||
fdiff = np.diff(mel_f)
|
||||
ramps = np.subtract.outer(mel_f, fftfreqs)
|
||||
|
||||
for i in range(n_mels):
|
||||
# lower and upper slopes for all bins
|
||||
lower = -ramps[i] / fdiff[i]
|
||||
upper = ramps[i+2] / fdiff[i+1]
|
||||
|
||||
# .. then intersect them with each other and zero
|
||||
weights[i] = np.maximum(0, np.minimum(lower, upper))
|
||||
|
||||
if norm == 1:
|
||||
# Slaney-style mel is scaled to be approx constant energy per channel
|
||||
enorm = 2.0 / (mel_f[2:n_mels+2] - mel_f[:n_mels])
|
||||
weights *= enorm[:, np.newaxis]
|
||||
|
||||
# Only check weights if f_mel[0] is positive
|
||||
if not np.all((mel_f[:-2] == 0) | (weights.max(axis=1) > 0)):
|
||||
# This means we have an empty channel somewhere
|
||||
warnings.warn('Empty filters detected in mel frequency basis. '
|
||||
'Some channels will produce empty responses. '
|
||||
'Try increasing your sampling rate (and fmax) or '
|
||||
'reducing n_mels.')
|
||||
|
||||
return weights
|
||||
### ------------------End of Functions for generating kenral for Mel Spectrogram ----------------###
|
||||
|
||||
|
||||
### ------------------Functions for making STFT same as librosa ---------------------------------###
|
||||
def pad_center(data, size, axis=-1, **kwargs):
|
||||
'''Wrapper for np.pad to automatically center an array prior to padding.
|
||||
This is analogous to `str.center()`
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> # Generate a vector
|
||||
>>> data = np.ones(5)
|
||||
>>> librosa.util.pad_center(data, 10, mode='constant')
|
||||
array([ 0., 0., 1., 1., 1., 1., 1., 0., 0., 0.])
|
||||
|
||||
>>> # Pad a matrix along its first dimension
|
||||
>>> data = np.ones((3, 5))
|
||||
>>> librosa.util.pad_center(data, 7, axis=0)
|
||||
array([[ 0., 0., 0., 0., 0.],
|
||||
[ 0., 0., 0., 0., 0.],
|
||||
[ 1., 1., 1., 1., 1.],
|
||||
[ 1., 1., 1., 1., 1.],
|
||||
[ 1., 1., 1., 1., 1.],
|
||||
[ 0., 0., 0., 0., 0.],
|
||||
[ 0., 0., 0., 0., 0.]])
|
||||
>>> # Or its second dimension
|
||||
>>> librosa.util.pad_center(data, 7, axis=1)
|
||||
array([[ 0., 1., 1., 1., 1., 1., 0.],
|
||||
[ 0., 1., 1., 1., 1., 1., 0.],
|
||||
[ 0., 1., 1., 1., 1., 1., 0.]])
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : np.ndarray
|
||||
Vector to be padded and centered
|
||||
|
||||
size : int >= len(data) [scalar]
|
||||
Length to pad `data`
|
||||
|
||||
axis : int
|
||||
Axis along which to pad and center the data
|
||||
|
||||
kwargs : additional keyword arguments
|
||||
arguments passed to `np.pad()`
|
||||
|
||||
Returns
|
||||
-------
|
||||
data_padded : np.ndarray
|
||||
`data` centered and padded to length `size` along the
|
||||
specified axis
|
||||
|
||||
Raises
|
||||
------
|
||||
ParameterError
|
||||
If `size < data.shape[axis]`
|
||||
|
||||
See Also
|
||||
--------
|
||||
numpy.pad
|
||||
'''
|
||||
|
||||
kwargs.setdefault('mode', 'constant')
|
||||
|
||||
n = data.shape[axis]
|
||||
|
||||
lpad = int((size - n) // 2)
|
||||
|
||||
lengths = [(0, 0)] * data.ndim
|
||||
lengths[axis] = (lpad, int(size - n - lpad))
|
||||
|
||||
if lpad < 0:
|
||||
raise ParameterError(('Target size ({:d}) must be '
|
||||
'at least input size ({:d})').format(size, n))
|
||||
|
||||
return np.pad(data, lengths, **kwargs)
|
||||
|
||||
### ------------------End of functions for making STFT same as librosa ---------------------------###
|
@ -0,0 +1,535 @@
|
||||
"""
|
||||
Module containing helper functions such as overlap sum and Fourier kernels generators
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.nn.functional import conv1d, fold
|
||||
|
||||
import numpy as np
|
||||
from time import time
|
||||
import math
|
||||
from scipy.signal import get_window
|
||||
from scipy import signal
|
||||
from scipy import fft
|
||||
import warnings
|
||||
|
||||
from nnAudio.librosa_functions import *
|
||||
|
||||
## --------------------------- Filter Design ---------------------------##
|
||||
def torch_window_sumsquare(w, n_frames, stride, n_fft, power=2):
|
||||
w_stacks = w.unsqueeze(-1).repeat((1,n_frames)).unsqueeze(0)
|
||||
# Window length + stride*(frames-1)
|
||||
output_len = w_stacks.shape[1] + stride*(w_stacks.shape[2]-1)
|
||||
return fold(w_stacks**power, (1,output_len), kernel_size=(1,n_fft), stride=stride)
|
||||
|
||||
def overlap_add(X, stride):
|
||||
n_fft = X.shape[1]
|
||||
output_len = n_fft + stride*(X.shape[2]-1)
|
||||
|
||||
return fold(X, (1,output_len), kernel_size=(1,n_fft), stride=stride).flatten(1)
|
||||
|
||||
def uniform_distribution(r1,r2, *size, device):
|
||||
return (r1 - r2) * torch.rand(*size, device=device) + r2
|
||||
|
||||
def extend_fbins(X):
|
||||
"""Extending the number of frequency bins from `n_fft//2+1` back to `n_fft` by
|
||||
reversing all bins except DC and Nyquist and append it on top of existing spectrogram"""
|
||||
X_upper = torch.flip(X[:,1:-1],(0,1))
|
||||
X_upper[:,:,:,1] = -X_upper[:,:,:,1] # For the imaganinry part, it is an odd function
|
||||
return torch.cat((X[:, :, :], X_upper), 1)
|
||||
|
||||
|
||||
def downsampling_by_n(x, filterKernel, n):
|
||||
"""A helper function that downsamples the audio by a arbitary factor n.
|
||||
It is used in CQT2010 and CQT2010v2.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
The input waveform in ``torch.Tensor`` type with shape ``(batch, 1, len_audio)``
|
||||
|
||||
filterKernel : str
|
||||
Filter kernel in ``torch.Tensor`` type with shape ``(1, 1, len_kernel)``
|
||||
|
||||
n : int
|
||||
The downsampling factor
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The downsampled waveform
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> x_down = downsampling_by_n(x, filterKernel)
|
||||
"""
|
||||
|
||||
x = conv1d(x,filterKernel,stride=n, padding=(filterKernel.shape[-1]-1)//2)
|
||||
return x
|
||||
|
||||
|
||||
def downsampling_by_2(x, filterKernel):
|
||||
"""A helper function that downsamples the audio by half. It is used in CQT2010 and CQT2010v2
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
The input waveform in ``torch.Tensor`` type with shape ``(batch, 1, len_audio)``
|
||||
|
||||
filterKernel : str
|
||||
Filter kernel in ``torch.Tensor`` type with shape ``(1, 1, len_kernel)``
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The downsampled waveform
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> x_down = downsampling_by_2(x, filterKernel)
|
||||
"""
|
||||
|
||||
x = conv1d(x,filterKernel,stride=2, padding=(filterKernel.shape[-1]-1)//2)
|
||||
return x
|
||||
|
||||
|
||||
## Basic tools for computation ##
|
||||
def nextpow2(A):
|
||||
"""A helper function to calculate the next nearest number to the power of 2.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A : float
|
||||
A float number that is going to be rounded up to the nearest power of 2
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The nearest power of 2 to the input number ``A``
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> nextpow2(6)
|
||||
3
|
||||
"""
|
||||
|
||||
return int(np.ceil(np.log2(A)))
|
||||
|
||||
## Basic tools for computation ##
|
||||
def prepow2(A):
|
||||
"""A helper function to calculate the next nearest number to the power of 2.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A : float
|
||||
A float number that is going to be rounded up to the nearest power of 2
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The nearest power of 2 to the input number ``A``
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> nextpow2(6)
|
||||
3
|
||||
"""
|
||||
|
||||
return int(np.floor(np.log2(A)))
|
||||
|
||||
|
||||
def complex_mul(cqt_filter, stft):
|
||||
"""Since PyTorch does not support complex numbers and its operation.
|
||||
We need to write our own complex multiplication function. This one is specially
|
||||
designed for CQT usage.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cqt_filter : tuple of torch.Tensor
|
||||
The tuple is in the format of ``(real_torch_tensor, imag_torch_tensor)``
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple of torch.Tensor
|
||||
The output is in the format of ``(real_torch_tensor, imag_torch_tensor)``
|
||||
"""
|
||||
|
||||
cqt_filter_real = cqt_filter[0]
|
||||
cqt_filter_imag = cqt_filter[1]
|
||||
fourier_real = stft[0]
|
||||
fourier_imag = stft[1]
|
||||
|
||||
CQT_real = torch.matmul(cqt_filter_real, fourier_real) - torch.matmul(cqt_filter_imag, fourier_imag)
|
||||
CQT_imag = torch.matmul(cqt_filter_real, fourier_imag) + torch.matmul(cqt_filter_imag, fourier_real)
|
||||
|
||||
return CQT_real, CQT_imag
|
||||
|
||||
|
||||
def broadcast_dim(x):
|
||||
"""
|
||||
Auto broadcast input so that it can fits into a Conv1d
|
||||
"""
|
||||
|
||||
if x.dim() == 2:
|
||||
x = x[:, None, :]
|
||||
elif x.dim() == 1:
|
||||
# If nn.DataParallel is used, this broadcast doesn't work
|
||||
x = x[None, None, :]
|
||||
elif x.dim() == 3:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Only support input with shape = (batch, len) or shape = (len)")
|
||||
return x
|
||||
|
||||
|
||||
def broadcast_dim_conv2d(x):
|
||||
"""
|
||||
Auto broadcast input so that it can fits into a Conv2d
|
||||
"""
|
||||
|
||||
if x.dim() == 3:
|
||||
x = x[:, None, :,:]
|
||||
|
||||
else:
|
||||
raise ValueError("Only support input with shape = (batch, len) or shape = (len)")
|
||||
return x
|
||||
|
||||
|
||||
## Kernal generation functions ##
|
||||
def create_fourier_kernels(n_fft, win_length=None, freq_bins=None, fmin=50,fmax=6000, sr=44100,
|
||||
freq_scale='linear', window='hann', verbose=True):
|
||||
""" This function creates the Fourier Kernel for STFT, Melspectrogram and CQT.
|
||||
Most of the parameters follow librosa conventions. Part of the code comes from
|
||||
pytorch_musicnet. https://github.com/jthickstun/pytorch_musicnet
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_fft : int
|
||||
The window size
|
||||
|
||||
freq_bins : int
|
||||
Number of frequency bins. Default is ``None``, which means ``n_fft//2+1`` bins
|
||||
|
||||
fmin : int
|
||||
The starting frequency for the lowest frequency bin.
|
||||
If freq_scale is ``no``, this argument does nothing.
|
||||
|
||||
fmax : int
|
||||
The ending frequency for the highest frequency bin.
|
||||
If freq_scale is ``no``, this argument does nothing.
|
||||
|
||||
sr : int
|
||||
The sampling rate for the input audio. It is used to calculate the correct ``fmin`` and ``fmax``.
|
||||
Setting the correct sampling rate is very important for calculating the correct frequency.
|
||||
|
||||
freq_scale: 'linear', 'log', or 'no'
|
||||
Determine the spacing between each frequency bin.
|
||||
When 'linear' or 'log' is used, the bin spacing can be controlled by ``fmin`` and ``fmax``.
|
||||
If 'no' is used, the bin will start at 0Hz and end at Nyquist frequency with linear spacing.
|
||||
|
||||
Returns
|
||||
-------
|
||||
wsin : numpy.array
|
||||
Imaginary Fourier Kernel with the shape ``(freq_bins, 1, n_fft)``
|
||||
|
||||
wcos : numpy.array
|
||||
Real Fourier Kernel with the shape ``(freq_bins, 1, n_fft)``
|
||||
|
||||
bins2freq : list
|
||||
Mapping each frequency bin to frequency in Hz.
|
||||
|
||||
binslist : list
|
||||
The normalized frequency ``k`` in digital domain.
|
||||
This ``k`` is in the Discrete Fourier Transform equation $$
|
||||
|
||||
"""
|
||||
|
||||
if freq_bins==None: freq_bins = n_fft//2+1
|
||||
if win_length==None: win_length = n_fft
|
||||
|
||||
s = np.arange(0, n_fft, 1.)
|
||||
wsin = np.empty((freq_bins,1,n_fft))
|
||||
wcos = np.empty((freq_bins,1,n_fft))
|
||||
start_freq = fmin
|
||||
end_freq = fmax
|
||||
bins2freq = []
|
||||
binslist = []
|
||||
|
||||
# num_cycles = start_freq*d/44000.
|
||||
# scaling_ind = np.log(end_freq/start_freq)/k
|
||||
|
||||
# Choosing window shape
|
||||
|
||||
window_mask = get_window(window,int(win_length), fftbins=True)
|
||||
window_mask = pad_center(window_mask, n_fft)
|
||||
|
||||
if freq_scale == 'linear':
|
||||
if verbose==True:
|
||||
print(f"sampling rate = {sr}. Please make sure the sampling rate is correct in order to"
|
||||
f"get a valid freq range")
|
||||
start_bin = start_freq*n_fft/sr
|
||||
scaling_ind = (end_freq-start_freq)*(n_fft/sr)/freq_bins
|
||||
|
||||
for k in range(freq_bins): # Only half of the bins contain useful info
|
||||
# print("linear freq = {}".format((k*scaling_ind+start_bin)*sr/n_fft))
|
||||
bins2freq.append((k*scaling_ind+start_bin)*sr/n_fft)
|
||||
binslist.append((k*scaling_ind+start_bin))
|
||||
wsin[k,0,:] = np.sin(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)
|
||||
wcos[k,0,:] = np.cos(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)
|
||||
|
||||
elif freq_scale == 'log':
|
||||
if verbose==True:
|
||||
print(f"sampling rate = {sr}. Please make sure the sampling rate is correct in order to"
|
||||
f"get a valid freq range")
|
||||
start_bin = start_freq*n_fft/sr
|
||||
scaling_ind = np.log(end_freq/start_freq)/freq_bins
|
||||
|
||||
for k in range(freq_bins): # Only half of the bins contain useful info
|
||||
# print("log freq = {}".format(np.exp(k*scaling_ind)*start_bin*sr/n_fft))
|
||||
bins2freq.append(np.exp(k*scaling_ind)*start_bin*sr/n_fft)
|
||||
binslist.append((np.exp(k*scaling_ind)*start_bin))
|
||||
wsin[k,0,:] = np.sin(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)
|
||||
wcos[k,0,:] = np.cos(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)
|
||||
|
||||
elif freq_scale == 'no':
|
||||
for k in range(freq_bins): # Only half of the bins contain useful info
|
||||
bins2freq.append(k*sr/n_fft)
|
||||
binslist.append(k)
|
||||
wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)
|
||||
wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)
|
||||
else:
|
||||
print("Please select the correct frequency scale, 'linear' or 'log'")
|
||||
return wsin.astype(np.float32),wcos.astype(np.float32), bins2freq, binslist, window_mask.astype(np.float32)
|
||||
|
||||
|
||||
# Tools for CQT
|
||||
|
||||
def create_cqt_kernels(Q, fs, fmin, n_bins=84, bins_per_octave=12, norm=1,
|
||||
window='hann', fmax=None, topbin_check=True):
|
||||
"""
|
||||
Automatically create CQT kernels in time domain
|
||||
"""
|
||||
|
||||
fftLen = 2**nextpow2(np.ceil(Q * fs / fmin))
|
||||
# minWin = 2**nextpow2(np.ceil(Q * fs / fmax))
|
||||
|
||||
if (fmax != None) and (n_bins == None):
|
||||
n_bins = np.ceil(bins_per_octave * np.log2(fmax / fmin)) # Calculate the number of bins
|
||||
freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave))
|
||||
|
||||
elif (fmax == None) and (n_bins != None):
|
||||
freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave))
|
||||
|
||||
else:
|
||||
warnings.warn('If fmax is given, n_bins will be ignored',SyntaxWarning)
|
||||
n_bins = np.ceil(bins_per_octave * np.log2(fmax / fmin)) # Calculate the number of bins
|
||||
freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave))
|
||||
|
||||
if np.max(freqs) > fs/2 and topbin_check==True:
|
||||
raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, \
|
||||
please reduce the n_bins'.format(np.max(freqs)))
|
||||
|
||||
tempKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64)
|
||||
specKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64)
|
||||
|
||||
lengths = np.ceil(Q * fs / freqs)
|
||||
for k in range(0, int(n_bins)):
|
||||
freq = freqs[k]
|
||||
l = np.ceil(Q * fs / freq)
|
||||
|
||||
# Centering the kernels
|
||||
if l%2==1: # pad more zeros on RHS
|
||||
start = int(np.ceil(fftLen / 2.0 - l / 2.0))-1
|
||||
else:
|
||||
start = int(np.ceil(fftLen / 2.0 - l / 2.0))
|
||||
|
||||
sig = get_window_dispatch(window,int(l), fftbins=True)*np.exp(np.r_[-l//2:l//2]*1j*2*np.pi*freq/fs)/l
|
||||
|
||||
if norm: # Normalizing the filter # Trying to normalize like librosa
|
||||
tempKernel[k, start:start + int(l)] = sig/np.linalg.norm(sig, norm)
|
||||
else:
|
||||
tempKernel[k, start:start + int(l)] = sig
|
||||
# specKernel[k, :] = fft(tempKernel[k])
|
||||
|
||||
# return specKernel[:,:fftLen//2+1], fftLen, torch.tensor(lenghts).float()
|
||||
return tempKernel, fftLen, torch.tensor(lengths).float(), freqs
|
||||
|
||||
|
||||
def get_window_dispatch(window, N, fftbins=True):
|
||||
if isinstance(window, str):
|
||||
return get_window(window, N, fftbins=fftbins)
|
||||
elif isinstance(window, tuple):
|
||||
if window[0] == 'gaussian':
|
||||
assert window[1] >= 0
|
||||
sigma = np.floor(- N / 2 / np.sqrt(- 2 * np.log(10**(- window[1] / 20))))
|
||||
return get_window(('gaussian', sigma), N, fftbins=fftbins)
|
||||
else:
|
||||
Warning("Tuple windows may have undesired behaviour regarding Q factor")
|
||||
elif isinstance(window, float):
|
||||
Warning("You are using Kaiser window with beta factor " + str(window) + ". Correct behaviour not checked.")
|
||||
else:
|
||||
raise Exception("The function get_window from scipy only supports strings, tuples and floats.")
|
||||
|
||||
|
||||
|
||||
def get_cqt_complex(x, cqt_kernels_real, cqt_kernels_imag, hop_length, padding):
|
||||
"""Multiplying the STFT result with the cqt_kernel, check out the 1992 CQT paper [1]
|
||||
for how to multiple the STFT result with the CQT kernel
|
||||
[2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of
|
||||
a constant Q transform.” (1992)."""
|
||||
|
||||
# STFT, converting the audio input from time domain to frequency domain
|
||||
try:
|
||||
x = padding(x) # When center == True, we need padding at the beginning and ending
|
||||
except:
|
||||
warnings.warn(f"\ninput size = {x.shape}\tkernel size = {cqt_kernels_real.shape[-1]}\n"
|
||||
"padding with reflection mode might not be the best choice, try using constant padding",
|
||||
UserWarning)
|
||||
x = torch.nn.functional.pad(x, (cqt_kernels_real.shape[-1]//2, cqt_kernels_real.shape[-1]//2))
|
||||
CQT_real = conv1d(x, cqt_kernels_real, stride=hop_length)
|
||||
CQT_imag = -conv1d(x, cqt_kernels_imag, stride=hop_length)
|
||||
|
||||
return torch.stack((CQT_real, CQT_imag),-1)
|
||||
|
||||
def get_cqt_complex2(x, cqt_kernels_real, cqt_kernels_imag, hop_length, padding, wcos=None, wsin=None):
|
||||
"""Multiplying the STFT result with the cqt_kernel, check out the 1992 CQT paper [1]
|
||||
for how to multiple the STFT result with the CQT kernel
|
||||
[2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of
|
||||
a constant Q transform.” (1992)."""
|
||||
|
||||
# STFT, converting the audio input from time domain to frequency domain
|
||||
try:
|
||||
x = padding(x) # When center == True, we need padding at the beginning and ending
|
||||
except:
|
||||
warnings.warn(f"\ninput size = {x.shape}\tkernel size = {cqt_kernels_real.shape[-1]}\n"
|
||||
"padding with reflection mode might not be the best choice, try using constant padding",
|
||||
UserWarning)
|
||||
x = torch.nn.functional.pad(x, (cqt_kernels_real.shape[-1]//2, cqt_kernels_real.shape[-1]//2))
|
||||
|
||||
|
||||
|
||||
if wcos==None or wsin==None:
|
||||
CQT_real = conv1d(x, cqt_kernels_real, stride=hop_length)
|
||||
CQT_imag = -conv1d(x, cqt_kernels_imag, stride=hop_length)
|
||||
|
||||
else:
|
||||
fourier_real = conv1d(x, wcos, stride=hop_length)
|
||||
fourier_imag = conv1d(x, wsin, stride=hop_length)
|
||||
# Multiplying input with the CQT kernel in freq domain
|
||||
CQT_real, CQT_imag = complex_mul((cqt_kernels_real, cqt_kernels_imag),
|
||||
(fourier_real, fourier_imag))
|
||||
|
||||
return torch.stack((CQT_real, CQT_imag),-1)
|
||||
|
||||
|
||||
|
||||
|
||||
def create_lowpass_filter(band_center=0.5, kernelLength=256, transitionBandwidth=0.03):
|
||||
"""
|
||||
Calculate the highest frequency we need to preserve and the lowest frequency we allow
|
||||
to pass through.
|
||||
Note that frequency is on a scale from 0 to 1 where 0 is 0 and 1 is Nyquist frequency of
|
||||
the signal BEFORE downsampling.
|
||||
"""
|
||||
|
||||
# transitionBandwidth = 0.03
|
||||
passbandMax = band_center / (1 + transitionBandwidth)
|
||||
stopbandMin = band_center * (1 + transitionBandwidth)
|
||||
|
||||
# Unlike the filter tool we used online yesterday, this tool does
|
||||
# not allow us to specify how closely the filter matches our
|
||||
# specifications. Instead, we specify the length of the kernel.
|
||||
# The longer the kernel is, the more precisely it will match.
|
||||
# kernelLength = 256
|
||||
|
||||
# We specify a list of key frequencies for which we will require
|
||||
# that the filter match a specific output gain.
|
||||
# From [0.0 to passbandMax] is the frequency range we want to keep
|
||||
# untouched and [stopbandMin, 1.0] is the range we want to remove
|
||||
keyFrequencies = [0.0, passbandMax, stopbandMin, 1.0]
|
||||
|
||||
# We specify a list of output gains to correspond to the key
|
||||
# frequencies listed above.
|
||||
# The first two gains are 1.0 because they correspond to the first
|
||||
# two key frequencies. the second two are 0.0 because they
|
||||
# correspond to the stopband frequencies
|
||||
gainAtKeyFrequencies = [1.0, 1.0, 0.0, 0.0]
|
||||
|
||||
# This command produces the filter kernel coefficients
|
||||
filterKernel = signal.firwin2(kernelLength, keyFrequencies, gainAtKeyFrequencies)
|
||||
|
||||
return filterKernel.astype(np.float32)
|
||||
|
||||
def get_early_downsample_params(sr, hop_length, fmax_t, Q, n_octaves, verbose):
|
||||
"""Used in CQT2010 and CQT2010v2"""
|
||||
|
||||
window_bandwidth = 1.5 # for hann window
|
||||
filter_cutoff = fmax_t * (1 + 0.5 * window_bandwidth / Q)
|
||||
sr, hop_length, downsample_factor = early_downsample(sr,
|
||||
hop_length,
|
||||
n_octaves,
|
||||
sr//2,
|
||||
filter_cutoff)
|
||||
if downsample_factor != 1:
|
||||
if verbose==True:
|
||||
print("Can do early downsample, factor = ", downsample_factor)
|
||||
earlydownsample=True
|
||||
# print("new sr = ", sr)
|
||||
# print("new hop_length = ", hop_length)
|
||||
early_downsample_filter = create_lowpass_filter(band_center=1/downsample_factor,
|
||||
kernelLength=256,
|
||||
transitionBandwidth=0.03)
|
||||
early_downsample_filter = torch.tensor(early_downsample_filter)[None, None, :]
|
||||
|
||||
else:
|
||||
if verbose==True:
|
||||
print("No early downsampling is required, downsample_factor = ", downsample_factor)
|
||||
early_downsample_filter = None
|
||||
earlydownsample=False
|
||||
|
||||
return sr, hop_length, downsample_factor, early_downsample_filter, earlydownsample
|
||||
|
||||
def early_downsample(sr, hop_length, n_octaves,
|
||||
nyquist, filter_cutoff):
|
||||
'''Return new sampling rate and hop length after early dowansampling'''
|
||||
downsample_count = early_downsample_count(nyquist, filter_cutoff, hop_length, n_octaves)
|
||||
# print("downsample_count = ", downsample_count)
|
||||
downsample_factor = 2**(downsample_count)
|
||||
|
||||
hop_length //= downsample_factor # Getting new hop_length
|
||||
new_sr = sr / float(downsample_factor) # Getting new sampling rate
|
||||
sr = new_sr
|
||||
|
||||
return sr, hop_length, downsample_factor
|
||||
|
||||
|
||||
# The following two downsampling count functions are obtained from librosa CQT
|
||||
# They are used to determine the number of pre resamplings if the starting and ending frequency
|
||||
# are both in low frequency regions.
|
||||
def early_downsample_count(nyquist, filter_cutoff, hop_length, n_octaves):
|
||||
'''Compute the number of early downsampling operations'''
|
||||
|
||||
downsample_count1 = max(0, int(np.ceil(np.log2(0.85 * nyquist /
|
||||
filter_cutoff)) - 1) - 1)
|
||||
# print("downsample_count1 = ", downsample_count1)
|
||||
num_twos = nextpow2(hop_length)
|
||||
downsample_count2 = max(0, num_twos - n_octaves + 1)
|
||||
# print("downsample_count2 = ",downsample_count2)
|
||||
|
||||
return min(downsample_count1, downsample_count2)
|
||||
|
||||
def early_downsample(sr, hop_length, n_octaves,
|
||||
nyquist, filter_cutoff):
|
||||
'''Return new sampling rate and hop length after early dowansampling'''
|
||||
downsample_count = early_downsample_count(nyquist, filter_cutoff, hop_length, n_octaves)
|
||||
# print("downsample_count = ", downsample_count)
|
||||
downsample_factor = 2**(downsample_count)
|
||||
|
||||
hop_length //= downsample_factor # Getting new hop_length
|
||||
new_sr = sr / float(downsample_factor) # Getting new sampling rate
|
||||
|
||||
sr = new_sr
|
||||
|
||||
return sr, hop_length, downsample_factor
|
@ -0,0 +1,37 @@
|
||||
import setuptools
|
||||
import codecs
|
||||
import os.path
|
||||
|
||||
with open("README.md", "r") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
def read(rel_path):
|
||||
here = os.path.abspath(os.path.dirname(__file__))
|
||||
with codecs.open(os.path.join(here, rel_path), 'r') as fp:
|
||||
return fp.read()
|
||||
|
||||
def get_version(rel_path):
|
||||
for line in read(rel_path).splitlines():
|
||||
if line.startswith('__version__'):
|
||||
delim = '"' if '"' in line else "'"
|
||||
return line.split(delim)[1]
|
||||
else:
|
||||
raise RuntimeError("Unable to find version string.")
|
||||
|
||||
setuptools.setup(
|
||||
name="nnAudio", # Replace with your own username
|
||||
version=get_version("nnAudio/__init__.py"),
|
||||
author="KinWaiCheuk",
|
||||
author_email="u3500684@connect.hku.hk",
|
||||
description="A fast GPU audio processing toolbox with 1D convolutional neural network",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/KinWaiCheuk/nnAudio",
|
||||
packages=setuptools.find_packages(),
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
],
|
||||
python_requires='>=3.6',
|
||||
)
|
@ -0,0 +1,38 @@
|
||||
# Creating parameters for STFT test
|
||||
"""
|
||||
It is equivalent to
|
||||
[(1024, 128, 'ones'),
|
||||
(1024, 128, 'hann'),
|
||||
(1024, 128, 'hamming'),
|
||||
(2048, 128, 'ones'),
|
||||
(2048, 512, 'ones'),
|
||||
(2048, 128, 'hann'),
|
||||
(2048, 512, 'hann'),
|
||||
(2048, 128, 'hamming'),
|
||||
(2048, 512, 'hamming'),
|
||||
(None, None, None)]
|
||||
"""
|
||||
|
||||
stft_parameters = []
|
||||
n_fft = [1024,2048]
|
||||
hop_length = {128,512,1024}
|
||||
window = ['ones', 'hann', 'hamming']
|
||||
for i in n_fft:
|
||||
for k in window:
|
||||
for j in hop_length:
|
||||
if j < (i/2):
|
||||
stft_parameters.append((i,j,k))
|
||||
stft_parameters.append((256, None, 'hann'))
|
||||
|
||||
stft_with_win_parameters = []
|
||||
n_fft = [512,1024]
|
||||
win_length = [400, 900]
|
||||
hop_length = {128,256}
|
||||
for i in n_fft:
|
||||
for j in win_length:
|
||||
if j < i:
|
||||
for k in hop_length:
|
||||
if k < (i/2):
|
||||
stft_with_win_parameters.append((i,j,k))
|
||||
|
||||
mel_win_parameters = [(512,400), (1024, 1000)]
|
@ -0,0 +1,373 @@
|
||||
import pytest
|
||||
import librosa
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy.signal import chirp, sweep_poly
|
||||
from nnAudio.Spectrogram import *
|
||||
from parameters import *
|
||||
|
||||
gpu_idx=0
|
||||
|
||||
# librosa example audio for testing
|
||||
example_y, example_sr = librosa.load(librosa.util.example_audio_file())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_fft, hop_length, window", stft_parameters)
|
||||
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
|
||||
def test_inverse2(n_fft, hop_length, window, device):
|
||||
x = torch.tensor(example_y,device=device)
|
||||
stft = STFT(n_fft=n_fft, hop_length=hop_length, window=window).to(device)
|
||||
istft = iSTFT(n_fft=n_fft, hop_length=hop_length, window=window).to(device)
|
||||
X = stft(x.unsqueeze(0), output_format="Complex")
|
||||
x_recon = istft(X, length=x.shape[0], onesided=True).squeeze()
|
||||
assert np.allclose(x.cpu(), x_recon.cpu(), rtol=1e-5, atol=1e-3)
|
||||
|
||||
@pytest.mark.parametrize("n_fft, hop_length, window", stft_parameters)
|
||||
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
|
||||
def test_inverse(n_fft, hop_length, window, device):
|
||||
x = torch.tensor(example_y, device=device)
|
||||
stft = STFT(n_fft=n_fft, hop_length=hop_length, window=window, iSTFT=True).to(device)
|
||||
X = stft(x.unsqueeze(0), output_format="Complex")
|
||||
x_recon = stft.inverse(X, length=x.shape[0]).squeeze()
|
||||
assert np.allclose(x.cpu(), x_recon.cpu(), rtol=1e-3, atol=1)
|
||||
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("n_fft, hop_length, window", stft_parameters)
|
||||
|
||||
# def test_inverse_GPU(n_fft, hop_length, window):
|
||||
# x = torch.tensor(example_y,device=f'cuda:{gpu_idx}')
|
||||
# stft = STFT(n_fft=n_fft, hop_length=hop_length, window=window, device=f'cuda:{gpu_idx}')
|
||||
# X = stft(x.unsqueeze(0), output_format="Complex")
|
||||
# x_recon = stft.inverse(X, num_samples=x.shape[0]).squeeze()
|
||||
# assert np.allclose(x.cpu(), x_recon.cpu(), rtol=1e-3, atol=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_fft, hop_length, window", stft_parameters)
|
||||
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
|
||||
def test_stft_complex(n_fft, hop_length, window, device):
|
||||
x = example_y
|
||||
stft = STFT(n_fft=n_fft, hop_length=hop_length, window=window).to(device)
|
||||
X = stft(torch.tensor(x, device=device).unsqueeze(0), output_format="Complex")
|
||||
X_real, X_imag = X[:, :, :, 0].squeeze(), X[:, :, :, 1].squeeze()
|
||||
X_librosa = librosa.stft(x, n_fft=n_fft, hop_length=hop_length, window=window)
|
||||
real_diff, imag_diff = np.allclose(X_real.cpu(), X_librosa.real, rtol=1e-3, atol=1e-3), \
|
||||
np.allclose(X_imag.cpu(), X_librosa.imag, rtol=1e-3, atol=1e-3)
|
||||
|
||||
assert real_diff and imag_diff
|
||||
|
||||
# @pytest.mark.parametrize("n_fft, hop_length, window", stft_parameters)
|
||||
# def test_stft_complex_GPU(n_fft, hop_length, window):
|
||||
# x = example_y
|
||||
# stft = STFT(n_fft=n_fft, hop_length=hop_length, window=window, device=f'cuda:{gpu_idx}')
|
||||
# X = stft(torch.tensor(x,device=f'cuda:{gpu_idx}').unsqueeze(0), output_format="Complex")
|
||||
# X_real, X_imag = X[:, :, :, 0].squeeze().detach().cpu(), X[:, :, :, 1].squeeze().detach().cpu()
|
||||
# X_librosa = librosa.stft(x, n_fft=n_fft, hop_length=hop_length, window=window)
|
||||
# real_diff, imag_diff = np.allclose(X_real, X_librosa.real, rtol=1e-3, atol=1e-3), \
|
||||
# np.allclose(X_imag, X_librosa.imag, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# assert real_diff and imag_diff
|
||||
|
||||
@pytest.mark.parametrize("n_fft, win_length, hop_length", stft_with_win_parameters)
|
||||
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
|
||||
def test_stft_complex_winlength(n_fft, win_length, hop_length, device):
|
||||
x = example_y
|
||||
stft = STFT(n_fft=n_fft, win_length=win_length, hop_length=hop_length).to(device)
|
||||
X = stft(torch.tensor(x, device=device).unsqueeze(0), output_format="Complex")
|
||||
X_real, X_imag = X[:, :, :, 0].squeeze(), X[:, :, :, 1].squeeze()
|
||||
X_librosa = librosa.stft(x, n_fft=n_fft, win_length=win_length, hop_length=hop_length)
|
||||
real_diff, imag_diff = np.allclose(X_real.cpu(), X_librosa.real, rtol=1e-3, atol=1e-3), \
|
||||
np.allclose(X_imag.cpu(), X_librosa.imag, rtol=1e-3, atol=1e-3)
|
||||
assert real_diff and imag_diff
|
||||
|
||||
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
|
||||
def test_stft_magnitude(device):
|
||||
x = example_y
|
||||
stft = STFT(n_fft=2048, hop_length=512).to(device)
|
||||
X = stft(torch.tensor(x, device=device).unsqueeze(0), output_format="Magnitude").squeeze()
|
||||
X_librosa, _ = librosa.core.magphase(librosa.stft(x, n_fft=2048, hop_length=512))
|
||||
assert np.allclose(X.cpu(), X_librosa, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
|
||||
def test_stft_phase(device):
|
||||
x = example_y
|
||||
stft = STFT(n_fft=2048, hop_length=512).to(device)
|
||||
X = stft(torch.tensor(x, device=device).unsqueeze(0), output_format="Phase")
|
||||
X_real, X_imag = torch.cos(X).squeeze(), torch.sin(X).squeeze()
|
||||
_, X_librosa = librosa.core.magphase(librosa.stft(x, n_fft=2048, hop_length=512))
|
||||
|
||||
real_diff, imag_diff = np.mean(np.abs(X_real.cpu().numpy() - X_librosa.real)), \
|
||||
np.mean(np.abs(X_imag.cpu().numpy() - X_librosa.imag))
|
||||
|
||||
# I find that np.allclose is too strict for allowing phase to be similar to librosa.
|
||||
# Hence for phase we use average element-wise distance as the test metric.
|
||||
assert real_diff < 2e-4 and imag_diff < 2e-4
|
||||
|
||||
@pytest.mark.parametrize("n_fft, win_length", mel_win_parameters)
|
||||
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
|
||||
def test_mel_spectrogram(n_fft, win_length, device):
|
||||
x = example_y
|
||||
melspec = MelSpectrogram(n_fft=n_fft, win_length=win_length, hop_length=512).to(device)
|
||||
X = melspec(torch.tensor(x, device=device).unsqueeze(0)).squeeze()
|
||||
X_librosa = librosa.feature.melspectrogram(x, n_fft=n_fft, win_length=win_length, hop_length=512)
|
||||
assert np.allclose(X.cpu(), X_librosa, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
|
||||
def test_cqt_1992(device):
|
||||
# Log sweep case
|
||||
fs = 44100
|
||||
t = 1
|
||||
f0 = 55
|
||||
f1 = 22050
|
||||
s = np.linspace(0, t, fs*t)
|
||||
x = chirp(s, f0, 1, f1, method='logarithmic')
|
||||
x = x.astype(dtype=np.float32)
|
||||
|
||||
# Magnitude
|
||||
stft = CQT1992(sr=fs, fmin=220, output_format="Magnitude",
|
||||
n_bins=80, bins_per_octave=24).to(device)
|
||||
X = stft(torch.tensor(x, device=device).unsqueeze(0))
|
||||
|
||||
|
||||
# Complex
|
||||
stft = CQT1992(sr=fs, fmin=220, output_format="Complex",
|
||||
n_bins=80, bins_per_octave=24).to(device)
|
||||
X = stft(torch.tensor(x, device=device).unsqueeze(0))
|
||||
|
||||
# Phase
|
||||
stft = CQT1992(sr=fs, fmin=220, output_format="Phase",
|
||||
n_bins=160, bins_per_octave=24).to(device)
|
||||
X = stft(torch.tensor(x, device=device).unsqueeze(0))
|
||||
|
||||
assert True
|
||||
|
||||
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
|
||||
def test_cqt_2010(device):
|
||||
# Log sweep case
|
||||
fs = 44100
|
||||
t = 1
|
||||
f0 = 55
|
||||
f1 = 22050
|
||||
s = np.linspace(0, t, fs*t)
|
||||
x = chirp(s, f0, 1, f1, method='logarithmic')
|
||||
x = x.astype(dtype=np.float32)
|
||||
|
||||
# Magnitude
|
||||
stft = CQT2010(sr=fs, fmin=110, output_format="Magnitude",
|
||||
n_bins=160, bins_per_octave=24).to(device)
|
||||
X = stft(torch.tensor(x, device=device).unsqueeze(0))
|
||||
|
||||
# Complex
|
||||
stft = CQT2010(sr=fs, fmin=110, output_format="Complex",
|
||||
n_bins=160, bins_per_octave=24).to(device)
|
||||
X = stft(torch.tensor(x, device=device).unsqueeze(0))
|
||||
|
||||
# Phase
|
||||
stft = CQT2010(sr=fs, fmin=110, output_format="Phase",
|
||||
n_bins=160, bins_per_octave=24).to(device)
|
||||
X = stft(torch.tensor(x, device=device).unsqueeze(0))
|
||||
assert True
|
||||
|
||||
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
|
||||
def test_cqt_1992_v2_log(device):
|
||||
# Log sweep case
|
||||
fs = 44100
|
||||
t = 1
|
||||
f0 = 55
|
||||
f1 = 22050
|
||||
s = np.linspace(0, t, fs*t)
|
||||
x = chirp(s, f0, 1, f1, method='logarithmic')
|
||||
x = x.astype(dtype=np.float32)
|
||||
|
||||
# Magnitude
|
||||
stft = CQT1992v2(sr=fs, fmin=55, output_format="Magnitude",
|
||||
n_bins=207, bins_per_octave=24).to(device)
|
||||
X = stft(torch.tensor(x, device=device).unsqueeze(0))
|
||||
ground_truth = np.load("tests/ground-truths/log-sweep-cqt-1992-mag-ground-truth.npy")
|
||||
X = torch.log(X + 1e-5)
|
||||
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# Complex
|
||||
stft = CQT1992v2(sr=fs, fmin=55, output_format="Complex",
|
||||
n_bins=207, bins_per_octave=24).to(device)
|
||||
X = stft(torch.tensor(x, device=device).unsqueeze(0))
|
||||
ground_truth = np.load("tests/ground-truths/log-sweep-cqt-1992-complex-ground-truth.npy")
|
||||
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# Phase
|
||||
stft = CQT1992v2(sr=fs, fmin=55, output_format="Phase",
|
||||
n_bins=207, bins_per_octave=24).to(device)
|
||||
X = stft(torch.tensor(x, device=device).unsqueeze(0))
|
||||
ground_truth = np.load("tests/ground-truths/log-sweep-cqt-1992-phase-ground-truth.npy")
|
||||
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
|
||||
def test_cqt_1992_v2_linear(device):
|
||||
# Linear sweep case
|
||||
fs = 44100
|
||||
t = 1
|
||||
f0 = 55
|
||||
f1 = 22050
|
||||
s = np.linspace(0, t, fs*t)
|
||||
x = chirp(s, f0, 1, f1, method='linear')
|
||||
x = x.astype(dtype=np.float32)
|
||||
|
||||
# Magnitude
|
||||
stft = CQT1992v2(sr=fs, fmin=55, output_format="Magnitude",
|
||||
n_bins=207, bins_per_octave=24).to(device)
|
||||
X = stft(torch.tensor(x, device=device).unsqueeze(0))
|
||||
ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-1992-mag-ground-truth.npy")
|
||||
X = torch.log(X + 1e-5)
|
||||
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# Complex
|
||||
stft = CQT1992v2(sr=fs, fmin=55, output_format="Complex",
|
||||
n_bins=207, bins_per_octave=24).to(device)
|
||||
X = stft(torch.tensor(x, device=device).unsqueeze(0))
|
||||
ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-1992-complex-ground-truth.npy")
|
||||
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# Phase
|
||||
stft = CQT1992v2(sr=fs, fmin=55, output_format="Phase",
|
||||
n_bins=207, bins_per_octave=24).to(device)
|
||||
X = stft(torch.tensor(x, device=device).unsqueeze(0))
|
||||
ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-1992-phase-ground-truth.npy")
|
||||
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
|
||||
def test_cqt_2010_v2_log(device):
|
||||
# Log sweep case
|
||||
fs = 44100
|
||||
t = 1
|
||||
f0 = 55
|
||||
f1 = 22050
|
||||
s = np.linspace(0, t, fs*t)
|
||||
x = chirp(s, f0, 1, f1, method='logarithmic')
|
||||
x = x.astype(dtype=np.float32)
|
||||
|
||||
# Magnitude
|
||||
stft = CQT2010v2(sr=fs, fmin=55, output_format="Magnitude",
|
||||
n_bins=207, bins_per_octave=24).to(device)
|
||||
X = stft(torch.tensor(x, device=device).unsqueeze(0))
|
||||
X = torch.log(X + 1e-2)
|
||||
# np.save("tests/ground-truths/log-sweep-cqt-2010-mag-ground-truth", X.cpu())
|
||||
ground_truth = np.load("tests/ground-truths/log-sweep-cqt-2010-mag-ground-truth.npy")
|
||||
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# Complex
|
||||
stft = CQT2010v2(sr=fs, fmin=55, output_format="Complex",
|
||||
n_bins=207, bins_per_octave=24).to(device)
|
||||
X = stft(torch.tensor(x, device=device).unsqueeze(0))
|
||||
# np.save("tests/ground-truths/log-sweep-cqt-2010-complex-ground-truth", X.cpu())
|
||||
ground_truth = np.load("tests/ground-truths/log-sweep-cqt-2010-complex-ground-truth.npy")
|
||||
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# # Phase
|
||||
# stft = CQT2010v2(sr=fs, fmin=55, device=device, output_format="Phase",
|
||||
# n_bins=207, bins_per_octave=24)
|
||||
# X = stft(torch.tensor(x, device=device).unsqueeze(0))
|
||||
# # np.save("tests/ground-truths/log-sweep-cqt-2010-phase-ground-truth", X.cpu())
|
||||
# ground_truth = np.load("tests/ground-truths/log-sweep-cqt-2010-phase-ground-truth.npy")
|
||||
# assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
|
||||
def test_cqt_2010_v2_linear(device):
|
||||
# Linear sweep case
|
||||
fs = 44100
|
||||
t = 1
|
||||
f0 = 55
|
||||
f1 = 22050
|
||||
s = np.linspace(0, t, fs*t)
|
||||
x = chirp(s, f0, 1, f1, method='linear')
|
||||
x = x.astype(dtype=np.float32)
|
||||
|
||||
# Magnitude
|
||||
stft = CQT2010v2(sr=fs, fmin=55, output_format="Magnitude",
|
||||
n_bins=207, bins_per_octave=24).to(device)
|
||||
X = stft(torch.tensor(x, device=device).unsqueeze(0))
|
||||
X = torch.log(X + 1e-2)
|
||||
# np.save("tests/ground-truths/linear-sweep-cqt-2010-mag-ground-truth", X.cpu())
|
||||
ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-2010-mag-ground-truth.npy")
|
||||
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# Complex
|
||||
stft = CQT2010v2(sr=fs, fmin=55, output_format="Complex",
|
||||
n_bins=207, bins_per_octave=24).to(device)
|
||||
X = stft(torch.tensor(x, device=device).unsqueeze(0))
|
||||
# np.save("tests/ground-truths/linear-sweep-cqt-2010-complex-ground-truth", X.cpu())
|
||||
ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-2010-complex-ground-truth.npy")
|
||||
assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# Phase
|
||||
# stft = CQT2010v2(sr=fs, fmin=55, device=device, output_format="Phase",
|
||||
# n_bins=207, bins_per_octave=24)
|
||||
# X = stft(torch.tensor(x, device=device).unsqueeze(0))
|
||||
# # np.save("tests/ground-truths/linear-sweep-cqt-2010-phase-ground-truth", X.cpu())
|
||||
# ground_truth = np.load("tests/ground-truths/linear-sweep-cqt-2010-phase-ground-truth.npy")
|
||||
# assert np.allclose(X.cpu(), ground_truth, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@pytest.mark.parametrize("device", ['cpu', f'cuda:{gpu_idx}'])
|
||||
def test_mfcc(device):
|
||||
x = example_y
|
||||
mfcc = MFCC(sr=example_sr).to(device)
|
||||
X = mfcc(torch.tensor(x, device=device).unsqueeze(0)).squeeze()
|
||||
X_librosa = librosa.feature.mfcc(x, sr=example_sr)
|
||||
assert np.allclose(X.cpu(), X_librosa, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
x = torch.randn((4,44100)) # Create a batch of input for the following Data.Parallel test
|
||||
|
||||
@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}'])
|
||||
def test_STFT_Parallel(device):
|
||||
spec_layer = STFT(hop_length=512, n_fft=2048, window='hann',
|
||||
freq_scale='no',
|
||||
output_format='Complex').to(device)
|
||||
inverse_spec_layer = iSTFT(hop_length=512, n_fft=2048, window='hann',
|
||||
freq_scale='no').to(device)
|
||||
|
||||
spec_layer_parallel = torch.nn.DataParallel(spec_layer)
|
||||
inverse_spec_layer_parallel = torch.nn.DataParallel(inverse_spec_layer)
|
||||
spec = spec_layer_parallel(x)
|
||||
x_recon = inverse_spec_layer_parallel(spec, onesided=True, length=x.shape[-1])
|
||||
|
||||
assert np.allclose(x_recon.detach().cpu(), x.detach().cpu(), rtol=1e-3, atol=1e-3)
|
||||
|
||||
@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}'])
|
||||
def test_MelSpectrogram_Parallel(device):
|
||||
spec_layer = MelSpectrogram(sr=22050, n_fft=2048, n_mels=128, hop_length=512,
|
||||
window='hann', center=True, pad_mode='reflect',
|
||||
power=2.0, htk=False, fmin=0.0, fmax=None, norm=1,
|
||||
verbose=True).to(device)
|
||||
spec_layer_parallel = torch.nn.DataParallel(spec_layer)
|
||||
spec = spec_layer_parallel(x)
|
||||
|
||||
@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}'])
|
||||
def test_MFCC_Parallel(device):
|
||||
spec_layer = MFCC().to(device)
|
||||
spec_layer_parallel = torch.nn.DataParallel(spec_layer)
|
||||
spec = spec_layer_parallel(x)
|
||||
|
||||
@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}'])
|
||||
def test_CQT1992_Parallel(device):
|
||||
spec_layer = CQT1992(fmin=110, n_bins=60, bins_per_octave=12).to(device)
|
||||
spec_layer_parallel = torch.nn.DataParallel(spec_layer)
|
||||
spec = spec_layer_parallel(x)
|
||||
|
||||
@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}'])
|
||||
def test_CQT1992v2_Parallel(device):
|
||||
spec_layer = CQT1992v2().to(device)
|
||||
spec_layer_parallel = torch.nn.DataParallel(spec_layer)
|
||||
spec = spec_layer_parallel(x)
|
||||
|
||||
@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}'])
|
||||
def test_CQT2010_Parallel(device):
|
||||
spec_layer = CQT2010().to(device)
|
||||
spec_layer_parallel = torch.nn.DataParallel(spec_layer)
|
||||
spec = spec_layer_parallel(x)
|
||||
|
||||
@pytest.mark.parametrize("device", [f'cuda:{gpu_idx}'])
|
||||
def test_CQT2010v2_Parallel(device):
|
||||
spec_layer = CQT2010v2().to(device)
|
||||
spec_layer_parallel = torch.nn.DataParallel(spec_layer)
|
||||
spec = spec_layer_parallel(x)
|
Loading…
Reference in new issue