You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
296 lines
11 KiB
296 lines
11 KiB
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import pandas as pd
|
|
import soundfile as sf
|
|
from matplotlib.ticker import MaxNLocator
|
|
|
|
from paddlespeech.s2t.utils.log import Log
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
|
|
def plot_waveform(audio_path, output_dir=None, show=True):
|
|
"""Plot waveform of audio file."""
|
|
try:
|
|
audio, sr = sf.read(audio_path)
|
|
if audio.ndim > 1:
|
|
audio = audio[:, 0] # Take first channel if stereo
|
|
|
|
duration = len(audio) / sr
|
|
time_axis = np.linspace(0, duration, len(audio))
|
|
|
|
plt.figure(figsize=(12, 4))
|
|
plt.plot(time_axis, audio, color='#1f77b4')
|
|
plt.title(f"Waveform: {os.path.basename(audio_path)}")
|
|
plt.xlabel("Time (seconds)")
|
|
plt.ylabel("Amplitude")
|
|
plt.grid(alpha=0.3)
|
|
|
|
if output_dir:
|
|
output_path = os.path.join(output_dir, f"{os.path.splitext(os.path.basename(audio_path))[0]}_waveform.png")
|
|
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
|
logger.info(f"Waveform plot saved to {output_path}")
|
|
|
|
if show:
|
|
plt.show()
|
|
else:
|
|
plt.close()
|
|
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error plotting waveform: {e}")
|
|
return False
|
|
|
|
|
|
def plot_transcription_comparison(reference, hypothesis, audio_path=None, output_dir=None, show=True):
|
|
"""Plot comparison between reference and hypothesis transcriptions."""
|
|
try:
|
|
fig, ax = plt.subplots(figsize=(12, 6))
|
|
|
|
# Title with audio file name if provided
|
|
title = "Reference vs. Hypothesis"
|
|
if audio_path:
|
|
title += f": {os.path.basename(audio_path)}"
|
|
plt.title(title)
|
|
|
|
# Create comparison table
|
|
rows = []
|
|
rows.append(["Reference", reference])
|
|
rows.append(["Hypothesis", hypothesis])
|
|
|
|
# Calculate word-level differences
|
|
ref_words = reference.strip().split()
|
|
hyp_words = hypothesis.strip().split()
|
|
|
|
from difflib import SequenceMatcher
|
|
matcher = SequenceMatcher(None, ref_words, hyp_words)
|
|
|
|
# Generate color-coded text for differences
|
|
ref_colored = []
|
|
hyp_colored = []
|
|
|
|
for op, i1, i2, j1, j2 in matcher.get_opcodes():
|
|
if op == 'equal':
|
|
ref_colored.extend(ref_words[i1:i2])
|
|
hyp_colored.extend(hyp_words[j1:j2])
|
|
elif op == 'replace':
|
|
ref_colored.extend([f"**{w}**" for w in ref_words[i1:i2]])
|
|
hyp_colored.extend([f"**{w}**" for w in hyp_words[j1:j2]])
|
|
elif op == 'delete':
|
|
ref_colored.extend([f"**{w}**" for w in ref_words[i1:i2]])
|
|
elif op == 'insert':
|
|
hyp_colored.extend([f"**{w}**" for w in hyp_words[j1:j2]])
|
|
|
|
rows.append(["Ref (Diff)", " ".join(ref_colored)])
|
|
rows.append(["Hyp (Diff)", " ".join(hyp_colored)])
|
|
|
|
# Calculate error metrics
|
|
word_error = sum(1 for i, j in zip(ref_words, hyp_words) if i != j)
|
|
if len(ref_words) > 0:
|
|
word_error_rate = word_error / len(ref_words)
|
|
else:
|
|
word_error_rate = 0
|
|
|
|
rows.append(["Word Error", f"{word_error}/{len(ref_words)} ({word_error_rate:.2%})"])
|
|
|
|
# Create table
|
|
ax.axis('tight')
|
|
ax.axis('off')
|
|
table = ax.table(cellText=rows, colWidths=[0.15, 0.85], loc='center', cellLoc='left')
|
|
table.auto_set_font_size(False)
|
|
table.set_fontsize(10)
|
|
table.scale(1, 1.5)
|
|
|
|
# Highlight difference rows with light blue background
|
|
for i in [2, 3]:
|
|
for j in range(2):
|
|
cell = table._cells[(i, j)]
|
|
cell.set_facecolor('#E6F3FF')
|
|
|
|
if output_dir:
|
|
base_name = os.path.splitext(os.path.basename(audio_path))[0] if audio_path else "comparison"
|
|
output_path = os.path.join(output_dir, f"{base_name}_comparison.png")
|
|
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
|
logger.info(f"Comparison plot saved to {output_path}")
|
|
|
|
if show:
|
|
plt.show()
|
|
else:
|
|
plt.close()
|
|
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error plotting transcription comparison: {e}")
|
|
return False
|
|
|
|
|
|
def plot_evaluation_metrics(results_file, output_dir=None, show=True):
|
|
"""Plot evaluation metrics from results file."""
|
|
try:
|
|
with open(results_file, 'r', encoding='utf-8') as f:
|
|
results = json.load(f)
|
|
|
|
metrics = results['metrics']
|
|
details = results['details']
|
|
|
|
# Extract duration information
|
|
durations = [item['duration'] for item in details if 'duration' in item]
|
|
if not durations:
|
|
durations = [0] * len(details) # Default if no durations available
|
|
|
|
# Calculate per-sample WER
|
|
from paddlespeech.metrics.wer import word_errors
|
|
wers = []
|
|
for item in details:
|
|
ref = item['reference']
|
|
hyp = item['hypothesis']
|
|
word_error_count, word_count = word_errors(ref, hyp)
|
|
wer = word_error_count / max(1, word_count)
|
|
wers.append(wer * 100) # Convert to percentage
|
|
|
|
# Create plots
|
|
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
|
|
|
# Overall metrics
|
|
ax = axes[0, 0]
|
|
metric_names = ['WER', 'CER']
|
|
metric_values = [metrics['wer'] * 100, metrics['cer'] * 100] # Convert to percentage
|
|
ax.bar(metric_names, metric_values, color=['#1f77b4', '#ff7f0e'])
|
|
ax.set_title('Overall Error Metrics')
|
|
ax.set_ylabel('Error Rate (%)')
|
|
ax.grid(axis='y', alpha=0.3)
|
|
for i, v in enumerate(metric_values):
|
|
ax.text(i, v + 0.5, f"{v:.2f}%", ha='center')
|
|
|
|
# WER vs Duration scatter plot
|
|
ax = axes[0, 1]
|
|
ax.scatter(durations, wers, alpha=0.6)
|
|
ax.set_title('WER vs Audio Duration')
|
|
ax.set_xlabel('Duration (seconds)')
|
|
ax.set_ylabel('WER (%)')
|
|
ax.grid(alpha=0.3)
|
|
|
|
# WER distribution histogram
|
|
ax = axes[1, 0]
|
|
ax.hist(wers, bins=20, color='#2ca02c', alpha=0.7)
|
|
ax.set_title('WER Distribution')
|
|
ax.set_xlabel('WER (%)')
|
|
ax.set_ylabel('Number of Samples')
|
|
ax.grid(alpha=0.3)
|
|
|
|
# Sample count by WER range
|
|
ax = axes[1, 1]
|
|
wer_ranges = [0, 5, 10, 20, 30, 50, 100]
|
|
wer_bins = pd.cut(wers, wer_ranges, right=False)
|
|
wer_counts = pd.value_counts(wer_bins).sort_index()
|
|
|
|
wer_labels = [f"{wer_ranges[i]}-{wer_ranges[i+1]}%" for i in range(len(wer_ranges)-1)]
|
|
ax.bar(wer_labels, wer_counts, color='#d62728')
|
|
ax.set_title('Sample Count by WER Range')
|
|
ax.set_xlabel('WER Range')
|
|
ax.set_ylabel('Number of Samples')
|
|
ax.tick_params(axis='x', rotation=45)
|
|
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
|
|
ax.grid(axis='y', alpha=0.3)
|
|
|
|
plt.tight_layout()
|
|
|
|
if output_dir:
|
|
output_path = os.path.join(output_dir, "evaluation_metrics.png")
|
|
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
|
logger.info(f"Metrics plots saved to {output_path}")
|
|
|
|
if show:
|
|
plt.show()
|
|
else:
|
|
plt.close()
|
|
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error plotting evaluation metrics: {e}")
|
|
return False
|
|
|
|
|
|
def visualize_results(results_file, output_dir=None, audio_dir=None, num_samples=5, show=True):
|
|
"""Visualize evaluation results."""
|
|
try:
|
|
# Create output directory if needed
|
|
if output_dir:
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
# Load results
|
|
with open(results_file, 'r', encoding='utf-8') as f:
|
|
results = json.load(f)
|
|
|
|
# Plot overall metrics
|
|
plot_evaluation_metrics(results_file, output_dir, show)
|
|
|
|
# Plot individual examples
|
|
if num_samples > 0:
|
|
details = results['details']
|
|
samples = details[:num_samples]
|
|
|
|
for i, sample in enumerate(samples):
|
|
audio_path = sample['audio']
|
|
reference = sample['reference']
|
|
hypothesis = sample['hypothesis']
|
|
|
|
# If audio directory is provided, use audio files from there
|
|
if audio_dir:
|
|
base_name = os.path.basename(audio_path)
|
|
audio_path = os.path.join(audio_dir, base_name)
|
|
|
|
# Plot waveform if audio file exists
|
|
if os.path.exists(audio_path):
|
|
plot_waveform(audio_path, output_dir, show)
|
|
|
|
# Plot transcription comparison
|
|
plot_transcription_comparison(reference, hypothesis, audio_path, output_dir, show)
|
|
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error visualizing results: {e}")
|
|
return False
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Visualize Whisper evaluation results")
|
|
parser.add_argument("--results_file", type=str, required=True, help="Path to evaluation results JSON file")
|
|
parser.add_argument("--output_dir", type=str, default="./visualizations", help="Directory to save visualization outputs")
|
|
parser.add_argument("--audio_dir", type=str, default=None, help="Directory containing audio files (optional)")
|
|
parser.add_argument("--num_samples", type=int, default=5, help="Number of individual samples to visualize")
|
|
parser.add_argument("--show", action="store_true", help="Show plots interactively")
|
|
|
|
args = parser.parse_args()
|
|
|
|
visualize_results(
|
|
results_file=args.results_file,
|
|
output_dir=args.output_dir,
|
|
audio_dir=args.audio_dir,
|
|
num_samples=args.num_samples,
|
|
show=args.show
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|