Merge pull request #1493 from zh794390558/dtw

[audio] dtw metric
pull/1518/head
Hui Zhang 3 years ago committed by GitHub
commit 39e5a81a2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,3 +2,4 @@
Date: 2022-2-25, Author: Hui Zhang. Date: 2022-2-25, Author: Hui Zhang.
- Refactor architecture. - Refactor architecture.
- dtw distance and mcd style dtw

@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .dtw import dtw_distance
from .mcd import mcd_distance

@ -0,0 +1,42 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from dtaidistance import dtw_ndim
__all__ = [
'dtw_distance',
]
def dtw_distance(xs: np.ndarray, ys: np.ndarray) -> float:
"""dtw distance
Dynamic Time Warping.
This function keeps a compact matrix, not the full warping paths matrix.
Uses dynamic programming to compute:
wps[i, j] = (s1[i]-s2[j])**2 + min(
wps[i-1, j ] + penalty, // vertical / insertion / expansion
wps[i , j-1] + penalty, // horizontal / deletion / compression
wps[i-1, j-1]) // diagonal / match
dtw = sqrt(wps[-1, -1])
Args:
xs (np.ndarray): ref sequence, [T,D]
ys (np.ndarray): hyp sequence, [T,D]
Returns:
float: dtw distance
"""
return dtw_ndim.distance(xs, ys)

@ -0,0 +1,47 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import mcd.metrics_fast as mt
from mcd import dtw
__all__ = [
'mcd_distance',
]
def mcd_distance(xs: np.ndarray, ys: np.ndarray, cost_fn=mt.logSpecDbDist):
"""Mel cepstral distortion (MCD), dtw distance.
Dynamic Time Warping.
Uses dynamic programming to compute:
wps[i, j] = cost_fn(xs[i], ys[j]) + min(
wps[i-1, j ], // vertical / insertion / expansion
wps[i , j-1], // horizontal / deletion / compression
wps[i-1, j-1]) // diagonal / match
dtw = sqrt(wps[-1, -1])
Cost Function:
logSpecDbConst = 10.0 / math.log(10.0) * math.sqrt(2.0)
def logSpecDbDist(x, y):
diff = x - y
return logSpecDbConst * math.sqrt(np.inner(diff, diff))
Args:
xs (np.ndarray): ref sequence, [T,D]
ys (np.ndarray): hyp sequence, [T,D]
Returns:
float: dtw distance
"""
min_cost, path = dtw.dtw(xs, ys, cost_fn)
return min_cost

@ -59,6 +59,8 @@ setuptools.setup(
'resampy >= 0.2.2', 'resampy >= 0.2.2',
'soundfile >= 0.9.0', 'soundfile >= 0.9.0',
'colorlog', 'colorlog',
'dtaidistance >= 2.3.6',
'mcd >= 0.4',
], ) ], )
remove_version_py() remove_version_py()

Loading…
Cancel
Save