You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/examples/commonvoice/whisper/visualize.py

296 lines
11 KiB

# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import soundfile as sf
from matplotlib.ticker import MaxNLocator
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
def plot_waveform(audio_path, output_dir=None, show=True):
"""Plot waveform of audio file."""
try:
audio, sr = sf.read(audio_path)
if audio.ndim > 1:
audio = audio[:, 0] # Take first channel if stereo
duration = len(audio) / sr
time_axis = np.linspace(0, duration, len(audio))
plt.figure(figsize=(12, 4))
plt.plot(time_axis, audio, color='#1f77b4')
plt.title(f"Waveform: {os.path.basename(audio_path)}")
plt.xlabel("Time (seconds)")
plt.ylabel("Amplitude")
plt.grid(alpha=0.3)
if output_dir:
output_path = os.path.join(output_dir, f"{os.path.splitext(os.path.basename(audio_path))[0]}_waveform.png")
plt.savefig(output_path, dpi=150, bbox_inches='tight')
logger.info(f"Waveform plot saved to {output_path}")
if show:
plt.show()
else:
plt.close()
return True
except Exception as e:
logger.error(f"Error plotting waveform: {e}")
return False
def plot_transcription_comparison(reference, hypothesis, audio_path=None, output_dir=None, show=True):
"""Plot comparison between reference and hypothesis transcriptions."""
try:
fig, ax = plt.subplots(figsize=(12, 6))
# Title with audio file name if provided
title = "Reference vs. Hypothesis"
if audio_path:
title += f": {os.path.basename(audio_path)}"
plt.title(title)
# Create comparison table
rows = []
rows.append(["Reference", reference])
rows.append(["Hypothesis", hypothesis])
# Calculate word-level differences
ref_words = reference.strip().split()
hyp_words = hypothesis.strip().split()
from difflib import SequenceMatcher
matcher = SequenceMatcher(None, ref_words, hyp_words)
# Generate color-coded text for differences
ref_colored = []
hyp_colored = []
for op, i1, i2, j1, j2 in matcher.get_opcodes():
if op == 'equal':
ref_colored.extend(ref_words[i1:i2])
hyp_colored.extend(hyp_words[j1:j2])
elif op == 'replace':
ref_colored.extend([f"**{w}**" for w in ref_words[i1:i2]])
hyp_colored.extend([f"**{w}**" for w in hyp_words[j1:j2]])
elif op == 'delete':
ref_colored.extend([f"**{w}**" for w in ref_words[i1:i2]])
elif op == 'insert':
hyp_colored.extend([f"**{w}**" for w in hyp_words[j1:j2]])
rows.append(["Ref (Diff)", " ".join(ref_colored)])
rows.append(["Hyp (Diff)", " ".join(hyp_colored)])
# Calculate error metrics
word_error = sum(1 for i, j in zip(ref_words, hyp_words) if i != j)
if len(ref_words) > 0:
word_error_rate = word_error / len(ref_words)
else:
word_error_rate = 0
rows.append(["Word Error", f"{word_error}/{len(ref_words)} ({word_error_rate:.2%})"])
# Create table
ax.axis('tight')
ax.axis('off')
table = ax.table(cellText=rows, colWidths=[0.15, 0.85], loc='center', cellLoc='left')
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 1.5)
# Highlight difference rows with light blue background
for i in [2, 3]:
for j in range(2):
cell = table._cells[(i, j)]
cell.set_facecolor('#E6F3FF')
if output_dir:
base_name = os.path.splitext(os.path.basename(audio_path))[0] if audio_path else "comparison"
output_path = os.path.join(output_dir, f"{base_name}_comparison.png")
plt.savefig(output_path, dpi=150, bbox_inches='tight')
logger.info(f"Comparison plot saved to {output_path}")
if show:
plt.show()
else:
plt.close()
return True
except Exception as e:
logger.error(f"Error plotting transcription comparison: {e}")
return False
def plot_evaluation_metrics(results_file, output_dir=None, show=True):
"""Plot evaluation metrics from results file."""
try:
with open(results_file, 'r', encoding='utf-8') as f:
results = json.load(f)
metrics = results['metrics']
details = results['details']
# Extract duration information
durations = [item['duration'] for item in details if 'duration' in item]
if not durations:
durations = [0] * len(details) # Default if no durations available
# Calculate per-sample WER
from paddlespeech.metrics.wer import word_errors
wers = []
for item in details:
ref = item['reference']
hyp = item['hypothesis']
word_error_count, word_count = word_errors(ref, hyp)
wer = word_error_count / max(1, word_count)
wers.append(wer * 100) # Convert to percentage
# Create plots
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# Overall metrics
ax = axes[0, 0]
metric_names = ['WER', 'CER']
metric_values = [metrics['wer'] * 100, metrics['cer'] * 100] # Convert to percentage
ax.bar(metric_names, metric_values, color=['#1f77b4', '#ff7f0e'])
ax.set_title('Overall Error Metrics')
ax.set_ylabel('Error Rate (%)')
ax.grid(axis='y', alpha=0.3)
for i, v in enumerate(metric_values):
ax.text(i, v + 0.5, f"{v:.2f}%", ha='center')
# WER vs Duration scatter plot
ax = axes[0, 1]
ax.scatter(durations, wers, alpha=0.6)
ax.set_title('WER vs Audio Duration')
ax.set_xlabel('Duration (seconds)')
ax.set_ylabel('WER (%)')
ax.grid(alpha=0.3)
# WER distribution histogram
ax = axes[1, 0]
ax.hist(wers, bins=20, color='#2ca02c', alpha=0.7)
ax.set_title('WER Distribution')
ax.set_xlabel('WER (%)')
ax.set_ylabel('Number of Samples')
ax.grid(alpha=0.3)
# Sample count by WER range
ax = axes[1, 1]
wer_ranges = [0, 5, 10, 20, 30, 50, 100]
wer_bins = pd.cut(wers, wer_ranges, right=False)
wer_counts = pd.value_counts(wer_bins).sort_index()
wer_labels = [f"{wer_ranges[i]}-{wer_ranges[i+1]}%" for i in range(len(wer_ranges)-1)]
ax.bar(wer_labels, wer_counts, color='#d62728')
ax.set_title('Sample Count by WER Range')
ax.set_xlabel('WER Range')
ax.set_ylabel('Number of Samples')
ax.tick_params(axis='x', rotation=45)
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
if output_dir:
output_path = os.path.join(output_dir, "evaluation_metrics.png")
plt.savefig(output_path, dpi=150, bbox_inches='tight')
logger.info(f"Metrics plots saved to {output_path}")
if show:
plt.show()
else:
plt.close()
return True
except Exception as e:
logger.error(f"Error plotting evaluation metrics: {e}")
return False
def visualize_results(results_file, output_dir=None, audio_dir=None, num_samples=5, show=True):
"""Visualize evaluation results."""
try:
# Create output directory if needed
if output_dir:
os.makedirs(output_dir, exist_ok=True)
# Load results
with open(results_file, 'r', encoding='utf-8') as f:
results = json.load(f)
# Plot overall metrics
plot_evaluation_metrics(results_file, output_dir, show)
# Plot individual examples
if num_samples > 0:
details = results['details']
samples = details[:num_samples]
for i, sample in enumerate(samples):
audio_path = sample['audio']
reference = sample['reference']
hypothesis = sample['hypothesis']
# If audio directory is provided, use audio files from there
if audio_dir:
base_name = os.path.basename(audio_path)
audio_path = os.path.join(audio_dir, base_name)
# Plot waveform if audio file exists
if os.path.exists(audio_path):
plot_waveform(audio_path, output_dir, show)
# Plot transcription comparison
plot_transcription_comparison(reference, hypothesis, audio_path, output_dir, show)
return True
except Exception as e:
logger.error(f"Error visualizing results: {e}")
return False
def main():
parser = argparse.ArgumentParser(description="Visualize Whisper evaluation results")
parser.add_argument("--results_file", type=str, required=True, help="Path to evaluation results JSON file")
parser.add_argument("--output_dir", type=str, default="./visualizations", help="Directory to save visualization outputs")
parser.add_argument("--audio_dir", type=str, default=None, help="Directory containing audio files (optional)")
parser.add_argument("--num_samples", type=int, default=5, help="Number of individual samples to visualize")
parser.add_argument("--show", action="store_true", help="Show plots interactively")
args = parser.parse_args()
visualize_results(
results_file=args.results_file,
output_dir=args.output_dir,
audio_dir=args.audio_dir,
num_samples=args.num_samples,
show=args.show
)
if __name__ == "__main__":
main()