# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import librosa.display
import matplotlib.pylab as plt

__all__ = [
    "plot_alignment",
    "plot_spectrogram",
    "plot_waveform",
    "plot_multihead_alignments",
    "plot_multilayer_multihead_alignments",
]


def plot_alignment(alignment, title=None):
    # alignment: [encoder_steps, decoder_steps)
    fig, ax = plt.subplots(figsize=(6, 4))
    im = ax.imshow(
        alignment, aspect='auto', origin='lower', interpolation='none')
    fig.colorbar(im, ax=ax)
    xlabel = 'Decoder timestep'
    if title is not None:
        xlabel += '\n\n' + title
    plt.xlabel(xlabel)
    plt.ylabel('Encoder timestep')
    plt.tight_layout()
    return fig


def plot_multihead_alignments(alignments, title=None):
    # alignments: [N, encoder_steps, decoder_steps)
    num_subplots = alignments.shape[0]

    fig, axes = plt.subplots(
        figsize=(6 * num_subplots, 4),
        ncols=num_subplots,
        sharey=True,
        squeeze=True)
    for i, ax in enumerate(axes):
        im = ax.imshow(
            alignments[i], aspect='auto', origin='lower', interpolation='none')
        fig.colorbar(im, ax=ax)
        xlabel = 'Decoder timestep'
        if title is not None:
            xlabel += '\n\n' + title
        ax.set_xlabel(xlabel)
        if i == 0:
            ax.set_ylabel('Encoder timestep')
    plt.tight_layout()
    return fig


def plot_multilayer_multihead_alignments(alignments, title=None):
    # alignments: [num_layers, num_heads, encoder_steps, decoder_steps)
    num_layers, num_heads, *_ = alignments.shape

    fig, axes = plt.subplots(
        figsize=(6 * num_heads, 4 * num_layers),
        nrows=num_layers,
        ncols=num_heads,
        sharex=True,
        sharey=True,
        squeeze=True)
    for i, row in enumerate(axes):
        for j, ax in enumerate(row):
            im = ax.imshow(
                alignments[i, j],
                aspect='auto',
                origin='lower',
                interpolation='none')
            fig.colorbar(im, ax=ax)
            xlabel = 'Decoder timestep'
            if title is not None:
                xlabel += '\n\n' + title
            if i == num_layers - 1:
                ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel('Encoder timestep')
    plt.tight_layout()
    return fig


def plot_spectrogram(spec):
    # spec: [C, T] librosa convention
    fig, ax = plt.subplots(figsize=(12, 3))
    im = ax.imshow(spec, aspect="auto", origin="lower", interpolation='none')
    plt.colorbar(im, ax=ax)
    plt.xlabel("Frames")
    plt.ylabel("Channels")
    plt.tight_layout()
    return fig


def plot_waveform(wav, sr=22050):
    fig, ax = plt.subplots(figsize=(12, 3))
    im = librosa.display.waveplot(wav, sr=22050)
    plt.colorbar(im, ax=ax)
    plt.tight_layout()
    return fig