You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
111 lines
3.3 KiB
111 lines
3.3 KiB
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import librosa.display
|
|
import matplotlib.pylab as plt
|
|
|
|
__all__ = [
|
|
"plot_alignment",
|
|
"plot_spectrogram",
|
|
"plot_waveform",
|
|
"plot_multihead_alignments",
|
|
"plot_multilayer_multihead_alignments",
|
|
]
|
|
|
|
|
|
def plot_alignment(alignment, title=None):
|
|
# alignment: [encoder_steps, decoder_steps)
|
|
fig, ax = plt.subplots(figsize=(6, 4))
|
|
im = ax.imshow(
|
|
alignment, aspect='auto', origin='lower', interpolation='none')
|
|
fig.colorbar(im, ax=ax)
|
|
xlabel = 'Decoder timestep'
|
|
if title is not None:
|
|
xlabel += '\n\n' + title
|
|
plt.xlabel(xlabel)
|
|
plt.ylabel('Encoder timestep')
|
|
plt.tight_layout()
|
|
return fig
|
|
|
|
|
|
def plot_multihead_alignments(alignments, title=None):
|
|
# alignments: [N, encoder_steps, decoder_steps)
|
|
num_subplots = alignments.shape[0]
|
|
|
|
fig, axes = plt.subplots(
|
|
figsize=(6 * num_subplots, 4),
|
|
ncols=num_subplots,
|
|
sharey=True,
|
|
squeeze=True)
|
|
for i, ax in enumerate(axes):
|
|
im = ax.imshow(
|
|
alignments[i], aspect='auto', origin='lower', interpolation='none')
|
|
fig.colorbar(im, ax=ax)
|
|
xlabel = 'Decoder timestep'
|
|
if title is not None:
|
|
xlabel += '\n\n' + title
|
|
ax.set_xlabel(xlabel)
|
|
if i == 0:
|
|
ax.set_ylabel('Encoder timestep')
|
|
plt.tight_layout()
|
|
return fig
|
|
|
|
|
|
def plot_multilayer_multihead_alignments(alignments, title=None):
|
|
# alignments: [num_layers, num_heads, encoder_steps, decoder_steps)
|
|
num_layers, num_heads, *_ = alignments.shape
|
|
|
|
fig, axes = plt.subplots(
|
|
figsize=(6 * num_heads, 4 * num_layers),
|
|
nrows=num_layers,
|
|
ncols=num_heads,
|
|
sharex=True,
|
|
sharey=True,
|
|
squeeze=True)
|
|
for i, row in enumerate(axes):
|
|
for j, ax in enumerate(row):
|
|
im = ax.imshow(
|
|
alignments[i, j],
|
|
aspect='auto',
|
|
origin='lower',
|
|
interpolation='none')
|
|
fig.colorbar(im, ax=ax)
|
|
xlabel = 'Decoder timestep'
|
|
if title is not None:
|
|
xlabel += '\n\n' + title
|
|
if i == num_layers - 1:
|
|
ax.set_xlabel(xlabel)
|
|
if j == 0:
|
|
ax.set_ylabel('Encoder timestep')
|
|
plt.tight_layout()
|
|
return fig
|
|
|
|
|
|
def plot_spectrogram(spec):
|
|
# spec: [C, T] librosa convention
|
|
fig, ax = plt.subplots(figsize=(12, 3))
|
|
im = ax.imshow(spec, aspect="auto", origin="lower", interpolation='none')
|
|
plt.colorbar(im, ax=ax)
|
|
plt.xlabel("Frames")
|
|
plt.ylabel("Channels")
|
|
plt.tight_layout()
|
|
return fig
|
|
|
|
|
|
def plot_waveform(wav, sr=22050):
|
|
fig, ax = plt.subplots(figsize=(12, 3))
|
|
im = librosa.display.waveplot(wav, sr=22050)
|
|
plt.colorbar(im, ax=ax)
|
|
plt.tight_layout()
|
|
return fig
|