pull/4081/merge
Biaolin Wen 2 days ago committed by GitHub
commit 27188fb02e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,238 @@
# Whisper Fine-tuning on Common Voice
This example demonstrates how to fine-tune Whisper models on the [Common Voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) dataset using PaddleSpeech.
## Overview
Whisper is a state-of-the-art speech recognition model from OpenAI. This implementation allows you to fine-tune Whisper models on new datasets to improve performance for specific languages, domains, or accents.
## Features
- Complete fine-tuning pipeline for Whisper models on custom datasets
- Flexible configuration via YAML files
- Support for all Whisper model sizes (tiny, base, small, medium, large, etc.)
- Data preparation tools for Common Voice and custom datasets
- Distributed training support with mixed precision
- Gradient accumulation for large batch training
- Learning rate scheduling and optimization techniques
- Evaluation tools with WER/CER metrics
- Command-line inference with both fine-tuned and original Whisper models
- Model export utilities for deployment
- Visualization tools for performance analysis
## Installation
Ensure you have PaddleSpeech installed with all dependencies:
```bash
git clone https://github.com/PaddlePaddle/PaddleSpeech.git
cd PaddleSpeech
pip install -e .
pip install datasets soundfile librosa matplotlib pandas jiwer
```
## Data
We use the Common Voice dataset (version 11.0) available on Hugging Face:
https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0
Other datasets compatible with the pipeline include LibriSpeech, AISHELL, and any dataset that can be converted to the manifest format (see below).
## Unified Command-Line Interface
This example includes a unified CLI for all operations:
```bash
python whisper_cli.py COMMAND [OPTIONS]
```
Available commands:
- `prepare`: Prepare dataset for fine-tuning
- `train`: Fine-tune Whisper model
- `evaluate`: Evaluate model performance
- `infer`: Run inference with fine-tuned or original model
## Data Preparation
To download and prepare the Common Voice dataset:
```bash
python whisper_cli.py prepare --language en --output_dir ./data
```
Options:
- `--language`: Target language code (default: en)
- `--output_dir`: Directory to save preprocessed data (default: ./data)
- `--cache_dir`: Cache directory for HuggingFace datasets
- `--val_size`: Validation set size ratio (default: 0.03)
- `--test_size`: Test set size ratio (default: 0.03)
- `--min_duration`: Minimum audio duration in seconds (default: 0.5)
- `--max_duration`: Maximum audio duration in seconds (default: 30.0)
Manifest format:
```json
{"audio": "path/to/audio.wav", "text": "transcription", "duration": 3.45}
```
## Configuration
Fine-tuning parameters are specified in YAML config files. See `conf/whisper_base.yaml` for a detailed example.
Key configuration sections:
- **model**: Model size, checkpoint path, freeze options
- **data**: Dataset paths, languages, tasks
- **training**: Batch size, learning rate, optimizer settings
- **distributed**: Distributed training options
- **output**: Save paths, logging options
## Training
To fine-tune the Whisper model:
```bash
python whisper_cli.py train --config conf/whisper_base.yaml --resource_path ./resources
```
For distributed training:
```bash
python -m paddle.distributed.launch --gpus "0,1,2,3" whisper_cli.py train --config conf/whisper_base.yaml --distributed True
```
Options:
- `--config`: Path to configuration YAML file
- `--resource_path`: Path to resources directory containing model assets
- `--device`: Device to use (cpu, gpu, xpu)
- `--seed`: Random seed
- `--checkpoint_path`: Path to resume training from checkpoint
- `--distributed`: Enable distributed training
## Evaluation
Evaluate model performance on a test set:
```bash
python whisper_cli.py evaluate --manifest ./data/test_manifest.json --checkpoint ./exp/whisper_fine_tune/epoch_10 --output_dir ./eval_results
```
Options:
- `--manifest`: Path to test manifest file
- `--checkpoint`: Path to model checkpoint
- `--model_size`: Model size if using original Whisper
- `--language`: Language code
- `--output_dir`: Directory to save evaluation results
- `--max_samples`: Maximum number of samples to evaluate
## Inference
For transcribing audio with a fine-tuned model:
```bash
python whisper_cli.py infer --audio_file path/to/audio.wav --checkpoint ./exp/whisper_fine_tune/final
```
For batch processing a directory:
```bash
python whisper_cli.py infer --audio_dir path/to/audio/folder --output_dir ./transcriptions --checkpoint ./exp/whisper_fine_tune/final
```
For inference with the original Whisper models:
```bash
python whisper_cli.py infer --audio_file path/to/audio.wav --use_original --model_size large-v3 --resource_path ./resources
```
Options:
- `--audio_file`: Path to single audio file
- `--audio_dir`: Path to directory with audio files
- `--checkpoint`: Path to fine-tuned checkpoint
- `--use_original`: Use original Whisper model
- `--model_size`: Model size (tiny, base, small, medium, large, etc.)
- `--language`: Language code (or "auto" for detection)
- `--task`: Task type (transcribe or translate)
- `--beam_size`: Beam size for beam search
- `--temperature`: Temperature for sampling
- `--without_timestamps`: Don't include timestamps
## Visualization
Visualize evaluation results:
```bash
python visualize.py --results_file ./eval_results/evaluation_results.json --output_dir ./visualizations
```
Options:
- `--results_file`: Path to evaluation results JSON file
- `--output_dir`: Directory to save visualizations
- `--audio_dir`: Directory with audio files (optional)
- `--num_samples`: Number of individual samples to visualize
- `--show`: Show plots interactively
## Model Export
Export fine-tuned model to inference format:
```bash
python export_model.py --checkpoint ./exp/whisper_fine_tune/final --output_path ./exported_model --model_size base
```
Options:
- `--checkpoint`: Path to model checkpoint
- `--output_path`: Path to save exported model
- `--model_size`: Model size
## Advanced Usage
### Freezing Encoder
To freeze the encoder and only fine-tune the decoder, set the following in your config file:
```yaml
model:
freeze_encoder: true
```
### Gradient Accumulation
For effective training with limited GPU memory, use gradient accumulation:
```yaml
training:
accum_grad: 8 # Accumulate gradients over 8 batches
```
### Mixed Precision
Enable mixed precision training for faster computation:
```yaml
training:
amp: true # Enable automatic mixed precision
```
### Custom Datasets
To use custom datasets, prepare manifest files in the following format:
```json
{"audio": "/absolute/path/to/audio.wav", "text": "transcription text"}
```
Then specify the manifest paths in your config file:
```yaml
data:
train_manifest: path/to/train_manifest.json
dev_manifest: path/to/dev_manifest.json
test_manifest: path/to/test_manifest.json
```
## Reference
- [OpenAI Whisper](https://github.com/openai/whisper)
- [Fine-tuning Whisper](https://huggingface.co/blog/fine-tune-whisper)
- [Paper: Robust Speech Recognition via Large-Scale Weak Supervision](https://arxiv.org/pdf/2212.04356)
- [Common Voice Dataset](https://commonvoice.mozilla.org/)
- [PaddleSpeech Documentation](https://github.com/PaddlePaddle/PaddleSpeech)

@ -0,0 +1,70 @@
# Configuration for fine-tuning Whisper base model
# Data settings
data:
train_manifest: "data/train_manifest.json"
dev_manifest: "data/dev_manifest.json"
test_manifest: "data/test_manifest.json"
target_language: "en" # Language code for fine-tuning
max_duration: 30.0 # Maximum audio duration in seconds
min_duration: 0.5 # Minimum audio duration in seconds
# Model settings
model:
name: "whisper"
size: "base" # Options: tiny, base, small, medium, large, large-v2, large-v3
checkpoint: null # Path to pre-trained checkpoint, null for default
freeze_encoder: false # Whether to freeze the encoder during fine-tuning
use_fp16: true # Whether to use half precision
# Training settings
training:
max_epoch: 20
save_epoch: 1
log_interval: 100
batch_size: 16
num_workers: 4
accum_grad: 1 # Gradient accumulation steps
# Optimizer settings
optimizer: "adamw"
learning_rate: 1e-5
weight_decay: 0.01
scheduler: "cosine"
warmup_ratio: 0.03
max_grad_norm: 1.0
# Regularization
dropout: 0.1
label_smoothing: 0.1
# Mixed precision training
amp_level: "O1"
amp_dtype: "float16"
# Distributed training
distributed:
use_fleet: true
strategy: "standard"
find_unused_parameters: false
# Output settings
output:
checkpoint_dir: "exp/whisper_fine_tune"
save_checkpoint: true
save_interval: 1
keep_checkpoint_max: 5
# Evaluation settings
eval:
eval_batch_size: 16
metrics: ["wer", "cer"]
# Inference settings
inference:
beam_size: 5
min_tokens: 0
max_tokens: 448
temperature: 0.0
language: null # Set to target language code or null to auto-detect
without_timestamps: true

@ -0,0 +1,264 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import time
from pathlib import Path
from typing import Dict, List, Optional, Union
import numpy as np
import paddle
import yaml
from paddlespeech.s2t.models.whisper.whisper import (MODEL_DIMENSIONS, Whisper,
DecodingOptions, log_mel_spectrogram,
transcribe)
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import get_rank, str2bool
from paddlespeech.metrics.wer import word_errors, char_errors
logger = Log(__name__).getlog()
def compute_metrics(references, hypotheses, language="en"):
"""Compute WER and CER metrics."""
total_words = 0
total_chars = 0
total_word_errors = 0
total_char_errors = 0
for ref, hyp in zip(references, hypotheses):
ref = ref.strip()
hyp = hyp.strip()
word_error_count, word_count = word_errors(ref, hyp, language)
char_error_count, char_count = char_errors(ref, hyp, language)
total_words += word_count
total_chars += char_count
total_word_errors += word_error_count
total_char_errors += char_error_count
wer = float(total_word_errors) / max(1, total_words)
cer = float(total_char_errors) / max(1, total_chars)
return {
"wer": wer,
"cer": cer,
"word_errors": total_word_errors,
"word_count": total_words,
"char_errors": total_char_errors,
"char_count": total_chars,
}
def load_model(model_size="base", checkpoint_path=None, resource_path=None):
"""Load Whisper model from checkpoint or pretrained weights."""
model_dims = MODEL_DIMENSIONS[model_size]
model = Whisper(model_dims)
if checkpoint_path:
logger.info(f"Loading model from checkpoint: {checkpoint_path}")
state_dict = paddle.load(checkpoint_path)
model.set_state_dict(state_dict)
elif resource_path:
model_path = os.path.join(resource_path, "whisper", f"whisper-{model_size}.pdparams")
if os.path.exists(model_path):
logger.info(f"Loading pretrained model from: {model_path}")
state_dict = paddle.load(model_path)
model.set_state_dict(state_dict)
else:
logger.error(f"Pretrained model not found at {model_path}")
raise FileNotFoundError(f"Model file not found: {model_path}")
else:
logger.error("Either checkpoint_path or resource_path must be provided")
raise ValueError("Either checkpoint_path or resource_path must be provided")
return model
def evaluate_manifest(
model,
manifest_path,
resource_path,
language=None,
task="transcribe",
temperature=0.0,
beam_size=5,
patience=1.0,
batch_size=16,
without_timestamps=True,
fp16=False,
verbose=True,
output_dir=None,
max_samples=None
):
"""Evaluate Whisper model on a manifest file."""
# Load manifest
with open(manifest_path, 'r', encoding='utf8') as f:
manifest_lines = f.readlines()
# Limit samples if requested
if max_samples:
manifest_lines = manifest_lines[:max_samples]
references = []
hypotheses = []
audio_paths = []
durations = []
# Process each item
start_time = time.time()
model.eval()
for i, line in enumerate(manifest_lines):
if i % 10 == 0:
logger.info(f"Processing item {i+1}/{len(manifest_lines)}")
item = json.loads(line.strip())
audio_path = item["audio"]
reference_text = item["text"]
# Get duration if available
duration = item.get("duration", 0)
durations.append(duration)
audio_paths.append(audio_path)
references.append(reference_text)
# Process audio
try:
mel = log_mel_spectrogram(audio_path, resource_path=resource_path)
# Setup decoding options
decode_options = DecodingOptions(
language=language,
task=task,
temperature=temperature,
beam_size=beam_size,
patience=patience,
without_timestamps=without_timestamps,
fp16=fp16,
)
# Run transcription
with paddle.no_grad():
result = transcribe(
model=model,
mel=mel,
resource_path=resource_path,
verbose=False,
**decode_options.__dict__,
)
hypotheses.append(result["text"])
except Exception as e:
logger.error(f"Error processing {audio_path}: {str(e)}")
hypotheses.append("")
# Compute metrics
elapsed = time.time() - start_time
logger.info(f"Processed {len(manifest_lines)} examples in {elapsed:.2f} seconds")
metrics = compute_metrics(references, hypotheses, language=language if language else "en")
logger.info(f"WER: {metrics['wer']*100:.2f}%, CER: {metrics['cer']*100:.2f}%")
# Save results if output directory provided
if output_dir:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Save detailed results
results = {
"metrics": metrics,
"details": [
{
"audio": audio_path,
"reference": reference,
"hypothesis": hypothesis,
"duration": duration
}
for audio_path, reference, hypothesis, duration in zip(
audio_paths, references, hypotheses, durations
)
]
}
with open(output_dir / "evaluation_results.json", "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
logger.info(f"Detailed evaluation results saved to {output_dir}")
return metrics
def main():
parser = argparse.ArgumentParser(description="Evaluate Whisper ASR Model")
parser.add_argument("--manifest", type=str, required=True, help="Path to manifest file")
parser.add_argument("--output_dir", type=str, default="./results", help="Directory to save results")
parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint from fine-tuning")
parser.add_argument("--resource_path", type=str, default="./resources",
help="Path to resources directory")
# Model options
parser.add_argument("--model_size", type=str, default="base",
choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"],
help="Model size for original Whisper")
# Decoding options
parser.add_argument("--language", type=str, default=None, help="Language code (e.g., en, fr)")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"],
help="Task: transcribe or translate to English")
parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for sampling")
parser.add_argument("--beam_size", type=int, default=5, help="Beam size for beam search")
parser.add_argument("--patience", type=float, default=1.0, help="Beam search patience factor")
parser.add_argument("--without_timestamps", type=str2bool, default=True, help="Don't include timestamps")
parser.add_argument("--batch_size", type=int, default=16, help="Batch size for evaluation")
parser.add_argument("--fp16", type=str2bool, default=False, help="Use half-precision float16")
parser.add_argument("--verbose", type=str2bool, default=True, help="Whether to display verbose logs")
parser.add_argument("--max_samples", type=int, default=None, help="Maximum number of samples to evaluate")
args = parser.parse_args()
# Load model
model = load_model(
model_size=args.model_size,
checkpoint_path=args.checkpoint,
resource_path=args.resource_path
)
# Evaluate
evaluate_manifest(
model=model,
manifest_path=args.manifest,
resource_path=args.resource_path,
language=args.language,
task=args.task,
temperature=args.temperature,
beam_size=args.beam_size,
patience=args.patience,
batch_size=args.batch_size,
without_timestamps=args.without_timestamps,
fp16=args.fp16,
verbose=args.verbose,
output_dir=args.output_dir,
max_samples=args.max_samples
)
if __name__ == "__main__":
main()

@ -0,0 +1,135 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from pathlib import Path
import paddle
from paddlespeech.s2t.models.whisper.whisper import MODEL_DIMENSIONS, Whisper
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
def export_encoder(model, save_path, input_shape=(1, 80, 3000)):
"""Export encoder part of Whisper to inference model."""
model.eval()
# Create save directory if not exists
save_dir = os.path.dirname(save_path)
os.makedirs(save_dir, exist_ok=True)
# Define input spec
mel_spec = paddle.static.InputSpec(shape=input_shape, dtype='float32', name='mel')
# Export encoder model
encoder_path = f"{save_path}_encoder"
paddle.jit.save(
layer=model.encoder,
path=encoder_path,
input_spec=[mel_spec]
)
logger.info(f"Encoder model exported to {encoder_path}")
return encoder_path
def export_decoder(model, save_path, input_shape_tokens=(1, 448), input_shape_features=(1, 1500, 768)):
"""Export decoder part of Whisper to inference model."""
model.eval()
# Create save directory if not exists
save_dir = os.path.dirname(save_path)
os.makedirs(save_dir, exist_ok=True)
# Define input spec
token_spec = paddle.static.InputSpec(shape=input_shape_tokens, dtype='int64', name='tokens')
audio_features_spec = paddle.static.InputSpec(shape=input_shape_features, dtype='float32', name='audio_features')
# Create a wrapper to match the exact API of the decoder
class DecoderWrapper(paddle.nn.Layer):
def __init__(self, decoder):
super().__init__()
self.decoder = decoder
def forward(self, tokens, audio_features):
return self.decoder(tokens, audio_features)
wrapper = DecoderWrapper(model.decoder)
# Export decoder model
decoder_path = f"{save_path}_decoder"
paddle.jit.save(
layer=wrapper,
path=decoder_path,
input_spec=[token_spec, audio_features_spec]
)
logger.info(f"Decoder model exported to {decoder_path}")
return decoder_path
def export_whisper(model, save_path):
"""Export full Whisper model to static graph models."""
export_encoder(model, save_path)
export_decoder(model, save_path)
# Export model info
dims = model.dims
model_info = {
"n_mels": dims.n_mels,
"n_vocab": dims.n_vocab,
"n_audio_ctx": dims.n_audio_ctx,
"n_audio_state": dims.n_audio_state,
"n_audio_head": dims.n_audio_head,
"n_audio_layer": dims.n_audio_layer,
"n_text_ctx": dims.n_text_ctx,
"n_text_state": dims.n_text_state,
"n_text_head": dims.n_text_head,
"n_text_layer": dims.n_text_layer
}
# Save model info
import json
with open(f"{save_path}_info.json", "w") as f:
json.dump(model_info, f, indent=4)
logger.info(f"Model info saved to {save_path}_info.json")
def main():
parser = argparse.ArgumentParser(description="Export Whisper model to inference format")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint")
parser.add_argument("--output_path", type=str, required=True, help="Path to save exported model")
parser.add_argument("--model_size", type=str, default="base",
choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"],
help="Model size")
args = parser.parse_args()
# Create model
model_dims = MODEL_DIMENSIONS[args.model_size]
model = Whisper(model_dims)
# Load checkpoint
state_dict = paddle.load(args.checkpoint)
model.set_state_dict(state_dict)
# Export model
export_whisper(model, args.output_path)
logger.info(f"Model exported to {args.output_path}")
if __name__ == "__main__":
main()

@ -0,0 +1,323 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import time
from pathlib import Path
from typing import Dict, List, Optional, Union
import numpy as np
import paddle
import yaml
from paddlespeech.s2t.models.whisper.whisper import (CHUNK_LENGTH, MODEL_DIMENSIONS, N_MELS,
DecodingOptions, Whisper,
log_mel_spectrogram, pad_or_trim,
transcribe)
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import get_files_by_ext, str2bool
logger = Log(__name__).getlog()
def load_audio(file_path, sample_rate=16000):
"""Load audio file and convert to 16kHz mono if needed."""
try:
import librosa
audio, sr = librosa.load(file_path, sr=sample_rate, mono=True)
return audio
except Exception as e:
logger.error(f"Error loading audio file {file_path}: {str(e)}")
return None
def get_model_path(size, resource_path):
"""Get the model path based on size."""
return os.path.join(resource_path, "whisper", f"whisper-{size}.pdparams")
def load_whisper_model(model_size="base", checkpoint_path=None, resource_path=None):
"""Load Whisper model from checkpoint or pretrained weights."""
model_dims = MODEL_DIMENSIONS[model_size]
model = Whisper(model_dims)
if checkpoint_path:
logger.info(f"Loading model from checkpoint: {checkpoint_path}")
state_dict = paddle.load(checkpoint_path)
model.set_state_dict(state_dict)
elif resource_path:
model_path = get_model_path(model_size, resource_path)
if os.path.exists(model_path):
logger.info(f"Loading pretrained model from: {model_path}")
state_dict = paddle.load(model_path)
model.set_state_dict(state_dict)
else:
logger.error(f"Pretrained model not found at {model_path}")
raise FileNotFoundError(f"Model file not found: {model_path}")
else:
logger.error("Either checkpoint_path or resource_path must be provided")
raise ValueError("Either checkpoint_path or resource_path must be provided")
return model
def transcribe_file(
model,
audio_file,
resource_path,
language=None,
task="transcribe",
temperature=0.0,
beam_size=5,
best_of=5,
patience=1.0,
length_penalty=1.0,
suppress_tokens="-1",
initial_prompt=None,
condition_on_previous_text=True,
without_timestamps=False,
fp16=False,
verbose=True,
):
"""Transcribe a single audio file."""
# Check if file exists
if not os.path.exists(audio_file):
logger.error(f"Audio file not found: {audio_file}")
return None
# Load and process audio
logger.info(f"Processing audio file: {audio_file}")
try:
mel = log_mel_spectrogram(audio_file, resource_path=resource_path)
except Exception as e:
logger.error(f"Error processing audio: {str(e)}")
return None
# Setup decoding options
decode_options = DecodingOptions(
language=language,
task=task,
temperature=temperature,
beam_size=beam_size,
best_of=best_of,
patience=patience,
length_penalty=length_penalty,
suppress_tokens=suppress_tokens,
prompt=initial_prompt,
without_timestamps=without_timestamps,
fp16=fp16,
)
# Run transcription
logger.info("Running transcription...")
start_time = time.time()
model.eval()
with paddle.no_grad():
result = transcribe(
model=model,
mel=mel,
resource_path=resource_path,
verbose=verbose,
condition_on_previous_text=condition_on_previous_text,
**decode_options.__dict__,
)
elapsed = time.time() - start_time
logger.info(f"Transcription completed in {elapsed:.2f} seconds")
return result
def batch_transcribe(
model,
audio_dir,
output_dir,
resource_path,
extensions=["wav", "mp3", "flac", "m4a", "ogg"],
**transcribe_kwargs
):
"""Transcribe all audio files in a directory."""
# Create output directory
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Get list of audio files
audio_files = []
for ext in extensions:
audio_files.extend(get_files_by_ext(audio_dir, ext))
if not audio_files:
logger.error(f"No audio files found in {audio_dir} with extensions {extensions}")
return
logger.info(f"Found {len(audio_files)} audio files to process")
# Process each file
results = {}
for audio_file in audio_files:
base_name = os.path.basename(audio_file)
file_name = os.path.splitext(base_name)[0]
logger.info(f"Processing file: {base_name}")
result = transcribe_file(model, audio_file, resource_path, **transcribe_kwargs)
if result:
# Save transcription
output_file = output_dir / f"{file_name}.txt"
with open(output_file, "w", encoding="utf-8") as f:
f.write(result["text"].strip())
# Save detailed results
results[base_name] = {
"text": result["text"].strip(),
"segments": result.get("segments", []),
"language": result.get("language", ""),
}
# Save all results as JSON
import json
with open(output_dir / "all_transcripts.json", "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
logger.info(f"All transcriptions saved to {output_dir}")
return results
def main():
parser = argparse.ArgumentParser(description="Whisper ASR Inference")
parser.add_argument("--audio_file", type=str, help="Path to audio file for transcription")
parser.add_argument("--audio_dir", type=str, help="Path to directory containing audio files")
parser.add_argument("--output_dir", type=str, default="./transcripts", help="Output directory for transcriptions")
parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint from fine-tuning")
parser.add_argument("--resource_path", type=str, default="./resources",
help="Path to resources directory containing original models and assets")
# Model options
parser.add_argument("--use_original", type=str2bool, default=False,
help="Use original Whisper model instead of fine-tuned")
parser.add_argument("--model_size", type=str, default="base",
choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"],
help="Model size for original Whisper")
# Decoding options
parser.add_argument("--language", type=str, default=None, help="Language code (e.g., en, fr, auto for detection)")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"],
help="Task: transcribe or translate to English")
parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for sampling")
parser.add_argument("--beam_size", type=int, default=5, help="Beam size for beam search")
parser.add_argument("--best_of", type=int, default=5, help="Number of candidates when sampling with temp > 0")
parser.add_argument("--patience", type=float, default=1.0, help="Beam search patience factor")
parser.add_argument("--length_penalty", type=float, default=1.0, help="Exponential length penalty")
parser.add_argument("--suppress_tokens", type=str, default="-1", help="Comma-separated list of token ids to suppress")
parser.add_argument("--initial_prompt", type=str, default=None, help="Optional text to provide as prompt")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True,
help="Whether to condition on previous text")
parser.add_argument("--without_timestamps", type=str2bool, default=False, help="Don't include timestamps")
parser.add_argument("--fp16", type=str2bool, default=False, help="Use half-precision float16")
parser.add_argument("--verbose", type=str2bool, default=True, help="Whether to display the text being decoded")
args = parser.parse_args()
# Validate arguments
if not args.audio_file and not args.audio_dir:
parser.error("Either --audio_file or --audio_dir must be specified")
if args.use_original:
# Use original Whisper model
if not args.resource_path:
parser.error("--resource_path must be specified when using original model")
model = load_whisper_model(
model_size=args.model_size,
resource_path=args.resource_path
)
else:
# Use fine-tuned model
if not args.checkpoint:
parser.error("--checkpoint must be specified when not using original model")
# Determine model size from checkpoint directory structure
model_size = "base" # Default
if "tiny" in args.checkpoint:
model_size = "tiny"
elif "small" in args.checkpoint:
model_size = "small"
elif "medium" in args.checkpoint:
model_size = "medium"
elif "large" in args.checkpoint:
model_size = "large"
model = load_whisper_model(
model_size=model_size,
checkpoint_path=args.checkpoint,
resource_path=args.resource_path
)
# Prepare transcription keyword arguments
transcribe_kwargs = {
"language": args.language,
"task": args.task,
"temperature": args.temperature,
"beam_size": args.beam_size,
"best_of": args.best_of,
"patience": args.patience,
"length_penalty": args.length_penalty,
"suppress_tokens": args.suppress_tokens,
"initial_prompt": args.initial_prompt,
"condition_on_previous_text": args.condition_on_previous_text,
"without_timestamps": args.without_timestamps,
"fp16": args.fp16,
"verbose": args.verbose,
}
# Run transcription
if args.audio_file:
# Single file transcription
result = transcribe_file(model, args.audio_file, args.resource_path, **transcribe_kwargs)
if result:
print("-" * 40)
print("Transcription:")
print(result["text"])
print("-" * 40)
# Save to output directory if specified
if args.output_dir:
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
file_name = os.path.splitext(os.path.basename(args.audio_file))[0]
with open(output_dir / f"{file_name}.txt", "w", encoding="utf-8") as f:
f.write(result["text"].strip())
import json
with open(output_dir / f"{file_name}.json", "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"Results saved to {output_dir}")
elif args.audio_dir:
# Batch transcription
batch_transcribe(
model,
args.audio_dir,
args.output_dir,
args.resource_path,
**transcribe_kwargs
)
if __name__ == "__main__":
main()

@ -0,0 +1,197 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import random
import shutil
from pathlib import Path
from typing import Dict, List, Optional, Union
import datasets
import numpy as np
import soundfile
import tqdm
import yaml
def prepare_common_voice(language: str,
output_dir: str,
cache_dir: Optional[str] = None,
val_size: float = 0.03,
test_size: float = 0.03,
min_duration: float = 0.5,
max_duration: float = 30.0,
seed: int = 42):
"""
Prepare Mozilla Common Voice dataset for Whisper fine-tuning.
Args:
language: Language code (e.g., "en", "fr", "es")
output_dir: Directory to save preprocessed data
cache_dir: Cache directory for HuggingFace datasets
val_size: Validation set size as a fraction of total data
test_size: Test set size as a fraction of total data
min_duration: Minimum audio duration in seconds
max_duration: Maximum audio duration in seconds
seed: Random seed for reproducibility
"""
print(f"Preparing Common Voice dataset for language: {language}")
# Create output directories
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
(output_dir / "wavs").mkdir(exist_ok=True)
# Load Common Voice dataset
print("Loading Common Voice dataset from HuggingFace...")
try:
common_voice = datasets.load_dataset(
"mozilla-foundation/common_voice_11_0",
language,
cache_dir=cache_dir,
trust_remote_code=True
)
except Exception as e:
print(f"Error loading dataset: {e}")
print("Make sure you have access to the Common Voice dataset on HuggingFace.")
return
# Filter and process data
print("Processing dataset...")
# Function to process and filter examples
def process_example(example):
audio = example['audio']
sample_rate = audio['sampling_rate']
# Calculate duration
duration = len(audio['array']) / sample_rate
# Filter by duration
if duration < min_duration or duration > max_duration:
return None
return {
"path": None, # Will be filled later
"audio": audio,
"text": example['sentence'],
"duration": duration,
}
# Process all splits
all_data = []
for split in ['train', 'validation', 'test']:
if split in common_voice:
split_data = []
for example in tqdm.tqdm(common_voice[split], desc=f"Processing {split} set"):
processed = process_example(example)
if processed:
split_data.append(processed)
all_data.extend(split_data)
# Shuffle and split data
random.seed(seed)
random.shuffle(all_data)
total_size = len(all_data)
val_count = max(1, int(total_size * val_size))
test_count = max(1, int(total_size * test_size))
train_count = total_size - val_count - test_count
train_data = all_data[:train_count]
val_data = all_data[train_count:train_count + val_count]
test_data = all_data[train_count + val_count:]
print(f"Dataset split - Train: {len(train_data)}, Dev: {len(val_data)}, Test: {len(test_data)}")
# Save audio files and create manifest files
def save_manifest(data, name):
manifest = []
for i, item in enumerate(tqdm.tqdm(data, desc=f"Saving {name} files")):
# Generate filename
filename = f"{name}_{i:08d}.wav"
filepath = str(output_dir / "wavs" / filename)
# Save audio file
soundfile.write(
filepath,
item["audio"]["array"],
item["audio"]["sampling_rate"]
)
# Add to manifest
manifest_item = {
"utt": f"{name}_{i:08d}",
"audio": filepath,
"text": item["text"],
"duration": item["duration"]
}
manifest.append(manifest_item)
# Write manifest file
with open(output_dir / f"{name}_manifest.json", "w", encoding="utf-8") as f:
for item in manifest:
f.write(json.dumps(item, ensure_ascii=False) + "\n")
save_manifest(train_data, "train")
save_manifest(val_data, "dev")
save_manifest(test_data, "test")
print(f"Dataset preparation complete. Files saved to {output_dir}")
# Save config
stats = {
"language": language,
"total_examples": total_size,
"train_examples": len(train_data),
"dev_examples": len(val_data),
"test_examples": len(test_data),
"min_duration": min_duration,
"max_duration": max_duration,
}
with open(output_dir / "stats.yaml", "w") as f:
yaml.dump(stats, f)
print("Data preparation complete!")
def main():
parser = argparse.ArgumentParser(description="Prepare Common Voice dataset for Whisper fine-tuning")
parser.add_argument("--language", type=str, default="en", help="Language code (e.g., en, fr, es)")
parser.add_argument("--output_dir", type=str, default="./data", help="Directory to save preprocessed data")
parser.add_argument("--cache_dir", type=str, default=None, help="Cache directory for HuggingFace datasets")
parser.add_argument("--val_size", type=float, default=0.03, help="Validation set size as fraction")
parser.add_argument("--test_size", type=float, default=0.03, help="Test set size as fraction")
parser.add_argument("--min_duration", type=float, default=0.5, help="Minimum audio duration in seconds")
parser.add_argument("--max_duration", type=float, default=30.0, help="Maximum audio duration in seconds")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
args = parser.parse_args()
prepare_common_voice(
language=args.language,
output_dir=args.output_dir,
cache_dir=args.cache_dir,
val_size=args.val_size,
test_size=args.test_size,
min_duration=args.min_duration,
max_duration=args.max_duration,
seed=args.seed
)
if __name__ == "__main__":
main()

@ -0,0 +1,510 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import sys
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import paddle
import paddle.distributed as dist
import paddle.nn as nn
import paddle.nn.functional as F
import yaml
from paddle.io import DataLoader, Dataset
from paddle.optimizer import AdamW
from paddle.optimizer.lr import CosineAnnealingDecay, LinearWarmup
from paddlespeech.s2t.models.whisper.tokenizer import get_tokenizer
from paddlespeech.s2t.models.whisper.whisper import (CHUNK_LENGTH, MODEL_DIMENSIONS, N_MELS, N_SAMPLES, Whisper,
log_mel_spectrogram, pad_or_trim)
from paddlespeech.s2t.utils.checkpoint import Checkpoint
from paddlespeech.s2t.training.extensions.evaluator import StandardEvaluator
from paddlespeech.s2t.training.extensions.reporter import ObsScope, Reporter
from paddlespeech.s2t.training.scheduler import LRSchedulerFactory
from paddlespeech.s2t.training.trainer import Trainer
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import get_rank, str2bool
logger = Log(__name__).getlog()
class WhisperDataset(Dataset):
"""Dataset for Whisper fine-tuning"""
def __init__(
self,
manifest_path: str,
tokenizer,
target_language: str = "en",
max_duration: float = 30.0,
min_duration: float = 0.5,
sample_rate: int = 16000,
resource_path: str = '',
):
"""Initialize the dataset.
Args:
manifest_path: Path to manifest file with audio paths and transcripts
tokenizer: Whisper tokenizer
target_language: Target language code
max_duration: Maximum audio duration
min_duration: Minimum audio duration
sample_rate: Audio sample rate
resource_path: Path to resources directory
"""
super().__init__()
self.tokenizer = tokenizer
self.target_language = target_language
self.sample_rate = sample_rate
self.resource_path = resource_path
# Load manifest
with open(manifest_path, 'r', encoding='utf8') as f:
manifest_lines = f.readlines()
self.data = []
for line in manifest_lines:
try:
item = json.loads(line.strip())
duration = item.get('duration', 0)
if min_duration <= duration <= max_duration:
self.data.append(item)
except Exception as e:
logger.warning(f"Error parsing line in manifest: {e}")
logger.info(f"Loaded {len(self.data)} examples from {manifest_path}")
# Generate special tokens and language tokens
self.special_tokens = {
"sot": self.tokenizer.sot,
"eot": self.tokenizer.eot,
"translate": self.tokenizer.translate,
"transcribe": self.tokenizer.transcribe,
}
# Get language token
self.language_token = self.tokenizer.language_tokens[self.target_language]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
# Load audio
audio_path = item["audio"]
try:
# Process audio to mel spectrogram
mel = log_mel_spectrogram(audio_path, n_mels=N_MELS, resource_path=self.resource_path)
mel = pad_or_trim(mel, N_SAMPLES)
# Get text
text = item["text"]
# Create prompt tokens
prompt_tokens = [
self.special_tokens["sot"],
self.language_token,
self.special_tokens["transcribe"],
]
# Encode the text
target_tokens = (
prompt_tokens +
self.tokenizer.encode(text) +
[self.special_tokens["eot"]]
)
return {
"mel": mel,
"target_tokens": np.array(target_tokens, dtype=np.int64),
"text": text
}
except Exception as e:
logger.warning(f"Error processing {audio_path}: {e}")
# Return a dummy sample that will be filtered in collate_fn
return None
@staticmethod
def collate_fn(batch):
"""Collate function for DataLoader"""
# Filter None samples
batch = [sample for sample in batch if sample is not None]
if not batch:
return None
# Get maximum sequence length in this batch
max_target_len = max(len(sample["target_tokens"]) for sample in batch)
# Initialize tensors
mel_specs = []
token_ids = []
labels = []
for sample in batch:
target_tokens = sample["target_tokens"]
target_len = len(target_tokens)
# Prepare inputs and labels for causal LM training
# Input tokens are shifted right
input_tokens = np.zeros(max_target_len, dtype=np.int64)
input_tokens[:target_len-1] = target_tokens[:target_len-1] # Exclude EOS
# Labels are shifted left and padded with -100 (ignore index)
label = np.full(max_target_len, -100, dtype=np.int64)
label[:target_len-1] = target_tokens[1:target_len] # Start from first token after SOT
# Add to lists
mel_specs.append(sample["mel"])
token_ids.append(input_tokens)
labels.append(label)
# Convert to tensors
mel_specs = paddle.to_tensor(np.array(mel_specs), dtype=paddle.float32)
token_ids = paddle.to_tensor(np.array(token_ids), dtype=paddle.int64)
labels = paddle.to_tensor(np.array(labels), dtype=paddle.int64)
return {
"mel": mel_specs,
"tokens": token_ids,
"labels": labels,
}
class WhisperTrainer(Trainer):
"""Trainer for Whisper fine-tuning"""
def __init__(self, config, args):
"""Initialize the trainer.
Args:
config: Training configuration
args: Command-line arguments
"""
super().__init__()
self.config = config
self.args = args
self.resource_path = args.resource_path
# Set random seed
if args.seed is not None:
paddle.seed(args.seed)
np.random.seed(args.seed)
# Initialize distributed training if needed
if args.distributed:
dist.init_parallel_env()
# Build model and optimizer
self.model = self._build_model()
self.optimizer = self._build_optimizer()
# Build tokenizer
self.tokenizer = get_tokenizer(
multilingual=self.model.is_multilingual,
resource_path=self.resource_path,
language=config['data']['target_language'],
task="transcribe"
)
# Initialize checkpoint class
self._init_checkpoint()
def _build_model(self):
"""Build Whisper model"""
config = self.config
model_size = config['model']['size']
# Load model dimensions
model_dims = MODEL_DIMENSIONS[model_size]
# Create the model
model = Whisper(model_dims)
# Load checkpoint if provided
checkpoint_path = config['model']['checkpoint']
if checkpoint_path:
logger.info(f"Loading checkpoint from {checkpoint_path}")
state_dict = paddle.load(checkpoint_path)
model.set_state_dict(state_dict)
# Freeze encoder if needed
if config['model'].get('freeze_encoder', False):
logger.info("Freezing encoder parameters")
for param in model.encoder.parameters():
param.stop_gradient = True
# Handle distributed training
if self.args.distributed:
model = paddle.DataParallel(model)
return model
def _build_optimizer(self):
"""Build optimizer and learning rate scheduler"""
config = self.config
model = self.model
# Set learning rate
learning_rate = config['training']['learning_rate']
# Apply weight decay
decay_params = [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]
# Build scheduler
scheduler_name = config['training'].get('scheduler', 'linear')
num_training_steps = self.config['training'].get('max_steps', 100000)
# Calculate warmup steps
warmup_ratio = config['training'].get('warmup_ratio', 0.1)
warmup_steps = int(num_training_steps * warmup_ratio)
if scheduler_name == 'cosine':
lr_scheduler = LinearWarmup(
CosineAnnealingDecay(learning_rate, num_training_steps - warmup_steps),
warmup_steps,
0.0,
learning_rate
)
else: # default to linear
lr_scheduler = paddle.optimizer.lr.LinearWarmup(
paddle.optimizer.lr.PolynomialDecay(
learning_rate=learning_rate,
decay_steps=num_training_steps - warmup_steps,
end_lr=0.0,
power=1.0),
warmup_steps,
0.0,
learning_rate
)
# Create optimizer
optimizer = AdamW(
learning_rate=lr_scheduler,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
parameters=model.parameters(),
weight_decay=config['training'].get('weight_decay', 0.01),
grad_clip=nn.ClipGradByNorm(config['training'].get('max_grad_norm', 1.0)),
apply_decay_param_fun=lambda x: x in decay_params
)
return optimizer
def _init_checkpoint(self):
"""Initialize checkpoint for saving and loading"""
config = self.config
args = self.args
checkpoint_dir = Path(config['output']['checkpoint_dir'])
checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.checkpoint = Checkpoint(
checkpoint_dir=checkpoint_dir,
model=self.model,
optimizer=self.optimizer,
infos=dict(),
visualizer=None,
**{"epoch": 0}
)
# Try to load from checkpoint if provided
if args.checkpoint_path:
self.checkpoint.load_parameters(args.checkpoint_path)
def train(self):
"""Run training"""
config = self.config
args = self.args
# Create dataset
train_dataset = WhisperDataset(
manifest_path=config['data']['train_manifest'],
tokenizer=self.tokenizer,
target_language=config['data']['target_language'],
max_duration=config['data']['max_duration'],
min_duration=config['data']['min_duration'],
resource_path=self.resource_path,
)
dev_dataset = WhisperDataset(
manifest_path=config['data']['dev_manifest'],
tokenizer=self.tokenizer,
target_language=config['data']['target_language'],
max_duration=config['data']['max_duration'],
min_duration=config['data']['min_duration'],
resource_path=self.resource_path,
)
# Create data loaders
train_dataloader = DataLoader(
train_dataset,
batch_size=config['training']['batch_size'],
shuffle=True,
num_workers=config['training'].get('num_workers', 4),
collate_fn=WhisperDataset.collate_fn,
drop_last=True,
)
dev_dataloader = DataLoader(
dev_dataset,
batch_size=config['eval']['eval_batch_size'],
shuffle=False,
num_workers=config['training'].get('num_workers', 4),
collate_fn=WhisperDataset.collate_fn,
)
# Setup training
max_epoch = config['training']['max_epoch']
accum_grad = config['training'].get('accum_grad', 1)
log_interval = config['training'].get('log_interval', 100)
save_interval = config['output'].get('save_interval', 1)
# Initialize reporter for logging
reporter = Reporter()
reporter.register(self.model, "model")
# Initialize evaluator
evaluator = StandardEvaluator(self.model)
# Training loop
for epoch in range(max_epoch):
self.model.train()
train_loss = 0
num_batches = 0
start_time = time.time()
for batch_idx, batch in enumerate(train_dataloader):
if batch is None:
continue
mel = batch["mel"]
tokens = batch["tokens"]
labels = batch["labels"]
# Forward pass
audio_features = self.model.embed_audio(mel)
logits = self.model.logits(tokens, audio_features)
# Compute loss
loss = F.cross_entropy(
logits.reshape([-1, logits.shape[-1]]),
labels.reshape([-1]),
ignore_index=-100
)
# Scale loss for gradient accumulation
if accum_grad > 1:
loss = loss / accum_grad
# Backward pass
loss.backward()
# Update parameters every accum_grad steps
if (batch_idx + 1) % accum_grad == 0:
self.optimizer.step()
self.optimizer.clear_grad()
# Logging
train_loss += loss.item() * accum_grad
num_batches += 1
if batch_idx % log_interval == 0 and get_rank() == 0:
elapsed_time = time.time() - start_time
logger.info(f"Epoch {epoch}/{max_epoch} | Batch {batch_idx}/{len(train_dataloader)} | "
f"Loss: {loss.item()*accum_grad:.4f} | {elapsed_time:.2f}s elapsed")
# End of epoch
avg_train_loss = train_loss / num_batches if num_batches > 0 else float('inf')
logger.info(f"Epoch {epoch}/{max_epoch} | Average Train Loss: {avg_train_loss:.4f}")
# Evaluation
self.model.eval()
dev_losses = []
with paddle.no_grad():
for batch in dev_dataloader:
if batch is None:
continue
mel = batch["mel"]
tokens = batch["tokens"]
labels = batch["labels"]
# Forward pass
audio_features = self.model.embed_audio(mel)
logits = self.model.logits(tokens, audio_features)
# Compute loss
loss = F.cross_entropy(
logits.reshape([-1, logits.shape[-1]]),
labels.reshape([-1]),
ignore_index=-100
)
dev_losses.append(loss.item())
avg_dev_loss = sum(dev_losses) / len(dev_losses) if dev_losses else float('inf')
logger.info(f"Epoch {epoch}/{max_epoch} | Validation Loss: {avg_dev_loss:.4f}")
# Update checkpoint info
self.checkpoint.infos["epoch"] = epoch
self.checkpoint.infos["train_loss"] = avg_train_loss
self.checkpoint.infos["dev_loss"] = avg_dev_loss
# Save checkpoint
if epoch % save_interval == 0 and get_rank() == 0:
self.checkpoint.save_parameters(tag=f"epoch_{epoch}")
# Save final model
if get_rank() == 0:
self.checkpoint.save_parameters(tag="final")
logger.info(f"Training completed. Final model saved at {self.checkpoint.checkpoint_dir}")
def main():
parser = argparse.ArgumentParser(description="Train Whisper model")
parser.add_argument("--config", type=str, required=True, help="Path to configuration file")
parser.add_argument("--resource_path", type=str, default="./resources", help="Path to resources directory")
parser.add_argument("--device", type=str, default="gpu", choices=["cpu", "gpu", "xpu"], help="Device to use")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to checkpoint to resume from")
parser.add_argument("--distributed", type=str2bool, default=False, help="Enable distributed training")
args = parser.parse_args()
# Load configuration
with open(args.config) as f:
config = yaml.safe_load(f)
# Set device
paddle.set_device(args.device)
trainer = WhisperTrainer(config, args)
trainer.train()
if __name__ == "__main__":
import json
main()

@ -0,0 +1,295 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import soundfile as sf
from matplotlib.ticker import MaxNLocator
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
def plot_waveform(audio_path, output_dir=None, show=True):
"""Plot waveform of audio file."""
try:
audio, sr = sf.read(audio_path)
if audio.ndim > 1:
audio = audio[:, 0] # Take first channel if stereo
duration = len(audio) / sr
time_axis = np.linspace(0, duration, len(audio))
plt.figure(figsize=(12, 4))
plt.plot(time_axis, audio, color='#1f77b4')
plt.title(f"Waveform: {os.path.basename(audio_path)}")
plt.xlabel("Time (seconds)")
plt.ylabel("Amplitude")
plt.grid(alpha=0.3)
if output_dir:
output_path = os.path.join(output_dir, f"{os.path.splitext(os.path.basename(audio_path))[0]}_waveform.png")
plt.savefig(output_path, dpi=150, bbox_inches='tight')
logger.info(f"Waveform plot saved to {output_path}")
if show:
plt.show()
else:
plt.close()
return True
except Exception as e:
logger.error(f"Error plotting waveform: {e}")
return False
def plot_transcription_comparison(reference, hypothesis, audio_path=None, output_dir=None, show=True):
"""Plot comparison between reference and hypothesis transcriptions."""
try:
fig, ax = plt.subplots(figsize=(12, 6))
# Title with audio file name if provided
title = "Reference vs. Hypothesis"
if audio_path:
title += f": {os.path.basename(audio_path)}"
plt.title(title)
# Create comparison table
rows = []
rows.append(["Reference", reference])
rows.append(["Hypothesis", hypothesis])
# Calculate word-level differences
ref_words = reference.strip().split()
hyp_words = hypothesis.strip().split()
from difflib import SequenceMatcher
matcher = SequenceMatcher(None, ref_words, hyp_words)
# Generate color-coded text for differences
ref_colored = []
hyp_colored = []
for op, i1, i2, j1, j2 in matcher.get_opcodes():
if op == 'equal':
ref_colored.extend(ref_words[i1:i2])
hyp_colored.extend(hyp_words[j1:j2])
elif op == 'replace':
ref_colored.extend([f"**{w}**" for w in ref_words[i1:i2]])
hyp_colored.extend([f"**{w}**" for w in hyp_words[j1:j2]])
elif op == 'delete':
ref_colored.extend([f"**{w}**" for w in ref_words[i1:i2]])
elif op == 'insert':
hyp_colored.extend([f"**{w}**" for w in hyp_words[j1:j2]])
rows.append(["Ref (Diff)", " ".join(ref_colored)])
rows.append(["Hyp (Diff)", " ".join(hyp_colored)])
# Calculate error metrics
word_error = sum(1 for i, j in zip(ref_words, hyp_words) if i != j)
if len(ref_words) > 0:
word_error_rate = word_error / len(ref_words)
else:
word_error_rate = 0
rows.append(["Word Error", f"{word_error}/{len(ref_words)} ({word_error_rate:.2%})"])
# Create table
ax.axis('tight')
ax.axis('off')
table = ax.table(cellText=rows, colWidths=[0.15, 0.85], loc='center', cellLoc='left')
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 1.5)
# Highlight difference rows with light blue background
for i in [2, 3]:
for j in range(2):
cell = table._cells[(i, j)]
cell.set_facecolor('#E6F3FF')
if output_dir:
base_name = os.path.splitext(os.path.basename(audio_path))[0] if audio_path else "comparison"
output_path = os.path.join(output_dir, f"{base_name}_comparison.png")
plt.savefig(output_path, dpi=150, bbox_inches='tight')
logger.info(f"Comparison plot saved to {output_path}")
if show:
plt.show()
else:
plt.close()
return True
except Exception as e:
logger.error(f"Error plotting transcription comparison: {e}")
return False
def plot_evaluation_metrics(results_file, output_dir=None, show=True):
"""Plot evaluation metrics from results file."""
try:
with open(results_file, 'r', encoding='utf-8') as f:
results = json.load(f)
metrics = results['metrics']
details = results['details']
# Extract duration information
durations = [item['duration'] for item in details if 'duration' in item]
if not durations:
durations = [0] * len(details) # Default if no durations available
# Calculate per-sample WER
from paddlespeech.metrics.wer import word_errors
wers = []
for item in details:
ref = item['reference']
hyp = item['hypothesis']
word_error_count, word_count = word_errors(ref, hyp)
wer = word_error_count / max(1, word_count)
wers.append(wer * 100) # Convert to percentage
# Create plots
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# Overall metrics
ax = axes[0, 0]
metric_names = ['WER', 'CER']
metric_values = [metrics['wer'] * 100, metrics['cer'] * 100] # Convert to percentage
ax.bar(metric_names, metric_values, color=['#1f77b4', '#ff7f0e'])
ax.set_title('Overall Error Metrics')
ax.set_ylabel('Error Rate (%)')
ax.grid(axis='y', alpha=0.3)
for i, v in enumerate(metric_values):
ax.text(i, v + 0.5, f"{v:.2f}%", ha='center')
# WER vs Duration scatter plot
ax = axes[0, 1]
ax.scatter(durations, wers, alpha=0.6)
ax.set_title('WER vs Audio Duration')
ax.set_xlabel('Duration (seconds)')
ax.set_ylabel('WER (%)')
ax.grid(alpha=0.3)
# WER distribution histogram
ax = axes[1, 0]
ax.hist(wers, bins=20, color='#2ca02c', alpha=0.7)
ax.set_title('WER Distribution')
ax.set_xlabel('WER (%)')
ax.set_ylabel('Number of Samples')
ax.grid(alpha=0.3)
# Sample count by WER range
ax = axes[1, 1]
wer_ranges = [0, 5, 10, 20, 30, 50, 100]
wer_bins = pd.cut(wers, wer_ranges, right=False)
wer_counts = pd.value_counts(wer_bins).sort_index()
wer_labels = [f"{wer_ranges[i]}-{wer_ranges[i+1]}%" for i in range(len(wer_ranges)-1)]
ax.bar(wer_labels, wer_counts, color='#d62728')
ax.set_title('Sample Count by WER Range')
ax.set_xlabel('WER Range')
ax.set_ylabel('Number of Samples')
ax.tick_params(axis='x', rotation=45)
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
if output_dir:
output_path = os.path.join(output_dir, "evaluation_metrics.png")
plt.savefig(output_path, dpi=150, bbox_inches='tight')
logger.info(f"Metrics plots saved to {output_path}")
if show:
plt.show()
else:
plt.close()
return True
except Exception as e:
logger.error(f"Error plotting evaluation metrics: {e}")
return False
def visualize_results(results_file, output_dir=None, audio_dir=None, num_samples=5, show=True):
"""Visualize evaluation results."""
try:
# Create output directory if needed
if output_dir:
os.makedirs(output_dir, exist_ok=True)
# Load results
with open(results_file, 'r', encoding='utf-8') as f:
results = json.load(f)
# Plot overall metrics
plot_evaluation_metrics(results_file, output_dir, show)
# Plot individual examples
if num_samples > 0:
details = results['details']
samples = details[:num_samples]
for i, sample in enumerate(samples):
audio_path = sample['audio']
reference = sample['reference']
hypothesis = sample['hypothesis']
# If audio directory is provided, use audio files from there
if audio_dir:
base_name = os.path.basename(audio_path)
audio_path = os.path.join(audio_dir, base_name)
# Plot waveform if audio file exists
if os.path.exists(audio_path):
plot_waveform(audio_path, output_dir, show)
# Plot transcription comparison
plot_transcription_comparison(reference, hypothesis, audio_path, output_dir, show)
return True
except Exception as e:
logger.error(f"Error visualizing results: {e}")
return False
def main():
parser = argparse.ArgumentParser(description="Visualize Whisper evaluation results")
parser.add_argument("--results_file", type=str, required=True, help="Path to evaluation results JSON file")
parser.add_argument("--output_dir", type=str, default="./visualizations", help="Directory to save visualization outputs")
parser.add_argument("--audio_dir", type=str, default=None, help="Directory containing audio files (optional)")
parser.add_argument("--num_samples", type=int, default=5, help="Number of individual samples to visualize")
parser.add_argument("--show", action="store_true", help="Show plots interactively")
args = parser.parse_args()
visualize_results(
results_file=args.results_file,
output_dir=args.output_dir,
audio_dir=args.audio_dir,
num_samples=args.num_samples,
show=args.show
)
if __name__ == "__main__":
main()

@ -0,0 +1,240 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import sys
from pathlib import Path
import paddle
import yaml
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import str2bool
logger = Log(__name__).getlog()
def prepare_parser():
"""Prepare argument parser."""
parser = argparse.ArgumentParser(
description="Whisper CLI for fine-tuning, inference, and evaluation",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# Create subparsers for different commands
subparsers = parser.add_subparsers(dest="command", help="Command to run")
# Data preparation command
data_parser = subparsers.add_parser("prepare", help="Prepare data for fine-tuning")
data_parser.add_argument("--language", type=str, default="en", help="Language code (e.g., en, fr, es)")
data_parser.add_argument("--output_dir", type=str, default="./data", help="Directory to save preprocessed data")
data_parser.add_argument("--cache_dir", type=str, default=None, help="Cache directory for HuggingFace datasets")
data_parser.add_argument("--val_size", type=float, default=0.03, help="Validation set size as fraction")
data_parser.add_argument("--test_size", type=float, default=0.03, help="Test set size as fraction")
data_parser.add_argument("--min_duration", type=float, default=0.5, help="Minimum audio duration in seconds")
data_parser.add_argument("--max_duration", type=float, default=30.0, help="Maximum audio duration in seconds")
data_parser.add_argument("--seed", type=int, default=42, help="Random seed")
# Fine-tuning command
train_parser = subparsers.add_parser("train", help="Fine-tune Whisper model")
train_parser.add_argument("--config", type=str, required=True, help="Path to configuration file")
train_parser.add_argument("--resource_path", type=str, default="./resources", help="Path to resources directory")
train_parser.add_argument("--device", type=str, default="gpu", choices=["cpu", "gpu", "xpu"], help="Device to use")
train_parser.add_argument("--seed", type=int, default=42, help="Random seed")
train_parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to checkpoint to resume from")
train_parser.add_argument("--distributed", type=str2bool, default=False, help="Enable distributed training")
# Inference command
infer_parser = subparsers.add_parser("infer", help="Inference with Whisper model")
infer_parser.add_argument("--audio_file", type=str, help="Path to audio file for transcription")
infer_parser.add_argument("--audio_dir", type=str, help="Path to directory containing audio files")
infer_parser.add_argument("--output_dir", type=str, default="./transcripts", help="Output directory for transcriptions")
infer_parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint from fine-tuning")
infer_parser.add_argument("--resource_path", type=str, default="./resources",
help="Path to resources directory containing original models and assets")
# Model options for inference
infer_parser.add_argument("--use_original", type=str2bool, default=False,
help="Use original Whisper model instead of fine-tuned")
infer_parser.add_argument("--model_size", type=str, default="base",
choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"],
help="Model size for original Whisper")
# Decoding options for inference
infer_parser.add_argument("--language", type=str, default=None, help="Language code (e.g., en, fr, auto for detection)")
infer_parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"],
help="Task: transcribe or translate to English")
infer_parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for sampling")
infer_parser.add_argument("--beam_size", type=int, default=5, help="Beam size for beam search")
infer_parser.add_argument("--without_timestamps", type=str2bool, default=False, help="Don't include timestamps")
# Evaluation command
eval_parser = subparsers.add_parser("evaluate", help="Evaluate Whisper model")
eval_parser.add_argument("--manifest", type=str, required=True, help="Path to manifest file")
eval_parser.add_argument("--output_dir", type=str, default="./results", help="Directory to save results")
eval_parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint from fine-tuning")
eval_parser.add_argument("--resource_path", type=str, default="./resources", help="Path to resources directory")
# Model options for evaluation
eval_parser.add_argument("--model_size", type=str, default="base",
choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"],
help="Model size for original Whisper")
# Decoding options for evaluation
eval_parser.add_argument("--language", type=str, default=None, help="Language code (e.g., en, fr)")
eval_parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"],
help="Task: transcribe or translate to English")
eval_parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for sampling")
eval_parser.add_argument("--beam_size", type=int, default=5, help="Beam size for beam search")
eval_parser.add_argument("--without_timestamps", type=str2bool, default=True, help="Don't include timestamps")
eval_parser.add_argument("--max_samples", type=int, default=None, help="Maximum number of samples to evaluate")
return parser
def main():
parser = prepare_parser()
args = parser.parse_args()
if not args.command:
parser.print_help()
sys.exit(1)
# Execute the appropriate command
if args.command == "prepare":
from prepare_data import prepare_common_voice
prepare_common_voice(
language=args.language,
output_dir=args.output_dir,
cache_dir=args.cache_dir,
val_size=args.val_size,
test_size=args.test_size,
min_duration=args.min_duration,
max_duration=args.max_duration,
seed=args.seed
)
elif args.command == "train":
import train
if args.distributed:
import paddle.distributed as dist
dist.init_parallel_env()
# Load configuration
with open(args.config) as f:
config = yaml.safe_load(f)
# Set device
paddle.set_device(args.device)
trainer = train.WhisperTrainer(config, args)
trainer.train()
elif args.command == "infer":
from infer import load_whisper_model, transcribe_file, batch_transcribe
# Validate arguments
if not args.audio_file and not args.audio_dir:
logger.error("Either --audio_file or --audio_dir must be specified")
sys.exit(1)
if args.use_original:
# Use original Whisper model
if not args.resource_path:
logger.error("--resource_path must be specified when using original model")
sys.exit(1)
model = load_whisper_model(
model_size=args.model_size,
resource_path=args.resource_path
)
else:
# Use fine-tuned model
if not args.checkpoint:
logger.error("--checkpoint must be specified when not using original model")
sys.exit(1)
model = load_whisper_model(
model_size=args.model_size,
checkpoint_path=args.checkpoint,
resource_path=args.resource_path
)
# Prepare transcription keyword arguments
transcribe_kwargs = {
"language": args.language,
"task": args.task,
"temperature": args.temperature,
"beam_size": args.beam_size,
"without_timestamps": args.without_timestamps,
}
# Run transcription
if args.audio_file:
result = transcribe_file(model, args.audio_file, args.resource_path, **transcribe_kwargs)
if result:
print("-" * 40)
print("Transcription:")
print(result["text"])
print("-" * 40)
# Save to output directory if specified
if args.output_dir:
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
file_name = os.path.splitext(os.path.basename(args.audio_file))[0]
with open(output_dir / f"{file_name}.txt", "w", encoding="utf-8") as f:
f.write(result["text"].strip())
import json
with open(output_dir / f"{file_name}.json", "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"Results saved to {output_dir}")
elif args.audio_dir:
batch_transcribe(
model,
args.audio_dir,
args.output_dir,
args.resource_path,
**transcribe_kwargs
)
elif args.command == "evaluate":
from evaluate import load_model, evaluate_manifest
model = load_model(
model_size=args.model_size,
checkpoint_path=args.checkpoint,
resource_path=args.resource_path
)
evaluate_manifest(
model=model,
manifest_path=args.manifest,
resource_path=args.resource_path,
language=args.language,
task=args.task,
temperature=args.temperature,
beam_size=args.beam_size,
without_timestamps=args.without_timestamps,
output_dir=args.output_dir,
max_samples=args.max_samples
)
if __name__ == "__main__":
main()

@ -0,0 +1,41 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Whisper fine-tuning module.
This module provides utilities and classes for fine-tuning Whisper models
on custom datasets, following the paper "Robust Speech Recognition via
Large-Scale Weak Supervision" and Hugging Face's fine-tuning approach.
References:
- Radford, A. et al. (2022). Robust Speech Recognition via Large-Scale Weak Supervision.
- https://github.com/openai/whisper
- https://huggingface.co/blog/fine-tune-whisper
- Whisper Original Paper: https://arxiv.org/abs/2212.04356
"""
from paddlespeech.s2t.models.whisper.fine_tune.dataset import (
WhisperDataset,
WhisperInferenceDataset
)
from paddlespeech.s2t.models.whisper.fine_tune.trainer import (
WhisperTrainer
)
__all__ = [
'WhisperDataset',
'WhisperInferenceDataset',
'WhisperTrainer',
]

@ -0,0 +1,360 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import paddle
from paddle.io import Dataset, DataLoader
from paddlespeech.s2t.models.whisper.whisper import (N_MELS, N_SAMPLES, log_mel_spectrogram, pad_or_trim)
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
class WhisperDataset(Dataset):
"""Dataset for Whisper fine-tuning"""
def __init__(
self,
manifest_path: str,
tokenizer,
target_language: str = "en",
task: str = "transcribe",
max_duration: float = 30.0,
min_duration: float = 0.5,
sample_rate: int = 16000,
resource_path: str = '',
pad_to_max_length: bool = False,
):
"""Initialize the dataset.
Args:
manifest_path: Path to manifest file with audio paths and transcripts
tokenizer: Whisper tokenizer
target_language: Target language code
task: Task type, either 'transcribe' or 'translate'
max_duration: Maximum audio duration
min_duration: Minimum audio duration
sample_rate: Audio sample rate
resource_path: Path to resources directory
pad_to_max_length: Whether to pad all sequences to maximum length in batch
"""
super().__init__()
self.tokenizer = tokenizer
self.target_language = target_language
self.task = task
self.sample_rate = sample_rate
self.resource_path = resource_path
self.pad_to_max_length = pad_to_max_length
# Load manifest
with open(manifest_path, 'r', encoding='utf8') as f:
manifest_lines = f.readlines()
self.data = []
for line in manifest_lines:
try:
item = json.loads(line.strip())
duration = item.get('duration', 0)
if min_duration <= duration <= max_duration:
self.data.append(item)
except Exception as e:
logger.warning(f"Error parsing line in manifest: {e}")
logger.info(f"Loaded {len(self.data)} examples from {manifest_path}")
# Generate special tokens and language tokens
self.special_tokens = {
"sot": self.tokenizer.sot,
"eot": self.tokenizer.eot,
"translate": self.tokenizer.translate if hasattr(self.tokenizer, "translate") else None,
"transcribe": self.tokenizer.transcribe if hasattr(self.tokenizer, "transcribe") else None,
"no_timestamps": self.tokenizer.no_timestamps if hasattr(self.tokenizer, "no_timestamps") else None,
}
# Get language token
self.language_token = self.tokenizer.language_tokens.get(self.target_language) if hasattr(self.tokenizer, "language_tokens") else None
if not self.language_token and hasattr(self.tokenizer, "language_token"):
self.language_token = self.tokenizer.language_token(self.target_language)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
# Load audio
audio_path = item["audio"]
try:
# Process audio to mel spectrogram
mel = log_mel_spectrogram(audio_path, n_mels=N_MELS, resource_path=self.resource_path)
mel = pad_or_trim(mel, N_SAMPLES)
# Get text
text = item["text"]
# Create prompt tokens
prompt_tokens = [self.special_tokens["sot"]]
# Add language token if available
if self.language_token is not None:
prompt_tokens.append(self.language_token)
# Add task token if available
task_token = self.special_tokens.get(self.task)
if task_token is not None:
prompt_tokens.append(task_token)
# Add no_timestamps token if available
if self.special_tokens["no_timestamps"] is not None:
prompt_tokens.append(self.special_tokens["no_timestamps"])
# Encode the text
target_tokens = (
prompt_tokens +
self.tokenizer.encode(text) +
[self.special_tokens["eot"]]
)
return {
"mel": mel,
"target_tokens": np.array(target_tokens, dtype=np.int64),
"text": text,
"audio_path": audio_path
}
except Exception as e:
logger.warning(f"Error processing {audio_path}: {e}")
# Return a dummy sample that will be filtered in collate_fn
return None
@staticmethod
def collate_fn(batch, pad_idx=-100):
"""Collate function for DataLoader"""
# Filter None samples
batch = [sample for sample in batch if sample is not None]
if not batch:
return None
# Get maximum sequence length in this batch
max_target_len = max(len(sample["target_tokens"]) for sample in batch)
# Initialize tensors
mel_specs = []
token_ids = []
labels = []
# Also collect metadata for debugging/logging
texts = []
audio_paths = []
for sample in batch:
target_tokens = sample["target_tokens"]
target_len = len(target_tokens)
# Prepare inputs and labels for causal LM training
# Input tokens are shifted right
input_tokens = np.zeros(max_target_len, dtype=np.int64)
input_tokens[:target_len-1] = target_tokens[:target_len-1] # Exclude EOS
# Labels are shifted left and padded with pad_idx (ignore index)
label = np.full(max_target_len, pad_idx, dtype=np.int64)
label[:target_len-1] = target_tokens[1:target_len] # Start from first token after SOT
# Add to lists
mel_specs.append(sample["mel"])
token_ids.append(input_tokens)
labels.append(label)
# Collect metadata
texts.append(sample["text"])
audio_paths.append(sample["audio_path"])
# Convert to tensors
mel_specs = paddle.to_tensor(np.array(mel_specs), dtype=paddle.float32)
token_ids = paddle.to_tensor(np.array(token_ids), dtype=paddle.int64)
labels = paddle.to_tensor(np.array(labels), dtype=paddle.int64)
return {
"mel": mel_specs,
"tokens": token_ids,
"labels": labels,
"texts": texts,
"audio_paths": audio_paths,
}
def create_dataloader(self,
batch_size=16,
num_workers=4,
shuffle=True,
drop_last=False):
"""Create a dataloader from this dataset"""
return DataLoader(
self,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
collate_fn=self.collate_fn,
drop_last=drop_last
)
class WhisperInferenceDataset(Dataset):
"""Dataset for Whisper inference with batching support"""
def __init__(
self,
audio_paths: Union[str, List[str]],
tokenizer=None,
language: Optional[str] = None,
task: str = "transcribe",
sample_rate: int = 16000,
resource_path: str = '',
):
"""Initialize the inference dataset.
Args:
audio_paths: Path to audio file or directory, or list of audio file paths
tokenizer: Whisper tokenizer (optional, only needed for preparing decoder input)
language: Language code (e.g., 'en', 'fr')
task: Task type, either 'transcribe' or 'translate'
sample_rate: Audio sample rate
resource_path: Path to resources directory
"""
super().__init__()
self.tokenizer = tokenizer
self.language = language
self.task = task
self.sample_rate = sample_rate
self.resource_path = resource_path
# Process audio paths
if isinstance(audio_paths, str):
# Single file
if os.path.isfile(audio_paths):
self.audio_paths = [audio_paths]
# Directory
elif os.path.isdir(audio_paths):
self.audio_paths = [
os.path.join(audio_paths, f) for f in os.listdir(audio_paths)
if f.endswith(('.wav', '.mp3', '.flac', '.ogg'))
]
else:
raise ValueError(f"Path not found: {audio_paths}")
else:
# List of files
self.audio_paths = audio_paths
logger.info(f"Loaded {len(self.audio_paths)} audio files for inference")
# Generate special tokens and language tokens if tokenizer provided
self.special_tokens = None
self.language_token = None
if self.tokenizer:
self.special_tokens = {
"sot": self.tokenizer.sot,
"eot": self.tokenizer.eot,
"translate": self.tokenizer.translate if hasattr(self.tokenizer, "translate") else None,
"transcribe": self.tokenizer.transcribe if hasattr(self.tokenizer, "transcribe") else None,
"no_timestamps": self.tokenizer.no_timestamps if hasattr(self.tokenizer, "no_timestamps") else None,
}
# Get language token
if self.language and hasattr(self.tokenizer, "language_tokens"):
self.language_token = self.tokenizer.language_tokens.get(self.language)
elif self.language and hasattr(self.tokenizer, "language_token"):
self.language_token = self.tokenizer.language_token(self.language)
def __len__(self):
return len(self.audio_paths)
def __getitem__(self, idx):
audio_path = self.audio_paths[idx]
# Compute mel spectrogram
try:
mel = log_mel_spectrogram(audio_path, self.resource_path, self.sample_rate)
except Exception as e:
logger.error(f"Error processing audio file {audio_path}: {e}")
# Return zero tensor with correct shape in case of error
mel = paddle.zeros((N_MELS, N_SAMPLES // 160))
return {
"mel": mel.numpy(),
"audio_path": audio_path,
}
@staticmethod
def collate_fn(batch):
"""Collate function for inference DataLoader"""
# Extract mel spectrograms and audio paths
mel_specs = []
audio_paths = []
for sample in batch:
mel_specs.append(sample["mel"])
audio_paths.append(sample["audio_path"])
# Convert to tensors
mel_specs = paddle.to_tensor(np.array(mel_specs), dtype=paddle.float32)
return {
"mel": mel_specs,
"audio_paths": audio_paths,
}
def create_dataloader(self,
batch_size=1,
num_workers=4):
"""Create a dataloader from this inference dataset"""
return DataLoader(
self,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=self.collate_fn,
drop_last=False
)
def prepare_decoder_input(self):
"""Prepare decoder input tokens for initial prompt
Only used when tokenizer is provided.
"""
if not self.tokenizer:
logger.warning("Cannot prepare decoder input without tokenizer")
return None
# Create initial tokens - similar to training but without labels
tokens = [self.special_tokens["sot"]]
if self.language_token:
tokens.append(self.language_token)
if self.task == "translate":
tokens.append(self.special_tokens["translate"])
elif self.task == "transcribe":
tokens.append(self.special_tokens["transcribe"])
if self.special_tokens["no_timestamps"]:
tokens.append(self.special_tokens["no_timestamps"])
return paddle.to_tensor([tokens], dtype=paddle.int64)

@ -0,0 +1,380 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from pathlib import Path
from typing import Dict, List, Optional, Union
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.distributed as dist
from paddle.io import DataLoader
from paddle.optimizer import AdamW
from paddle.optimizer.lr import CosineAnnealingDecay, LinearWarmup
from paddlespeech.s2t.models.whisper.tokenizer import get_tokenizer
from paddlespeech.s2t.models.whisper.whisper import MODEL_DIMENSIONS, Whisper
from paddlespeech.s2t.training.reporter import ObsScope, Reporter
from paddlespeech.s2t.training.timer import Timer
from paddlespeech.s2t.utils.checkpoint import Checkpoint
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import get_rank
logger = Log(__name__).getlog()
class WhisperTrainer:
"""Trainer for fine-tuning Whisper models."""
def __init__(
self,
config: Dict,
model: Optional[Whisper] = None,
optimizer: Optional[paddle.optimizer.Optimizer] = None,
checkpoint_dir: Optional[Union[str, Path]] = None,
resource_path: Optional[str] = None,
log_interval: int = 10,
save_interval: int = 1,
rank: int = 0,
world_size: int = 1,
):
"""Initialize the trainer.
Args:
config: Training configuration dictionary
model: Whisper model instance, or None to build one from config
optimizer: Optimizer instance, or None to build one from config
checkpoint_dir: Directory to save checkpoints
resource_path: Path to resources directory containing whisper assets
log_interval: Steps between logging
save_interval: Epochs between checkpoint saving
rank: Local rank for distributed training
world_size: Total number of processes for distributed training
"""
self.config = config
self.resource_path = resource_path
self.log_interval = log_interval
self.save_interval = save_interval
self.rank = rank
self.world_size = world_size
# Initialize model if not provided
self.model = model if model is not None else self._build_model()
# Initialize tokenizer
self.tokenizer = self._build_tokenizer()
# Initialize optimizer if not provided
self.optimizer = optimizer if optimizer is not None else self._build_optimizer()
# Initialize checkpoint
self.checkpoint = self._init_checkpoint(checkpoint_dir)
# Initialize reporter and timer
self.reporter = Reporter()
self.timer = Timer()
def _build_model(self) -> Whisper:
"""Build Whisper model from config."""
model_config = self.config['model']
model_size = model_config['size']
# Load model dimensions
model_dims = MODEL_DIMENSIONS[model_size]
# Create the model
model = Whisper(model_dims)
# Load checkpoint if provided
checkpoint_path = model_config.get('checkpoint')
if checkpoint_path:
logger.info(f"Loading model weights from {checkpoint_path}")
state_dict = paddle.load(checkpoint_path)
model.set_state_dict(state_dict)
# Freeze encoder if needed
if model_config.get('freeze_encoder', False):
logger.info("Freezing encoder parameters")
for param in model.encoder.parameters():
param.stop_gradient = True
# Handle distributed training
if self.world_size > 1:
logger.info(f"Initializing distributed model with {self.world_size} processes")
model = paddle.DataParallel(model)
return model
def _build_tokenizer(self):
"""Build Whisper tokenizer."""
target_language = self.config['data'].get('target_language', 'en')
task = self.config['data'].get('task', 'transcribe')
return get_tokenizer(
multilingual=self.model.is_multilingual,
resource_path=self.resource_path,
language=target_language,
task=task
)
def _build_optimizer(self) -> paddle.optimizer.Optimizer:
"""Build optimizer and learning rate scheduler."""
training_config = self.config['training']
# Set learning rate
learning_rate = training_config['learning_rate']
# Apply weight decay
decay_params = [
p.name for n, p in self.model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]
# Build scheduler
scheduler_name = training_config.get('scheduler', 'linear')
max_steps = training_config.get('max_steps', 100000)
# Calculate warmup steps
warmup_ratio = training_config.get('warmup_ratio', 0.1)
warmup_steps = int(max_steps * warmup_ratio)
if scheduler_name == 'cosine':
lr_scheduler = LinearWarmup(
CosineAnnealingDecay(learning_rate, max_steps - warmup_steps),
warmup_steps,
0.0,
learning_rate
)
else: # default to linear
lr_scheduler = paddle.optimizer.lr.LinearWarmup(
paddle.optimizer.lr.PolynomialDecay(
learning_rate=learning_rate,
decay_steps=max_steps - warmup_steps,
end_lr=0.0,
power=1.0),
warmup_steps,
0.0,
learning_rate
)
# Create optimizer
weight_decay = training_config.get('weight_decay', 0.01)
max_grad_norm = training_config.get('max_grad_norm', 1.0)
optimizer = AdamW(
learning_rate=lr_scheduler,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
parameters=self.model.parameters(),
weight_decay=weight_decay,
grad_clip=nn.ClipGradByNorm(max_grad_norm),
apply_decay_param_fun=lambda x: x in decay_params
)
return optimizer
def _init_checkpoint(self, checkpoint_dir=None) -> Checkpoint:
"""Initialize checkpoint for saving and loading."""
if checkpoint_dir is None:
checkpoint_dir = self.config['output'].get('checkpoint_dir', './exp/whisper_fine_tune')
checkpoint_dir = Path(checkpoint_dir)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
return Checkpoint(
checkpoint_dir=checkpoint_dir,
model=self.model,
optimizer=self.optimizer,
infos=dict(),
visualizer=None,
**{"epoch": 0}
)
def train(self, train_loader: DataLoader, dev_loader: Optional[DataLoader] = None, num_epochs: int = 10):
"""Run training loop.
Args:
train_loader: DataLoader for training data
dev_loader: DataLoader for validation data
num_epochs: Number of epochs to train
"""
self.reporter.register(self.model, "model")
# Get training configuration
training_config = self.config['training']
max_epoch = num_epochs or training_config.get('max_epoch', 10)
accum_grad = training_config.get('accum_grad', 1)
# Start training
logger.info(f"Starting training for {max_epoch} epochs")
self.timer.start()
# Resume from checkpoint if epoch > 0
start_epoch = self.checkpoint.infos.get("epoch", 0)
for epoch in range(start_epoch, max_epoch):
self._train_epoch(train_loader, epoch, accum_grad)
# Validation
if dev_loader is not None:
dev_loss = self._eval_epoch(dev_loader, epoch)
self.checkpoint.infos["dev_loss"] = dev_loss
# Update epoch in checkpoint
self.checkpoint.infos["epoch"] = epoch + 1
# Save checkpoint
if (epoch + 1) % self.save_interval == 0 and self.rank == 0:
logger.info(f"Saving checkpoint at epoch {epoch + 1}")
self.checkpoint.save_parameters(tag=f"epoch_{epoch + 1}")
# Save final model
if self.rank == 0:
logger.info("Saving final model")
self.checkpoint.save_parameters(tag="final")
def _train_epoch(self, train_loader: DataLoader, epoch: int, accum_grad: int = 1):
"""Train for one epoch."""
self.model.train()
train_loss = 0.0
num_batches = 0
steps_per_epoch = len(train_loader)
for batch_idx, batch in enumerate(train_loader):
if batch is None:
continue
# Get batch data
mel = batch["mel"]
tokens = batch["tokens"]
labels = batch["labels"]
# Forward pass
audio_features = self.model.embed_audio(mel)
logits = self.model.logits(tokens, audio_features)
# Compute loss
loss = F.cross_entropy(
logits.reshape([-1, logits.shape[-1]]),
labels.reshape([-1]),
ignore_index=-100
)
# Scale loss for gradient accumulation
if accum_grad > 1:
loss = loss / accum_grad
# Backward pass
loss.backward()
# Update parameters every accum_grad steps
if (batch_idx + 1) % accum_grad == 0:
self.optimizer.step()
self.optimizer.clear_grad()
# Logging
train_loss += loss.item() * (accum_grad if accum_grad > 1 else 1)
num_batches += 1
# Log training progress
if batch_idx % self.log_interval == 0 and self.rank == 0:
elapsed_time = self.timer.elapsed_interval()
step = epoch * steps_per_epoch + batch_idx
logger.info(f"Epoch {epoch+1} | Batch {batch_idx}/{steps_per_epoch} | "
f"Loss: {loss.item()*(accum_grad if accum_grad > 1 else 1):.4f} | "
f"Step {step} | {elapsed_time:.2f}s elapsed")
# Compute average loss
avg_loss = train_loss / num_batches if num_batches > 0 else float('inf')
# Log epoch stats
if self.rank == 0:
logger.info(f"Epoch {epoch+1} | Average Training Loss: {avg_loss:.4f}")
self.checkpoint.infos["train_loss"] = avg_loss
def _eval_epoch(self, dev_loader: DataLoader, epoch: int):
"""Evaluate for one epoch."""
self.model.eval()
dev_losses = []
with paddle.no_grad():
for batch in dev_loader:
if batch is None:
continue
# Get batch data
mel = batch["mel"]
tokens = batch["tokens"]
labels = batch["labels"]
# Forward pass
audio_features = self.model.embed_audio(mel)
logits = self.model.logits(tokens, audio_features)
# Compute loss
loss = F.cross_entropy(
logits.reshape([-1, logits.shape[-1]]),
labels.reshape([-1]),
ignore_index=-100
)
dev_losses.append(loss.item())
avg_loss = sum(dev_losses) / len(dev_losses) if dev_losses else float('inf')
if self.rank == 0:
logger.info(f"Epoch {epoch+1} | Validation Loss: {avg_loss:.4f}")
return avg_loss
def save(self, tag: str = "final"):
"""Save model checkpoint."""
if self.rank == 0:
logger.info(f"Saving model checkpoint with tag '{tag}'")
self.checkpoint.save_parameters(tag=tag)
def load(self, checkpoint_path: Union[str, Path]):
"""Load model from checkpoint."""
logger.info(f"Loading checkpoint from {checkpoint_path}")
self.checkpoint.load_parameters(checkpoint_path)
@classmethod
def from_pretrained(cls,
config: Dict,
model_size: str = "base",
checkpoint_path: Optional[str] = None,
resource_path: Optional[str] = None):
"""Create a trainer from pretrained model."""
# Create model
model_dims = MODEL_DIMENSIONS[model_size]
model = Whisper(model_dims)
# Load checkpoint if provided
if checkpoint_path:
state_dict = paddle.load(checkpoint_path)
model.set_state_dict(state_dict)
# Create trainer
trainer = cls(
config=config,
model=model,
resource_path=resource_path
)
return trainer

@ -0,0 +1,271 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pathlib import Path
from typing import Dict, List, Optional, Union
import numpy as np
import paddle
import yaml
from paddlespeech.s2t.models.whisper.tokenizer import get_tokenizer
from paddlespeech.s2t.models.whisper.whisper import (
MODEL_DIMENSIONS,
LANGUAGES,
Whisper,
DecodingOptions,
load_model,
transcribe
)
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
def load_whisper_model(
model_size: str = "base",
checkpoint_path: Optional[str] = None,
resource_path: Optional[str] = None,
) -> Whisper:
"""Load Whisper model from checkpoint or pretrained weights.
Args:
model_size: Model size for Whisper
checkpoint_path: Path to model checkpoint from fine-tuning
resource_path: Path to resources directory containing original models
Returns:
Whisper model
"""
model_dims = MODEL_DIMENSIONS[model_size]
model = Whisper(model_dims)
if checkpoint_path:
logger.info(f"Loading model from checkpoint: {checkpoint_path}")
state_dict = paddle.load(checkpoint_path)
model.set_state_dict(state_dict)
elif resource_path:
model_path = os.path.join(resource_path, "whisper", f"whisper-{model_size}.pdparams")
if os.path.exists(model_path):
logger.info(f"Loading pretrained model from: {model_path}")
state_dict = paddle.load(model_path)
model.set_state_dict(state_dict)
else:
logger.error(f"Pretrained model not found at {model_path}")
raise FileNotFoundError(f"Model file not found: {model_path}")
else:
logger.error("Either checkpoint_path or resource_path must be provided")
raise ValueError("Either checkpoint_path or resource_path must be provided")
return model
def detect_language(
model: Whisper,
mel: paddle.Tensor,
tokenizer=None,
resource_path: Optional[str] = None
) -> str:
"""Detect language from audio.
Args:
model: Whisper model
mel: Mel spectrogram
tokenizer: Optional tokenizer
resource_path: Path to resources directory (required if tokenizer not provided)
Returns:
Detected language code
"""
if not tokenizer and not resource_path:
raise ValueError("Either tokenizer or resource_path must be provided")
if not tokenizer:
tokenizer = get_tokenizer(
multilingual=model.is_multilingual,
resource_path=resource_path
)
# Get audio features from encoder
audio_features = model.embed_audio(mel)
# Get initial tokens
initial_tokens = paddle.to_tensor([[tokenizer.sot]], dtype=paddle.int64)
# Extract language token logits
token_logits = model.logits(initial_tokens, audio_features)
language_token_logits = token_logits[:, 0, tokenizer.language_token_ranges]
# Get language token with highest probability
language_token_probs = paddle.softmax(language_token_logits, axis=-1)
language_token_id = paddle.argmax(language_token_probs, axis=-1)
detected_language_token_id = language_token_id.item() + tokenizer.language_token_ranges[0]
# Map token to language code
language_token = tokenizer.all_tokens[detected_language_token_id]
language_code = language_token.strip("<>")
return language_code
def get_available_languages() -> List[str]:
"""Get list of available languages in Whisper.
Returns:
List of language codes
"""
return sorted(LANGUAGES.keys())
def format_timestamp(
seconds: float,
always_include_hours: bool = False,
decimal_marker: str = "."
) -> str:
"""Format a timestamp into a string.
Args:
seconds: Number of seconds
always_include_hours: Whether to always include hours
decimal_marker: Decimal marker character
Returns:
Formatted timestamp string
"""
seconds = max(0, seconds)
hours = seconds // 3600
seconds = seconds - hours * 3600
minutes = seconds // 60
seconds = seconds - minutes * 60
hours_marker = f"{int(hours):02d}:" if always_include_hours or hours > 0 else ""
return f"{hours_marker}{int(minutes):02d}:{int(seconds):02d}{decimal_marker}{int(seconds * 100 % 100):02d}"
def save_srt(
transcript: Dict,
file_path: str
):
"""Save transcript to SRT file.
Args:
transcript: Transcript dictionary from Whisper
file_path: Path to output SRT file
"""
if not transcript.get("segments"):
return
with open(file_path, "w", encoding="utf-8") as f:
for i, segment in enumerate(transcript["segments"], start=1):
start = format_timestamp(segment["start"], always_include_hours=True, decimal_marker=",")
end = format_timestamp(segment["end"], always_include_hours=True, decimal_marker=",")
text = segment["text"].strip().replace("-->", "->")
f.write(f"{i}\n")
f.write(f"{start} --> {end}\n")
f.write(f"{text}\n\n")
def batch_transcribe_with_progress(
model: Whisper,
dataset,
dataloader,
resource_path: str,
language: Optional[str] = None,
task: str = "transcribe",
beam_size: int = 5,
temperature: float = 0.0,
without_timestamps: bool = True,
verbose: bool = False,
decoder_options: Optional[dict] = None,
):
"""Transcribe audio files in batches with progress reporting.
Args:
model: Whisper model
dataset: WhisperInferenceDataset
dataloader: DataLoader from dataset
resource_path: Path to resources directory
language: Language code or None for auto-detection
task: Task (transcribe or translate)
beam_size: Beam size for beam search
temperature: Temperature for sampling
without_timestamps: Whether to include timestamps
verbose: Whether to show verbose output
decoder_options: Additional decoder options
Returns:
List of transcription results
"""
model.eval()
tokenizer = get_tokenizer(
multilingual=model.is_multilingual,
resource_path=resource_path,
language=language,
task=task
)
results = []
total_batches = len(dataloader)
with paddle.no_grad():
for batch_idx, batch in enumerate(dataloader):
if verbose:
logger.info(f"Processing batch {batch_idx+1}/{total_batches}")
mel = batch["mel"] # [batch, n_mels, n_frames]
audio_paths = batch["audio_paths"]
batch_size = mel.shape[0]
# Process each item in batch
for i in range(batch_size):
audio_path = audio_paths[i]
mel_i = mel[i:i+1] # Keep batch dimension
# Auto-detect language if needed
current_language = language
if not current_language:
current_language = detect_language(model, mel_i, tokenizer, resource_path)
if verbose:
logger.info(f"Detected language: {current_language}")
# Set up decoding options
options = DecodingOptions(
task=task,
language=current_language,
beam_size=beam_size,
temperature=temperature,
without_timestamps=without_timestamps,
**decoder_options if decoder_options else {}
)
# Transcribe
result = transcribe(
model=model,
tokenizer=tokenizer,
mel=mel_i,
resource_path=resource_path,
verbose=verbose,
**options.__dict__
)
# Add file path to result
result["audio_path"] = audio_path
results.append(result)
if verbose:
logger.info(f"Transcribed {os.path.basename(audio_path)}: {result['text'][:80]}...")
return results
Loading…
Cancel
Save