Merge 78a1935309
into 8247eba840
commit
27188fb02e
@ -0,0 +1,238 @@
|
||||
# Whisper Fine-tuning on Common Voice
|
||||
|
||||
This example demonstrates how to fine-tune Whisper models on the [Common Voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) dataset using PaddleSpeech.
|
||||
|
||||
## Overview
|
||||
|
||||
Whisper is a state-of-the-art speech recognition model from OpenAI. This implementation allows you to fine-tune Whisper models on new datasets to improve performance for specific languages, domains, or accents.
|
||||
|
||||
## Features
|
||||
|
||||
- Complete fine-tuning pipeline for Whisper models on custom datasets
|
||||
- Flexible configuration via YAML files
|
||||
- Support for all Whisper model sizes (tiny, base, small, medium, large, etc.)
|
||||
- Data preparation tools for Common Voice and custom datasets
|
||||
- Distributed training support with mixed precision
|
||||
- Gradient accumulation for large batch training
|
||||
- Learning rate scheduling and optimization techniques
|
||||
- Evaluation tools with WER/CER metrics
|
||||
- Command-line inference with both fine-tuned and original Whisper models
|
||||
- Model export utilities for deployment
|
||||
- Visualization tools for performance analysis
|
||||
|
||||
## Installation
|
||||
|
||||
Ensure you have PaddleSpeech installed with all dependencies:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/PaddlePaddle/PaddleSpeech.git
|
||||
cd PaddleSpeech
|
||||
pip install -e .
|
||||
pip install datasets soundfile librosa matplotlib pandas jiwer
|
||||
```
|
||||
|
||||
## Data
|
||||
|
||||
We use the Common Voice dataset (version 11.0) available on Hugging Face:
|
||||
https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0
|
||||
|
||||
Other datasets compatible with the pipeline include LibriSpeech, AISHELL, and any dataset that can be converted to the manifest format (see below).
|
||||
|
||||
## Unified Command-Line Interface
|
||||
|
||||
This example includes a unified CLI for all operations:
|
||||
|
||||
```bash
|
||||
python whisper_cli.py COMMAND [OPTIONS]
|
||||
```
|
||||
|
||||
Available commands:
|
||||
- `prepare`: Prepare dataset for fine-tuning
|
||||
- `train`: Fine-tune Whisper model
|
||||
- `evaluate`: Evaluate model performance
|
||||
- `infer`: Run inference with fine-tuned or original model
|
||||
|
||||
## Data Preparation
|
||||
|
||||
To download and prepare the Common Voice dataset:
|
||||
|
||||
```bash
|
||||
python whisper_cli.py prepare --language en --output_dir ./data
|
||||
```
|
||||
|
||||
Options:
|
||||
- `--language`: Target language code (default: en)
|
||||
- `--output_dir`: Directory to save preprocessed data (default: ./data)
|
||||
- `--cache_dir`: Cache directory for HuggingFace datasets
|
||||
- `--val_size`: Validation set size ratio (default: 0.03)
|
||||
- `--test_size`: Test set size ratio (default: 0.03)
|
||||
- `--min_duration`: Minimum audio duration in seconds (default: 0.5)
|
||||
- `--max_duration`: Maximum audio duration in seconds (default: 30.0)
|
||||
|
||||
Manifest format:
|
||||
```json
|
||||
{"audio": "path/to/audio.wav", "text": "transcription", "duration": 3.45}
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Fine-tuning parameters are specified in YAML config files. See `conf/whisper_base.yaml` for a detailed example.
|
||||
|
||||
Key configuration sections:
|
||||
- **model**: Model size, checkpoint path, freeze options
|
||||
- **data**: Dataset paths, languages, tasks
|
||||
- **training**: Batch size, learning rate, optimizer settings
|
||||
- **distributed**: Distributed training options
|
||||
- **output**: Save paths, logging options
|
||||
|
||||
## Training
|
||||
|
||||
To fine-tune the Whisper model:
|
||||
|
||||
```bash
|
||||
python whisper_cli.py train --config conf/whisper_base.yaml --resource_path ./resources
|
||||
```
|
||||
|
||||
For distributed training:
|
||||
|
||||
```bash
|
||||
python -m paddle.distributed.launch --gpus "0,1,2,3" whisper_cli.py train --config conf/whisper_base.yaml --distributed True
|
||||
```
|
||||
|
||||
Options:
|
||||
- `--config`: Path to configuration YAML file
|
||||
- `--resource_path`: Path to resources directory containing model assets
|
||||
- `--device`: Device to use (cpu, gpu, xpu)
|
||||
- `--seed`: Random seed
|
||||
- `--checkpoint_path`: Path to resume training from checkpoint
|
||||
- `--distributed`: Enable distributed training
|
||||
|
||||
## Evaluation
|
||||
|
||||
Evaluate model performance on a test set:
|
||||
|
||||
```bash
|
||||
python whisper_cli.py evaluate --manifest ./data/test_manifest.json --checkpoint ./exp/whisper_fine_tune/epoch_10 --output_dir ./eval_results
|
||||
```
|
||||
|
||||
Options:
|
||||
- `--manifest`: Path to test manifest file
|
||||
- `--checkpoint`: Path to model checkpoint
|
||||
- `--model_size`: Model size if using original Whisper
|
||||
- `--language`: Language code
|
||||
- `--output_dir`: Directory to save evaluation results
|
||||
- `--max_samples`: Maximum number of samples to evaluate
|
||||
|
||||
## Inference
|
||||
|
||||
For transcribing audio with a fine-tuned model:
|
||||
|
||||
```bash
|
||||
python whisper_cli.py infer --audio_file path/to/audio.wav --checkpoint ./exp/whisper_fine_tune/final
|
||||
```
|
||||
|
||||
For batch processing a directory:
|
||||
|
||||
```bash
|
||||
python whisper_cli.py infer --audio_dir path/to/audio/folder --output_dir ./transcriptions --checkpoint ./exp/whisper_fine_tune/final
|
||||
```
|
||||
|
||||
For inference with the original Whisper models:
|
||||
|
||||
```bash
|
||||
python whisper_cli.py infer --audio_file path/to/audio.wav --use_original --model_size large-v3 --resource_path ./resources
|
||||
```
|
||||
|
||||
Options:
|
||||
- `--audio_file`: Path to single audio file
|
||||
- `--audio_dir`: Path to directory with audio files
|
||||
- `--checkpoint`: Path to fine-tuned checkpoint
|
||||
- `--use_original`: Use original Whisper model
|
||||
- `--model_size`: Model size (tiny, base, small, medium, large, etc.)
|
||||
- `--language`: Language code (or "auto" for detection)
|
||||
- `--task`: Task type (transcribe or translate)
|
||||
- `--beam_size`: Beam size for beam search
|
||||
- `--temperature`: Temperature for sampling
|
||||
- `--without_timestamps`: Don't include timestamps
|
||||
|
||||
## Visualization
|
||||
|
||||
Visualize evaluation results:
|
||||
|
||||
```bash
|
||||
python visualize.py --results_file ./eval_results/evaluation_results.json --output_dir ./visualizations
|
||||
```
|
||||
|
||||
Options:
|
||||
- `--results_file`: Path to evaluation results JSON file
|
||||
- `--output_dir`: Directory to save visualizations
|
||||
- `--audio_dir`: Directory with audio files (optional)
|
||||
- `--num_samples`: Number of individual samples to visualize
|
||||
- `--show`: Show plots interactively
|
||||
|
||||
## Model Export
|
||||
|
||||
Export fine-tuned model to inference format:
|
||||
|
||||
```bash
|
||||
python export_model.py --checkpoint ./exp/whisper_fine_tune/final --output_path ./exported_model --model_size base
|
||||
```
|
||||
|
||||
Options:
|
||||
- `--checkpoint`: Path to model checkpoint
|
||||
- `--output_path`: Path to save exported model
|
||||
- `--model_size`: Model size
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Freezing Encoder
|
||||
|
||||
To freeze the encoder and only fine-tune the decoder, set the following in your config file:
|
||||
|
||||
```yaml
|
||||
model:
|
||||
freeze_encoder: true
|
||||
```
|
||||
|
||||
### Gradient Accumulation
|
||||
|
||||
For effective training with limited GPU memory, use gradient accumulation:
|
||||
|
||||
```yaml
|
||||
training:
|
||||
accum_grad: 8 # Accumulate gradients over 8 batches
|
||||
```
|
||||
|
||||
### Mixed Precision
|
||||
|
||||
Enable mixed precision training for faster computation:
|
||||
|
||||
```yaml
|
||||
training:
|
||||
amp: true # Enable automatic mixed precision
|
||||
```
|
||||
|
||||
### Custom Datasets
|
||||
|
||||
To use custom datasets, prepare manifest files in the following format:
|
||||
|
||||
```json
|
||||
{"audio": "/absolute/path/to/audio.wav", "text": "transcription text"}
|
||||
```
|
||||
|
||||
Then specify the manifest paths in your config file:
|
||||
|
||||
```yaml
|
||||
data:
|
||||
train_manifest: path/to/train_manifest.json
|
||||
dev_manifest: path/to/dev_manifest.json
|
||||
test_manifest: path/to/test_manifest.json
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
- [OpenAI Whisper](https://github.com/openai/whisper)
|
||||
- [Fine-tuning Whisper](https://huggingface.co/blog/fine-tune-whisper)
|
||||
- [Paper: Robust Speech Recognition via Large-Scale Weak Supervision](https://arxiv.org/pdf/2212.04356)
|
||||
- [Common Voice Dataset](https://commonvoice.mozilla.org/)
|
||||
- [PaddleSpeech Documentation](https://github.com/PaddlePaddle/PaddleSpeech)
|
@ -0,0 +1,70 @@
|
||||
# Configuration for fine-tuning Whisper base model
|
||||
|
||||
# Data settings
|
||||
data:
|
||||
train_manifest: "data/train_manifest.json"
|
||||
dev_manifest: "data/dev_manifest.json"
|
||||
test_manifest: "data/test_manifest.json"
|
||||
target_language: "en" # Language code for fine-tuning
|
||||
max_duration: 30.0 # Maximum audio duration in seconds
|
||||
min_duration: 0.5 # Minimum audio duration in seconds
|
||||
|
||||
# Model settings
|
||||
model:
|
||||
name: "whisper"
|
||||
size: "base" # Options: tiny, base, small, medium, large, large-v2, large-v3
|
||||
checkpoint: null # Path to pre-trained checkpoint, null for default
|
||||
freeze_encoder: false # Whether to freeze the encoder during fine-tuning
|
||||
use_fp16: true # Whether to use half precision
|
||||
|
||||
# Training settings
|
||||
training:
|
||||
max_epoch: 20
|
||||
save_epoch: 1
|
||||
log_interval: 100
|
||||
batch_size: 16
|
||||
num_workers: 4
|
||||
accum_grad: 1 # Gradient accumulation steps
|
||||
|
||||
# Optimizer settings
|
||||
optimizer: "adamw"
|
||||
learning_rate: 1e-5
|
||||
weight_decay: 0.01
|
||||
scheduler: "cosine"
|
||||
warmup_ratio: 0.03
|
||||
max_grad_norm: 1.0
|
||||
|
||||
# Regularization
|
||||
dropout: 0.1
|
||||
label_smoothing: 0.1
|
||||
|
||||
# Mixed precision training
|
||||
amp_level: "O1"
|
||||
amp_dtype: "float16"
|
||||
|
||||
# Distributed training
|
||||
distributed:
|
||||
use_fleet: true
|
||||
strategy: "standard"
|
||||
find_unused_parameters: false
|
||||
|
||||
# Output settings
|
||||
output:
|
||||
checkpoint_dir: "exp/whisper_fine_tune"
|
||||
save_checkpoint: true
|
||||
save_interval: 1
|
||||
keep_checkpoint_max: 5
|
||||
|
||||
# Evaluation settings
|
||||
eval:
|
||||
eval_batch_size: 16
|
||||
metrics: ["wer", "cer"]
|
||||
|
||||
# Inference settings
|
||||
inference:
|
||||
beam_size: 5
|
||||
min_tokens: 0
|
||||
max_tokens: 448
|
||||
temperature: 0.0
|
||||
language: null # Set to target language code or null to auto-detect
|
||||
without_timestamps: true
|
@ -0,0 +1,264 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import yaml
|
||||
|
||||
from paddlespeech.s2t.models.whisper.whisper import (MODEL_DIMENSIONS, Whisper,
|
||||
DecodingOptions, log_mel_spectrogram,
|
||||
transcribe)
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
from paddlespeech.s2t.utils.utility import get_rank, str2bool
|
||||
from paddlespeech.metrics.wer import word_errors, char_errors
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
def compute_metrics(references, hypotheses, language="en"):
|
||||
"""Compute WER and CER metrics."""
|
||||
total_words = 0
|
||||
total_chars = 0
|
||||
total_word_errors = 0
|
||||
total_char_errors = 0
|
||||
|
||||
for ref, hyp in zip(references, hypotheses):
|
||||
ref = ref.strip()
|
||||
hyp = hyp.strip()
|
||||
|
||||
word_error_count, word_count = word_errors(ref, hyp, language)
|
||||
char_error_count, char_count = char_errors(ref, hyp, language)
|
||||
|
||||
total_words += word_count
|
||||
total_chars += char_count
|
||||
total_word_errors += word_error_count
|
||||
total_char_errors += char_error_count
|
||||
|
||||
wer = float(total_word_errors) / max(1, total_words)
|
||||
cer = float(total_char_errors) / max(1, total_chars)
|
||||
|
||||
return {
|
||||
"wer": wer,
|
||||
"cer": cer,
|
||||
"word_errors": total_word_errors,
|
||||
"word_count": total_words,
|
||||
"char_errors": total_char_errors,
|
||||
"char_count": total_chars,
|
||||
}
|
||||
|
||||
|
||||
def load_model(model_size="base", checkpoint_path=None, resource_path=None):
|
||||
"""Load Whisper model from checkpoint or pretrained weights."""
|
||||
model_dims = MODEL_DIMENSIONS[model_size]
|
||||
model = Whisper(model_dims)
|
||||
|
||||
if checkpoint_path:
|
||||
logger.info(f"Loading model from checkpoint: {checkpoint_path}")
|
||||
state_dict = paddle.load(checkpoint_path)
|
||||
model.set_state_dict(state_dict)
|
||||
elif resource_path:
|
||||
model_path = os.path.join(resource_path, "whisper", f"whisper-{model_size}.pdparams")
|
||||
if os.path.exists(model_path):
|
||||
logger.info(f"Loading pretrained model from: {model_path}")
|
||||
state_dict = paddle.load(model_path)
|
||||
model.set_state_dict(state_dict)
|
||||
else:
|
||||
logger.error(f"Pretrained model not found at {model_path}")
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
else:
|
||||
logger.error("Either checkpoint_path or resource_path must be provided")
|
||||
raise ValueError("Either checkpoint_path or resource_path must be provided")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def evaluate_manifest(
|
||||
model,
|
||||
manifest_path,
|
||||
resource_path,
|
||||
language=None,
|
||||
task="transcribe",
|
||||
temperature=0.0,
|
||||
beam_size=5,
|
||||
patience=1.0,
|
||||
batch_size=16,
|
||||
without_timestamps=True,
|
||||
fp16=False,
|
||||
verbose=True,
|
||||
output_dir=None,
|
||||
max_samples=None
|
||||
):
|
||||
"""Evaluate Whisper model on a manifest file."""
|
||||
# Load manifest
|
||||
with open(manifest_path, 'r', encoding='utf8') as f:
|
||||
manifest_lines = f.readlines()
|
||||
|
||||
# Limit samples if requested
|
||||
if max_samples:
|
||||
manifest_lines = manifest_lines[:max_samples]
|
||||
|
||||
references = []
|
||||
hypotheses = []
|
||||
audio_paths = []
|
||||
durations = []
|
||||
|
||||
# Process each item
|
||||
start_time = time.time()
|
||||
model.eval()
|
||||
|
||||
for i, line in enumerate(manifest_lines):
|
||||
if i % 10 == 0:
|
||||
logger.info(f"Processing item {i+1}/{len(manifest_lines)}")
|
||||
|
||||
item = json.loads(line.strip())
|
||||
audio_path = item["audio"]
|
||||
reference_text = item["text"]
|
||||
|
||||
# Get duration if available
|
||||
duration = item.get("duration", 0)
|
||||
durations.append(duration)
|
||||
|
||||
audio_paths.append(audio_path)
|
||||
references.append(reference_text)
|
||||
|
||||
# Process audio
|
||||
try:
|
||||
mel = log_mel_spectrogram(audio_path, resource_path=resource_path)
|
||||
|
||||
# Setup decoding options
|
||||
decode_options = DecodingOptions(
|
||||
language=language,
|
||||
task=task,
|
||||
temperature=temperature,
|
||||
beam_size=beam_size,
|
||||
patience=patience,
|
||||
without_timestamps=without_timestamps,
|
||||
fp16=fp16,
|
||||
)
|
||||
|
||||
# Run transcription
|
||||
with paddle.no_grad():
|
||||
result = transcribe(
|
||||
model=model,
|
||||
mel=mel,
|
||||
resource_path=resource_path,
|
||||
verbose=False,
|
||||
**decode_options.__dict__,
|
||||
)
|
||||
|
||||
hypotheses.append(result["text"])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {audio_path}: {str(e)}")
|
||||
hypotheses.append("")
|
||||
|
||||
# Compute metrics
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Processed {len(manifest_lines)} examples in {elapsed:.2f} seconds")
|
||||
|
||||
metrics = compute_metrics(references, hypotheses, language=language if language else "en")
|
||||
logger.info(f"WER: {metrics['wer']*100:.2f}%, CER: {metrics['cer']*100:.2f}%")
|
||||
|
||||
# Save results if output directory provided
|
||||
if output_dir:
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save detailed results
|
||||
results = {
|
||||
"metrics": metrics,
|
||||
"details": [
|
||||
{
|
||||
"audio": audio_path,
|
||||
"reference": reference,
|
||||
"hypothesis": hypothesis,
|
||||
"duration": duration
|
||||
}
|
||||
for audio_path, reference, hypothesis, duration in zip(
|
||||
audio_paths, references, hypotheses, durations
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
with open(output_dir / "evaluation_results.json", "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"Detailed evaluation results saved to {output_dir}")
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Evaluate Whisper ASR Model")
|
||||
parser.add_argument("--manifest", type=str, required=True, help="Path to manifest file")
|
||||
parser.add_argument("--output_dir", type=str, default="./results", help="Directory to save results")
|
||||
parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint from fine-tuning")
|
||||
parser.add_argument("--resource_path", type=str, default="./resources",
|
||||
help="Path to resources directory")
|
||||
|
||||
# Model options
|
||||
parser.add_argument("--model_size", type=str, default="base",
|
||||
choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"],
|
||||
help="Model size for original Whisper")
|
||||
|
||||
# Decoding options
|
||||
parser.add_argument("--language", type=str, default=None, help="Language code (e.g., en, fr)")
|
||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"],
|
||||
help="Task: transcribe or translate to English")
|
||||
parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for sampling")
|
||||
parser.add_argument("--beam_size", type=int, default=5, help="Beam size for beam search")
|
||||
parser.add_argument("--patience", type=float, default=1.0, help="Beam search patience factor")
|
||||
parser.add_argument("--without_timestamps", type=str2bool, default=True, help="Don't include timestamps")
|
||||
parser.add_argument("--batch_size", type=int, default=16, help="Batch size for evaluation")
|
||||
parser.add_argument("--fp16", type=str2bool, default=False, help="Use half-precision float16")
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="Whether to display verbose logs")
|
||||
parser.add_argument("--max_samples", type=int, default=None, help="Maximum number of samples to evaluate")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load model
|
||||
model = load_model(
|
||||
model_size=args.model_size,
|
||||
checkpoint_path=args.checkpoint,
|
||||
resource_path=args.resource_path
|
||||
)
|
||||
|
||||
# Evaluate
|
||||
evaluate_manifest(
|
||||
model=model,
|
||||
manifest_path=args.manifest,
|
||||
resource_path=args.resource_path,
|
||||
language=args.language,
|
||||
task=args.task,
|
||||
temperature=args.temperature,
|
||||
beam_size=args.beam_size,
|
||||
patience=args.patience,
|
||||
batch_size=args.batch_size,
|
||||
without_timestamps=args.without_timestamps,
|
||||
fp16=args.fp16,
|
||||
verbose=args.verbose,
|
||||
output_dir=args.output_dir,
|
||||
max_samples=args.max_samples
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,135 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import paddle
|
||||
|
||||
from paddlespeech.s2t.models.whisper.whisper import MODEL_DIMENSIONS, Whisper
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
def export_encoder(model, save_path, input_shape=(1, 80, 3000)):
|
||||
"""Export encoder part of Whisper to inference model."""
|
||||
model.eval()
|
||||
|
||||
# Create save directory if not exists
|
||||
save_dir = os.path.dirname(save_path)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Define input spec
|
||||
mel_spec = paddle.static.InputSpec(shape=input_shape, dtype='float32', name='mel')
|
||||
|
||||
# Export encoder model
|
||||
encoder_path = f"{save_path}_encoder"
|
||||
paddle.jit.save(
|
||||
layer=model.encoder,
|
||||
path=encoder_path,
|
||||
input_spec=[mel_spec]
|
||||
)
|
||||
logger.info(f"Encoder model exported to {encoder_path}")
|
||||
return encoder_path
|
||||
|
||||
|
||||
def export_decoder(model, save_path, input_shape_tokens=(1, 448), input_shape_features=(1, 1500, 768)):
|
||||
"""Export decoder part of Whisper to inference model."""
|
||||
model.eval()
|
||||
|
||||
# Create save directory if not exists
|
||||
save_dir = os.path.dirname(save_path)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Define input spec
|
||||
token_spec = paddle.static.InputSpec(shape=input_shape_tokens, dtype='int64', name='tokens')
|
||||
audio_features_spec = paddle.static.InputSpec(shape=input_shape_features, dtype='float32', name='audio_features')
|
||||
|
||||
# Create a wrapper to match the exact API of the decoder
|
||||
class DecoderWrapper(paddle.nn.Layer):
|
||||
def __init__(self, decoder):
|
||||
super().__init__()
|
||||
self.decoder = decoder
|
||||
|
||||
def forward(self, tokens, audio_features):
|
||||
return self.decoder(tokens, audio_features)
|
||||
|
||||
wrapper = DecoderWrapper(model.decoder)
|
||||
|
||||
# Export decoder model
|
||||
decoder_path = f"{save_path}_decoder"
|
||||
paddle.jit.save(
|
||||
layer=wrapper,
|
||||
path=decoder_path,
|
||||
input_spec=[token_spec, audio_features_spec]
|
||||
)
|
||||
logger.info(f"Decoder model exported to {decoder_path}")
|
||||
return decoder_path
|
||||
|
||||
|
||||
def export_whisper(model, save_path):
|
||||
"""Export full Whisper model to static graph models."""
|
||||
export_encoder(model, save_path)
|
||||
export_decoder(model, save_path)
|
||||
|
||||
# Export model info
|
||||
dims = model.dims
|
||||
model_info = {
|
||||
"n_mels": dims.n_mels,
|
||||
"n_vocab": dims.n_vocab,
|
||||
"n_audio_ctx": dims.n_audio_ctx,
|
||||
"n_audio_state": dims.n_audio_state,
|
||||
"n_audio_head": dims.n_audio_head,
|
||||
"n_audio_layer": dims.n_audio_layer,
|
||||
"n_text_ctx": dims.n_text_ctx,
|
||||
"n_text_state": dims.n_text_state,
|
||||
"n_text_head": dims.n_text_head,
|
||||
"n_text_layer": dims.n_text_layer
|
||||
}
|
||||
|
||||
# Save model info
|
||||
import json
|
||||
with open(f"{save_path}_info.json", "w") as f:
|
||||
json.dump(model_info, f, indent=4)
|
||||
|
||||
logger.info(f"Model info saved to {save_path}_info.json")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Export Whisper model to inference format")
|
||||
parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path to save exported model")
|
||||
parser.add_argument("--model_size", type=str, default="base",
|
||||
choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"],
|
||||
help="Model size")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create model
|
||||
model_dims = MODEL_DIMENSIONS[args.model_size]
|
||||
model = Whisper(model_dims)
|
||||
|
||||
# Load checkpoint
|
||||
state_dict = paddle.load(args.checkpoint)
|
||||
model.set_state_dict(state_dict)
|
||||
|
||||
# Export model
|
||||
export_whisper(model, args.output_path)
|
||||
logger.info(f"Model exported to {args.output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,323 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import yaml
|
||||
|
||||
from paddlespeech.s2t.models.whisper.whisper import (CHUNK_LENGTH, MODEL_DIMENSIONS, N_MELS,
|
||||
DecodingOptions, Whisper,
|
||||
log_mel_spectrogram, pad_or_trim,
|
||||
transcribe)
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
from paddlespeech.s2t.utils.utility import get_files_by_ext, str2bool
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
def load_audio(file_path, sample_rate=16000):
|
||||
"""Load audio file and convert to 16kHz mono if needed."""
|
||||
try:
|
||||
import librosa
|
||||
audio, sr = librosa.load(file_path, sr=sample_rate, mono=True)
|
||||
return audio
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading audio file {file_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def get_model_path(size, resource_path):
|
||||
"""Get the model path based on size."""
|
||||
return os.path.join(resource_path, "whisper", f"whisper-{size}.pdparams")
|
||||
|
||||
|
||||
def load_whisper_model(model_size="base", checkpoint_path=None, resource_path=None):
|
||||
"""Load Whisper model from checkpoint or pretrained weights."""
|
||||
model_dims = MODEL_DIMENSIONS[model_size]
|
||||
model = Whisper(model_dims)
|
||||
|
||||
if checkpoint_path:
|
||||
logger.info(f"Loading model from checkpoint: {checkpoint_path}")
|
||||
state_dict = paddle.load(checkpoint_path)
|
||||
model.set_state_dict(state_dict)
|
||||
elif resource_path:
|
||||
model_path = get_model_path(model_size, resource_path)
|
||||
if os.path.exists(model_path):
|
||||
logger.info(f"Loading pretrained model from: {model_path}")
|
||||
state_dict = paddle.load(model_path)
|
||||
model.set_state_dict(state_dict)
|
||||
else:
|
||||
logger.error(f"Pretrained model not found at {model_path}")
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
else:
|
||||
logger.error("Either checkpoint_path or resource_path must be provided")
|
||||
raise ValueError("Either checkpoint_path or resource_path must be provided")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def transcribe_file(
|
||||
model,
|
||||
audio_file,
|
||||
resource_path,
|
||||
language=None,
|
||||
task="transcribe",
|
||||
temperature=0.0,
|
||||
beam_size=5,
|
||||
best_of=5,
|
||||
patience=1.0,
|
||||
length_penalty=1.0,
|
||||
suppress_tokens="-1",
|
||||
initial_prompt=None,
|
||||
condition_on_previous_text=True,
|
||||
without_timestamps=False,
|
||||
fp16=False,
|
||||
verbose=True,
|
||||
):
|
||||
"""Transcribe a single audio file."""
|
||||
# Check if file exists
|
||||
if not os.path.exists(audio_file):
|
||||
logger.error(f"Audio file not found: {audio_file}")
|
||||
return None
|
||||
|
||||
# Load and process audio
|
||||
logger.info(f"Processing audio file: {audio_file}")
|
||||
try:
|
||||
mel = log_mel_spectrogram(audio_file, resource_path=resource_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio: {str(e)}")
|
||||
return None
|
||||
|
||||
# Setup decoding options
|
||||
decode_options = DecodingOptions(
|
||||
language=language,
|
||||
task=task,
|
||||
temperature=temperature,
|
||||
beam_size=beam_size,
|
||||
best_of=best_of,
|
||||
patience=patience,
|
||||
length_penalty=length_penalty,
|
||||
suppress_tokens=suppress_tokens,
|
||||
prompt=initial_prompt,
|
||||
without_timestamps=without_timestamps,
|
||||
fp16=fp16,
|
||||
)
|
||||
|
||||
# Run transcription
|
||||
logger.info("Running transcription...")
|
||||
start_time = time.time()
|
||||
|
||||
model.eval()
|
||||
with paddle.no_grad():
|
||||
result = transcribe(
|
||||
model=model,
|
||||
mel=mel,
|
||||
resource_path=resource_path,
|
||||
verbose=verbose,
|
||||
condition_on_previous_text=condition_on_previous_text,
|
||||
**decode_options.__dict__,
|
||||
)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Transcription completed in {elapsed:.2f} seconds")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def batch_transcribe(
|
||||
model,
|
||||
audio_dir,
|
||||
output_dir,
|
||||
resource_path,
|
||||
extensions=["wav", "mp3", "flac", "m4a", "ogg"],
|
||||
**transcribe_kwargs
|
||||
):
|
||||
"""Transcribe all audio files in a directory."""
|
||||
# Create output directory
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get list of audio files
|
||||
audio_files = []
|
||||
for ext in extensions:
|
||||
audio_files.extend(get_files_by_ext(audio_dir, ext))
|
||||
|
||||
if not audio_files:
|
||||
logger.error(f"No audio files found in {audio_dir} with extensions {extensions}")
|
||||
return
|
||||
|
||||
logger.info(f"Found {len(audio_files)} audio files to process")
|
||||
|
||||
# Process each file
|
||||
results = {}
|
||||
for audio_file in audio_files:
|
||||
base_name = os.path.basename(audio_file)
|
||||
file_name = os.path.splitext(base_name)[0]
|
||||
|
||||
logger.info(f"Processing file: {base_name}")
|
||||
result = transcribe_file(model, audio_file, resource_path, **transcribe_kwargs)
|
||||
|
||||
if result:
|
||||
# Save transcription
|
||||
output_file = output_dir / f"{file_name}.txt"
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
f.write(result["text"].strip())
|
||||
|
||||
# Save detailed results
|
||||
results[base_name] = {
|
||||
"text": result["text"].strip(),
|
||||
"segments": result.get("segments", []),
|
||||
"language": result.get("language", ""),
|
||||
}
|
||||
|
||||
# Save all results as JSON
|
||||
import json
|
||||
with open(output_dir / "all_transcripts.json", "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"All transcriptions saved to {output_dir}")
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Whisper ASR Inference")
|
||||
parser.add_argument("--audio_file", type=str, help="Path to audio file for transcription")
|
||||
parser.add_argument("--audio_dir", type=str, help="Path to directory containing audio files")
|
||||
parser.add_argument("--output_dir", type=str, default="./transcripts", help="Output directory for transcriptions")
|
||||
parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint from fine-tuning")
|
||||
parser.add_argument("--resource_path", type=str, default="./resources",
|
||||
help="Path to resources directory containing original models and assets")
|
||||
|
||||
# Model options
|
||||
parser.add_argument("--use_original", type=str2bool, default=False,
|
||||
help="Use original Whisper model instead of fine-tuned")
|
||||
parser.add_argument("--model_size", type=str, default="base",
|
||||
choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"],
|
||||
help="Model size for original Whisper")
|
||||
|
||||
# Decoding options
|
||||
parser.add_argument("--language", type=str, default=None, help="Language code (e.g., en, fr, auto for detection)")
|
||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"],
|
||||
help="Task: transcribe or translate to English")
|
||||
parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for sampling")
|
||||
parser.add_argument("--beam_size", type=int, default=5, help="Beam size for beam search")
|
||||
parser.add_argument("--best_of", type=int, default=5, help="Number of candidates when sampling with temp > 0")
|
||||
parser.add_argument("--patience", type=float, default=1.0, help="Beam search patience factor")
|
||||
parser.add_argument("--length_penalty", type=float, default=1.0, help="Exponential length penalty")
|
||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="Comma-separated list of token ids to suppress")
|
||||
parser.add_argument("--initial_prompt", type=str, default=None, help="Optional text to provide as prompt")
|
||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True,
|
||||
help="Whether to condition on previous text")
|
||||
parser.add_argument("--without_timestamps", type=str2bool, default=False, help="Don't include timestamps")
|
||||
parser.add_argument("--fp16", type=str2bool, default=False, help="Use half-precision float16")
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="Whether to display the text being decoded")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate arguments
|
||||
if not args.audio_file and not args.audio_dir:
|
||||
parser.error("Either --audio_file or --audio_dir must be specified")
|
||||
|
||||
if args.use_original:
|
||||
# Use original Whisper model
|
||||
if not args.resource_path:
|
||||
parser.error("--resource_path must be specified when using original model")
|
||||
|
||||
model = load_whisper_model(
|
||||
model_size=args.model_size,
|
||||
resource_path=args.resource_path
|
||||
)
|
||||
else:
|
||||
# Use fine-tuned model
|
||||
if not args.checkpoint:
|
||||
parser.error("--checkpoint must be specified when not using original model")
|
||||
|
||||
# Determine model size from checkpoint directory structure
|
||||
model_size = "base" # Default
|
||||
if "tiny" in args.checkpoint:
|
||||
model_size = "tiny"
|
||||
elif "small" in args.checkpoint:
|
||||
model_size = "small"
|
||||
elif "medium" in args.checkpoint:
|
||||
model_size = "medium"
|
||||
elif "large" in args.checkpoint:
|
||||
model_size = "large"
|
||||
|
||||
model = load_whisper_model(
|
||||
model_size=model_size,
|
||||
checkpoint_path=args.checkpoint,
|
||||
resource_path=args.resource_path
|
||||
)
|
||||
|
||||
# Prepare transcription keyword arguments
|
||||
transcribe_kwargs = {
|
||||
"language": args.language,
|
||||
"task": args.task,
|
||||
"temperature": args.temperature,
|
||||
"beam_size": args.beam_size,
|
||||
"best_of": args.best_of,
|
||||
"patience": args.patience,
|
||||
"length_penalty": args.length_penalty,
|
||||
"suppress_tokens": args.suppress_tokens,
|
||||
"initial_prompt": args.initial_prompt,
|
||||
"condition_on_previous_text": args.condition_on_previous_text,
|
||||
"without_timestamps": args.without_timestamps,
|
||||
"fp16": args.fp16,
|
||||
"verbose": args.verbose,
|
||||
}
|
||||
|
||||
# Run transcription
|
||||
if args.audio_file:
|
||||
# Single file transcription
|
||||
result = transcribe_file(model, args.audio_file, args.resource_path, **transcribe_kwargs)
|
||||
if result:
|
||||
print("-" * 40)
|
||||
print("Transcription:")
|
||||
print(result["text"])
|
||||
print("-" * 40)
|
||||
|
||||
# Save to output directory if specified
|
||||
if args.output_dir:
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_name = os.path.splitext(os.path.basename(args.audio_file))[0]
|
||||
with open(output_dir / f"{file_name}.txt", "w", encoding="utf-8") as f:
|
||||
f.write(result["text"].strip())
|
||||
|
||||
import json
|
||||
with open(output_dir / f"{file_name}.json", "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"Results saved to {output_dir}")
|
||||
|
||||
elif args.audio_dir:
|
||||
# Batch transcription
|
||||
batch_transcribe(
|
||||
model,
|
||||
args.audio_dir,
|
||||
args.output_dir,
|
||||
args.resource_path,
|
||||
**transcribe_kwargs
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,197 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import soundfile
|
||||
import tqdm
|
||||
import yaml
|
||||
|
||||
|
||||
def prepare_common_voice(language: str,
|
||||
output_dir: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
val_size: float = 0.03,
|
||||
test_size: float = 0.03,
|
||||
min_duration: float = 0.5,
|
||||
max_duration: float = 30.0,
|
||||
seed: int = 42):
|
||||
"""
|
||||
Prepare Mozilla Common Voice dataset for Whisper fine-tuning.
|
||||
|
||||
Args:
|
||||
language: Language code (e.g., "en", "fr", "es")
|
||||
output_dir: Directory to save preprocessed data
|
||||
cache_dir: Cache directory for HuggingFace datasets
|
||||
val_size: Validation set size as a fraction of total data
|
||||
test_size: Test set size as a fraction of total data
|
||||
min_duration: Minimum audio duration in seconds
|
||||
max_duration: Maximum audio duration in seconds
|
||||
seed: Random seed for reproducibility
|
||||
"""
|
||||
print(f"Preparing Common Voice dataset for language: {language}")
|
||||
|
||||
# Create output directories
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
(output_dir / "wavs").mkdir(exist_ok=True)
|
||||
|
||||
# Load Common Voice dataset
|
||||
print("Loading Common Voice dataset from HuggingFace...")
|
||||
try:
|
||||
common_voice = datasets.load_dataset(
|
||||
"mozilla-foundation/common_voice_11_0",
|
||||
language,
|
||||
cache_dir=cache_dir,
|
||||
trust_remote_code=True
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error loading dataset: {e}")
|
||||
print("Make sure you have access to the Common Voice dataset on HuggingFace.")
|
||||
return
|
||||
|
||||
# Filter and process data
|
||||
print("Processing dataset...")
|
||||
|
||||
# Function to process and filter examples
|
||||
def process_example(example):
|
||||
audio = example['audio']
|
||||
sample_rate = audio['sampling_rate']
|
||||
|
||||
# Calculate duration
|
||||
duration = len(audio['array']) / sample_rate
|
||||
|
||||
# Filter by duration
|
||||
if duration < min_duration or duration > max_duration:
|
||||
return None
|
||||
|
||||
return {
|
||||
"path": None, # Will be filled later
|
||||
"audio": audio,
|
||||
"text": example['sentence'],
|
||||
"duration": duration,
|
||||
}
|
||||
|
||||
# Process all splits
|
||||
all_data = []
|
||||
for split in ['train', 'validation', 'test']:
|
||||
if split in common_voice:
|
||||
split_data = []
|
||||
for example in tqdm.tqdm(common_voice[split], desc=f"Processing {split} set"):
|
||||
processed = process_example(example)
|
||||
if processed:
|
||||
split_data.append(processed)
|
||||
all_data.extend(split_data)
|
||||
|
||||
# Shuffle and split data
|
||||
random.seed(seed)
|
||||
random.shuffle(all_data)
|
||||
|
||||
total_size = len(all_data)
|
||||
val_count = max(1, int(total_size * val_size))
|
||||
test_count = max(1, int(total_size * test_size))
|
||||
train_count = total_size - val_count - test_count
|
||||
|
||||
train_data = all_data[:train_count]
|
||||
val_data = all_data[train_count:train_count + val_count]
|
||||
test_data = all_data[train_count + val_count:]
|
||||
|
||||
print(f"Dataset split - Train: {len(train_data)}, Dev: {len(val_data)}, Test: {len(test_data)}")
|
||||
|
||||
# Save audio files and create manifest files
|
||||
def save_manifest(data, name):
|
||||
manifest = []
|
||||
for i, item in enumerate(tqdm.tqdm(data, desc=f"Saving {name} files")):
|
||||
# Generate filename
|
||||
filename = f"{name}_{i:08d}.wav"
|
||||
filepath = str(output_dir / "wavs" / filename)
|
||||
|
||||
# Save audio file
|
||||
soundfile.write(
|
||||
filepath,
|
||||
item["audio"]["array"],
|
||||
item["audio"]["sampling_rate"]
|
||||
)
|
||||
|
||||
# Add to manifest
|
||||
manifest_item = {
|
||||
"utt": f"{name}_{i:08d}",
|
||||
"audio": filepath,
|
||||
"text": item["text"],
|
||||
"duration": item["duration"]
|
||||
}
|
||||
manifest.append(manifest_item)
|
||||
|
||||
# Write manifest file
|
||||
with open(output_dir / f"{name}_manifest.json", "w", encoding="utf-8") as f:
|
||||
for item in manifest:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
||||
|
||||
save_manifest(train_data, "train")
|
||||
save_manifest(val_data, "dev")
|
||||
save_manifest(test_data, "test")
|
||||
|
||||
print(f"Dataset preparation complete. Files saved to {output_dir}")
|
||||
|
||||
# Save config
|
||||
stats = {
|
||||
"language": language,
|
||||
"total_examples": total_size,
|
||||
"train_examples": len(train_data),
|
||||
"dev_examples": len(val_data),
|
||||
"test_examples": len(test_data),
|
||||
"min_duration": min_duration,
|
||||
"max_duration": max_duration,
|
||||
}
|
||||
|
||||
with open(output_dir / "stats.yaml", "w") as f:
|
||||
yaml.dump(stats, f)
|
||||
|
||||
print("Data preparation complete!")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Prepare Common Voice dataset for Whisper fine-tuning")
|
||||
parser.add_argument("--language", type=str, default="en", help="Language code (e.g., en, fr, es)")
|
||||
parser.add_argument("--output_dir", type=str, default="./data", help="Directory to save preprocessed data")
|
||||
parser.add_argument("--cache_dir", type=str, default=None, help="Cache directory for HuggingFace datasets")
|
||||
parser.add_argument("--val_size", type=float, default=0.03, help="Validation set size as fraction")
|
||||
parser.add_argument("--test_size", type=float, default=0.03, help="Test set size as fraction")
|
||||
parser.add_argument("--min_duration", type=float, default=0.5, help="Minimum audio duration in seconds")
|
||||
parser.add_argument("--max_duration", type=float, default=30.0, help="Maximum audio duration in seconds")
|
||||
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
prepare_common_voice(
|
||||
language=args.language,
|
||||
output_dir=args.output_dir,
|
||||
cache_dir=args.cache_dir,
|
||||
val_size=args.val_size,
|
||||
test_size=args.test_size,
|
||||
min_duration=args.min_duration,
|
||||
max_duration=args.max_duration,
|
||||
seed=args.seed
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,510 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
import yaml
|
||||
from paddle.io import DataLoader, Dataset
|
||||
from paddle.optimizer import AdamW
|
||||
from paddle.optimizer.lr import CosineAnnealingDecay, LinearWarmup
|
||||
|
||||
from paddlespeech.s2t.models.whisper.tokenizer import get_tokenizer
|
||||
from paddlespeech.s2t.models.whisper.whisper import (CHUNK_LENGTH, MODEL_DIMENSIONS, N_MELS, N_SAMPLES, Whisper,
|
||||
log_mel_spectrogram, pad_or_trim)
|
||||
from paddlespeech.s2t.utils.checkpoint import Checkpoint
|
||||
from paddlespeech.s2t.training.extensions.evaluator import StandardEvaluator
|
||||
from paddlespeech.s2t.training.extensions.reporter import ObsScope, Reporter
|
||||
from paddlespeech.s2t.training.scheduler import LRSchedulerFactory
|
||||
from paddlespeech.s2t.training.trainer import Trainer
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
from paddlespeech.s2t.utils.utility import get_rank, str2bool
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
class WhisperDataset(Dataset):
|
||||
"""Dataset for Whisper fine-tuning"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manifest_path: str,
|
||||
tokenizer,
|
||||
target_language: str = "en",
|
||||
max_duration: float = 30.0,
|
||||
min_duration: float = 0.5,
|
||||
sample_rate: int = 16000,
|
||||
resource_path: str = '',
|
||||
):
|
||||
"""Initialize the dataset.
|
||||
|
||||
Args:
|
||||
manifest_path: Path to manifest file with audio paths and transcripts
|
||||
tokenizer: Whisper tokenizer
|
||||
target_language: Target language code
|
||||
max_duration: Maximum audio duration
|
||||
min_duration: Minimum audio duration
|
||||
sample_rate: Audio sample rate
|
||||
resource_path: Path to resources directory
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.target_language = target_language
|
||||
self.sample_rate = sample_rate
|
||||
self.resource_path = resource_path
|
||||
|
||||
# Load manifest
|
||||
with open(manifest_path, 'r', encoding='utf8') as f:
|
||||
manifest_lines = f.readlines()
|
||||
|
||||
self.data = []
|
||||
for line in manifest_lines:
|
||||
try:
|
||||
item = json.loads(line.strip())
|
||||
duration = item.get('duration', 0)
|
||||
if min_duration <= duration <= max_duration:
|
||||
self.data.append(item)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error parsing line in manifest: {e}")
|
||||
|
||||
logger.info(f"Loaded {len(self.data)} examples from {manifest_path}")
|
||||
|
||||
# Generate special tokens and language tokens
|
||||
self.special_tokens = {
|
||||
"sot": self.tokenizer.sot,
|
||||
"eot": self.tokenizer.eot,
|
||||
"translate": self.tokenizer.translate,
|
||||
"transcribe": self.tokenizer.transcribe,
|
||||
}
|
||||
|
||||
# Get language token
|
||||
self.language_token = self.tokenizer.language_tokens[self.target_language]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
|
||||
# Load audio
|
||||
audio_path = item["audio"]
|
||||
try:
|
||||
# Process audio to mel spectrogram
|
||||
mel = log_mel_spectrogram(audio_path, n_mels=N_MELS, resource_path=self.resource_path)
|
||||
mel = pad_or_trim(mel, N_SAMPLES)
|
||||
|
||||
# Get text
|
||||
text = item["text"]
|
||||
|
||||
# Create prompt tokens
|
||||
prompt_tokens = [
|
||||
self.special_tokens["sot"],
|
||||
self.language_token,
|
||||
self.special_tokens["transcribe"],
|
||||
]
|
||||
|
||||
# Encode the text
|
||||
target_tokens = (
|
||||
prompt_tokens +
|
||||
self.tokenizer.encode(text) +
|
||||
[self.special_tokens["eot"]]
|
||||
)
|
||||
|
||||
return {
|
||||
"mel": mel,
|
||||
"target_tokens": np.array(target_tokens, dtype=np.int64),
|
||||
"text": text
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing {audio_path}: {e}")
|
||||
# Return a dummy sample that will be filtered in collate_fn
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch):
|
||||
"""Collate function for DataLoader"""
|
||||
# Filter None samples
|
||||
batch = [sample for sample in batch if sample is not None]
|
||||
if not batch:
|
||||
return None
|
||||
|
||||
# Get maximum sequence length in this batch
|
||||
max_target_len = max(len(sample["target_tokens"]) for sample in batch)
|
||||
|
||||
# Initialize tensors
|
||||
mel_specs = []
|
||||
token_ids = []
|
||||
labels = []
|
||||
|
||||
for sample in batch:
|
||||
target_tokens = sample["target_tokens"]
|
||||
target_len = len(target_tokens)
|
||||
|
||||
# Prepare inputs and labels for causal LM training
|
||||
# Input tokens are shifted right
|
||||
input_tokens = np.zeros(max_target_len, dtype=np.int64)
|
||||
input_tokens[:target_len-1] = target_tokens[:target_len-1] # Exclude EOS
|
||||
|
||||
# Labels are shifted left and padded with -100 (ignore index)
|
||||
label = np.full(max_target_len, -100, dtype=np.int64)
|
||||
label[:target_len-1] = target_tokens[1:target_len] # Start from first token after SOT
|
||||
|
||||
# Add to lists
|
||||
mel_specs.append(sample["mel"])
|
||||
token_ids.append(input_tokens)
|
||||
labels.append(label)
|
||||
|
||||
# Convert to tensors
|
||||
mel_specs = paddle.to_tensor(np.array(mel_specs), dtype=paddle.float32)
|
||||
token_ids = paddle.to_tensor(np.array(token_ids), dtype=paddle.int64)
|
||||
labels = paddle.to_tensor(np.array(labels), dtype=paddle.int64)
|
||||
|
||||
return {
|
||||
"mel": mel_specs,
|
||||
"tokens": token_ids,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
|
||||
class WhisperTrainer(Trainer):
|
||||
"""Trainer for Whisper fine-tuning"""
|
||||
|
||||
def __init__(self, config, args):
|
||||
"""Initialize the trainer.
|
||||
|
||||
Args:
|
||||
config: Training configuration
|
||||
args: Command-line arguments
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.args = args
|
||||
self.resource_path = args.resource_path
|
||||
|
||||
# Set random seed
|
||||
if args.seed is not None:
|
||||
paddle.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# Initialize distributed training if needed
|
||||
if args.distributed:
|
||||
dist.init_parallel_env()
|
||||
|
||||
# Build model and optimizer
|
||||
self.model = self._build_model()
|
||||
self.optimizer = self._build_optimizer()
|
||||
|
||||
# Build tokenizer
|
||||
self.tokenizer = get_tokenizer(
|
||||
multilingual=self.model.is_multilingual,
|
||||
resource_path=self.resource_path,
|
||||
language=config['data']['target_language'],
|
||||
task="transcribe"
|
||||
)
|
||||
|
||||
# Initialize checkpoint class
|
||||
self._init_checkpoint()
|
||||
|
||||
def _build_model(self):
|
||||
"""Build Whisper model"""
|
||||
config = self.config
|
||||
model_size = config['model']['size']
|
||||
|
||||
# Load model dimensions
|
||||
model_dims = MODEL_DIMENSIONS[model_size]
|
||||
|
||||
# Create the model
|
||||
model = Whisper(model_dims)
|
||||
|
||||
# Load checkpoint if provided
|
||||
checkpoint_path = config['model']['checkpoint']
|
||||
if checkpoint_path:
|
||||
logger.info(f"Loading checkpoint from {checkpoint_path}")
|
||||
state_dict = paddle.load(checkpoint_path)
|
||||
model.set_state_dict(state_dict)
|
||||
|
||||
# Freeze encoder if needed
|
||||
if config['model'].get('freeze_encoder', False):
|
||||
logger.info("Freezing encoder parameters")
|
||||
for param in model.encoder.parameters():
|
||||
param.stop_gradient = True
|
||||
|
||||
# Handle distributed training
|
||||
if self.args.distributed:
|
||||
model = paddle.DataParallel(model)
|
||||
|
||||
return model
|
||||
|
||||
def _build_optimizer(self):
|
||||
"""Build optimizer and learning rate scheduler"""
|
||||
config = self.config
|
||||
model = self.model
|
||||
|
||||
# Set learning rate
|
||||
learning_rate = config['training']['learning_rate']
|
||||
|
||||
# Apply weight decay
|
||||
decay_params = [
|
||||
p.name for n, p in model.named_parameters()
|
||||
if not any(nd in n for nd in ["bias", "norm"])
|
||||
]
|
||||
|
||||
# Build scheduler
|
||||
scheduler_name = config['training'].get('scheduler', 'linear')
|
||||
num_training_steps = self.config['training'].get('max_steps', 100000)
|
||||
|
||||
# Calculate warmup steps
|
||||
warmup_ratio = config['training'].get('warmup_ratio', 0.1)
|
||||
warmup_steps = int(num_training_steps * warmup_ratio)
|
||||
|
||||
if scheduler_name == 'cosine':
|
||||
lr_scheduler = LinearWarmup(
|
||||
CosineAnnealingDecay(learning_rate, num_training_steps - warmup_steps),
|
||||
warmup_steps,
|
||||
0.0,
|
||||
learning_rate
|
||||
)
|
||||
else: # default to linear
|
||||
lr_scheduler = paddle.optimizer.lr.LinearWarmup(
|
||||
paddle.optimizer.lr.PolynomialDecay(
|
||||
learning_rate=learning_rate,
|
||||
decay_steps=num_training_steps - warmup_steps,
|
||||
end_lr=0.0,
|
||||
power=1.0),
|
||||
warmup_steps,
|
||||
0.0,
|
||||
learning_rate
|
||||
)
|
||||
|
||||
# Create optimizer
|
||||
optimizer = AdamW(
|
||||
learning_rate=lr_scheduler,
|
||||
beta1=0.9,
|
||||
beta2=0.999,
|
||||
epsilon=1e-8,
|
||||
parameters=model.parameters(),
|
||||
weight_decay=config['training'].get('weight_decay', 0.01),
|
||||
grad_clip=nn.ClipGradByNorm(config['training'].get('max_grad_norm', 1.0)),
|
||||
apply_decay_param_fun=lambda x: x in decay_params
|
||||
)
|
||||
|
||||
return optimizer
|
||||
|
||||
def _init_checkpoint(self):
|
||||
"""Initialize checkpoint for saving and loading"""
|
||||
config = self.config
|
||||
args = self.args
|
||||
|
||||
checkpoint_dir = Path(config['output']['checkpoint_dir'])
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.checkpoint = Checkpoint(
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
model=self.model,
|
||||
optimizer=self.optimizer,
|
||||
infos=dict(),
|
||||
visualizer=None,
|
||||
**{"epoch": 0}
|
||||
)
|
||||
|
||||
# Try to load from checkpoint if provided
|
||||
if args.checkpoint_path:
|
||||
self.checkpoint.load_parameters(args.checkpoint_path)
|
||||
|
||||
def train(self):
|
||||
"""Run training"""
|
||||
config = self.config
|
||||
args = self.args
|
||||
|
||||
# Create dataset
|
||||
train_dataset = WhisperDataset(
|
||||
manifest_path=config['data']['train_manifest'],
|
||||
tokenizer=self.tokenizer,
|
||||
target_language=config['data']['target_language'],
|
||||
max_duration=config['data']['max_duration'],
|
||||
min_duration=config['data']['min_duration'],
|
||||
resource_path=self.resource_path,
|
||||
)
|
||||
|
||||
dev_dataset = WhisperDataset(
|
||||
manifest_path=config['data']['dev_manifest'],
|
||||
tokenizer=self.tokenizer,
|
||||
target_language=config['data']['target_language'],
|
||||
max_duration=config['data']['max_duration'],
|
||||
min_duration=config['data']['min_duration'],
|
||||
resource_path=self.resource_path,
|
||||
)
|
||||
|
||||
# Create data loaders
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config['training']['batch_size'],
|
||||
shuffle=True,
|
||||
num_workers=config['training'].get('num_workers', 4),
|
||||
collate_fn=WhisperDataset.collate_fn,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
dev_dataloader = DataLoader(
|
||||
dev_dataset,
|
||||
batch_size=config['eval']['eval_batch_size'],
|
||||
shuffle=False,
|
||||
num_workers=config['training'].get('num_workers', 4),
|
||||
collate_fn=WhisperDataset.collate_fn,
|
||||
)
|
||||
|
||||
# Setup training
|
||||
max_epoch = config['training']['max_epoch']
|
||||
accum_grad = config['training'].get('accum_grad', 1)
|
||||
log_interval = config['training'].get('log_interval', 100)
|
||||
save_interval = config['output'].get('save_interval', 1)
|
||||
|
||||
# Initialize reporter for logging
|
||||
reporter = Reporter()
|
||||
reporter.register(self.model, "model")
|
||||
|
||||
# Initialize evaluator
|
||||
evaluator = StandardEvaluator(self.model)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(max_epoch):
|
||||
self.model.train()
|
||||
train_loss = 0
|
||||
num_batches = 0
|
||||
start_time = time.time()
|
||||
|
||||
for batch_idx, batch in enumerate(train_dataloader):
|
||||
if batch is None:
|
||||
continue
|
||||
|
||||
mel = batch["mel"]
|
||||
tokens = batch["tokens"]
|
||||
labels = batch["labels"]
|
||||
|
||||
# Forward pass
|
||||
audio_features = self.model.embed_audio(mel)
|
||||
logits = self.model.logits(tokens, audio_features)
|
||||
|
||||
# Compute loss
|
||||
loss = F.cross_entropy(
|
||||
logits.reshape([-1, logits.shape[-1]]),
|
||||
labels.reshape([-1]),
|
||||
ignore_index=-100
|
||||
)
|
||||
|
||||
# Scale loss for gradient accumulation
|
||||
if accum_grad > 1:
|
||||
loss = loss / accum_grad
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
# Update parameters every accum_grad steps
|
||||
if (batch_idx + 1) % accum_grad == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.clear_grad()
|
||||
|
||||
# Logging
|
||||
train_loss += loss.item() * accum_grad
|
||||
num_batches += 1
|
||||
|
||||
if batch_idx % log_interval == 0 and get_rank() == 0:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(f"Epoch {epoch}/{max_epoch} | Batch {batch_idx}/{len(train_dataloader)} | "
|
||||
f"Loss: {loss.item()*accum_grad:.4f} | {elapsed_time:.2f}s elapsed")
|
||||
|
||||
# End of epoch
|
||||
avg_train_loss = train_loss / num_batches if num_batches > 0 else float('inf')
|
||||
logger.info(f"Epoch {epoch}/{max_epoch} | Average Train Loss: {avg_train_loss:.4f}")
|
||||
|
||||
# Evaluation
|
||||
self.model.eval()
|
||||
dev_losses = []
|
||||
|
||||
with paddle.no_grad():
|
||||
for batch in dev_dataloader:
|
||||
if batch is None:
|
||||
continue
|
||||
|
||||
mel = batch["mel"]
|
||||
tokens = batch["tokens"]
|
||||
labels = batch["labels"]
|
||||
|
||||
# Forward pass
|
||||
audio_features = self.model.embed_audio(mel)
|
||||
logits = self.model.logits(tokens, audio_features)
|
||||
|
||||
# Compute loss
|
||||
loss = F.cross_entropy(
|
||||
logits.reshape([-1, logits.shape[-1]]),
|
||||
labels.reshape([-1]),
|
||||
ignore_index=-100
|
||||
)
|
||||
|
||||
dev_losses.append(loss.item())
|
||||
|
||||
avg_dev_loss = sum(dev_losses) / len(dev_losses) if dev_losses else float('inf')
|
||||
logger.info(f"Epoch {epoch}/{max_epoch} | Validation Loss: {avg_dev_loss:.4f}")
|
||||
|
||||
# Update checkpoint info
|
||||
self.checkpoint.infos["epoch"] = epoch
|
||||
self.checkpoint.infos["train_loss"] = avg_train_loss
|
||||
self.checkpoint.infos["dev_loss"] = avg_dev_loss
|
||||
|
||||
# Save checkpoint
|
||||
if epoch % save_interval == 0 and get_rank() == 0:
|
||||
self.checkpoint.save_parameters(tag=f"epoch_{epoch}")
|
||||
|
||||
# Save final model
|
||||
if get_rank() == 0:
|
||||
self.checkpoint.save_parameters(tag="final")
|
||||
logger.info(f"Training completed. Final model saved at {self.checkpoint.checkpoint_dir}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Train Whisper model")
|
||||
parser.add_argument("--config", type=str, required=True, help="Path to configuration file")
|
||||
parser.add_argument("--resource_path", type=str, default="./resources", help="Path to resources directory")
|
||||
parser.add_argument("--device", type=str, default="gpu", choices=["cpu", "gpu", "xpu"], help="Device to use")
|
||||
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
||||
parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to checkpoint to resume from")
|
||||
parser.add_argument("--distributed", type=str2bool, default=False, help="Enable distributed training")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load configuration
|
||||
with open(args.config) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# Set device
|
||||
paddle.set_device(args.device)
|
||||
|
||||
trainer = WhisperTrainer(config, args)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import json
|
||||
main()
|
@ -0,0 +1,295 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import soundfile as sf
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
def plot_waveform(audio_path, output_dir=None, show=True):
|
||||
"""Plot waveform of audio file."""
|
||||
try:
|
||||
audio, sr = sf.read(audio_path)
|
||||
if audio.ndim > 1:
|
||||
audio = audio[:, 0] # Take first channel if stereo
|
||||
|
||||
duration = len(audio) / sr
|
||||
time_axis = np.linspace(0, duration, len(audio))
|
||||
|
||||
plt.figure(figsize=(12, 4))
|
||||
plt.plot(time_axis, audio, color='#1f77b4')
|
||||
plt.title(f"Waveform: {os.path.basename(audio_path)}")
|
||||
plt.xlabel("Time (seconds)")
|
||||
plt.ylabel("Amplitude")
|
||||
plt.grid(alpha=0.3)
|
||||
|
||||
if output_dir:
|
||||
output_path = os.path.join(output_dir, f"{os.path.splitext(os.path.basename(audio_path))[0]}_waveform.png")
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
logger.info(f"Waveform plot saved to {output_path}")
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close()
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error plotting waveform: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def plot_transcription_comparison(reference, hypothesis, audio_path=None, output_dir=None, show=True):
|
||||
"""Plot comparison between reference and hypothesis transcriptions."""
|
||||
try:
|
||||
fig, ax = plt.subplots(figsize=(12, 6))
|
||||
|
||||
# Title with audio file name if provided
|
||||
title = "Reference vs. Hypothesis"
|
||||
if audio_path:
|
||||
title += f": {os.path.basename(audio_path)}"
|
||||
plt.title(title)
|
||||
|
||||
# Create comparison table
|
||||
rows = []
|
||||
rows.append(["Reference", reference])
|
||||
rows.append(["Hypothesis", hypothesis])
|
||||
|
||||
# Calculate word-level differences
|
||||
ref_words = reference.strip().split()
|
||||
hyp_words = hypothesis.strip().split()
|
||||
|
||||
from difflib import SequenceMatcher
|
||||
matcher = SequenceMatcher(None, ref_words, hyp_words)
|
||||
|
||||
# Generate color-coded text for differences
|
||||
ref_colored = []
|
||||
hyp_colored = []
|
||||
|
||||
for op, i1, i2, j1, j2 in matcher.get_opcodes():
|
||||
if op == 'equal':
|
||||
ref_colored.extend(ref_words[i1:i2])
|
||||
hyp_colored.extend(hyp_words[j1:j2])
|
||||
elif op == 'replace':
|
||||
ref_colored.extend([f"**{w}**" for w in ref_words[i1:i2]])
|
||||
hyp_colored.extend([f"**{w}**" for w in hyp_words[j1:j2]])
|
||||
elif op == 'delete':
|
||||
ref_colored.extend([f"**{w}**" for w in ref_words[i1:i2]])
|
||||
elif op == 'insert':
|
||||
hyp_colored.extend([f"**{w}**" for w in hyp_words[j1:j2]])
|
||||
|
||||
rows.append(["Ref (Diff)", " ".join(ref_colored)])
|
||||
rows.append(["Hyp (Diff)", " ".join(hyp_colored)])
|
||||
|
||||
# Calculate error metrics
|
||||
word_error = sum(1 for i, j in zip(ref_words, hyp_words) if i != j)
|
||||
if len(ref_words) > 0:
|
||||
word_error_rate = word_error / len(ref_words)
|
||||
else:
|
||||
word_error_rate = 0
|
||||
|
||||
rows.append(["Word Error", f"{word_error}/{len(ref_words)} ({word_error_rate:.2%})"])
|
||||
|
||||
# Create table
|
||||
ax.axis('tight')
|
||||
ax.axis('off')
|
||||
table = ax.table(cellText=rows, colWidths=[0.15, 0.85], loc='center', cellLoc='left')
|
||||
table.auto_set_font_size(False)
|
||||
table.set_fontsize(10)
|
||||
table.scale(1, 1.5)
|
||||
|
||||
# Highlight difference rows with light blue background
|
||||
for i in [2, 3]:
|
||||
for j in range(2):
|
||||
cell = table._cells[(i, j)]
|
||||
cell.set_facecolor('#E6F3FF')
|
||||
|
||||
if output_dir:
|
||||
base_name = os.path.splitext(os.path.basename(audio_path))[0] if audio_path else "comparison"
|
||||
output_path = os.path.join(output_dir, f"{base_name}_comparison.png")
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
logger.info(f"Comparison plot saved to {output_path}")
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close()
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error plotting transcription comparison: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def plot_evaluation_metrics(results_file, output_dir=None, show=True):
|
||||
"""Plot evaluation metrics from results file."""
|
||||
try:
|
||||
with open(results_file, 'r', encoding='utf-8') as f:
|
||||
results = json.load(f)
|
||||
|
||||
metrics = results['metrics']
|
||||
details = results['details']
|
||||
|
||||
# Extract duration information
|
||||
durations = [item['duration'] for item in details if 'duration' in item]
|
||||
if not durations:
|
||||
durations = [0] * len(details) # Default if no durations available
|
||||
|
||||
# Calculate per-sample WER
|
||||
from paddlespeech.metrics.wer import word_errors
|
||||
wers = []
|
||||
for item in details:
|
||||
ref = item['reference']
|
||||
hyp = item['hypothesis']
|
||||
word_error_count, word_count = word_errors(ref, hyp)
|
||||
wer = word_error_count / max(1, word_count)
|
||||
wers.append(wer * 100) # Convert to percentage
|
||||
|
||||
# Create plots
|
||||
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
||||
|
||||
# Overall metrics
|
||||
ax = axes[0, 0]
|
||||
metric_names = ['WER', 'CER']
|
||||
metric_values = [metrics['wer'] * 100, metrics['cer'] * 100] # Convert to percentage
|
||||
ax.bar(metric_names, metric_values, color=['#1f77b4', '#ff7f0e'])
|
||||
ax.set_title('Overall Error Metrics')
|
||||
ax.set_ylabel('Error Rate (%)')
|
||||
ax.grid(axis='y', alpha=0.3)
|
||||
for i, v in enumerate(metric_values):
|
||||
ax.text(i, v + 0.5, f"{v:.2f}%", ha='center')
|
||||
|
||||
# WER vs Duration scatter plot
|
||||
ax = axes[0, 1]
|
||||
ax.scatter(durations, wers, alpha=0.6)
|
||||
ax.set_title('WER vs Audio Duration')
|
||||
ax.set_xlabel('Duration (seconds)')
|
||||
ax.set_ylabel('WER (%)')
|
||||
ax.grid(alpha=0.3)
|
||||
|
||||
# WER distribution histogram
|
||||
ax = axes[1, 0]
|
||||
ax.hist(wers, bins=20, color='#2ca02c', alpha=0.7)
|
||||
ax.set_title('WER Distribution')
|
||||
ax.set_xlabel('WER (%)')
|
||||
ax.set_ylabel('Number of Samples')
|
||||
ax.grid(alpha=0.3)
|
||||
|
||||
# Sample count by WER range
|
||||
ax = axes[1, 1]
|
||||
wer_ranges = [0, 5, 10, 20, 30, 50, 100]
|
||||
wer_bins = pd.cut(wers, wer_ranges, right=False)
|
||||
wer_counts = pd.value_counts(wer_bins).sort_index()
|
||||
|
||||
wer_labels = [f"{wer_ranges[i]}-{wer_ranges[i+1]}%" for i in range(len(wer_ranges)-1)]
|
||||
ax.bar(wer_labels, wer_counts, color='#d62728')
|
||||
ax.set_title('Sample Count by WER Range')
|
||||
ax.set_xlabel('WER Range')
|
||||
ax.set_ylabel('Number of Samples')
|
||||
ax.tick_params(axis='x', rotation=45)
|
||||
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
|
||||
ax.grid(axis='y', alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if output_dir:
|
||||
output_path = os.path.join(output_dir, "evaluation_metrics.png")
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
logger.info(f"Metrics plots saved to {output_path}")
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close()
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error plotting evaluation metrics: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def visualize_results(results_file, output_dir=None, audio_dir=None, num_samples=5, show=True):
|
||||
"""Visualize evaluation results."""
|
||||
try:
|
||||
# Create output directory if needed
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Load results
|
||||
with open(results_file, 'r', encoding='utf-8') as f:
|
||||
results = json.load(f)
|
||||
|
||||
# Plot overall metrics
|
||||
plot_evaluation_metrics(results_file, output_dir, show)
|
||||
|
||||
# Plot individual examples
|
||||
if num_samples > 0:
|
||||
details = results['details']
|
||||
samples = details[:num_samples]
|
||||
|
||||
for i, sample in enumerate(samples):
|
||||
audio_path = sample['audio']
|
||||
reference = sample['reference']
|
||||
hypothesis = sample['hypothesis']
|
||||
|
||||
# If audio directory is provided, use audio files from there
|
||||
if audio_dir:
|
||||
base_name = os.path.basename(audio_path)
|
||||
audio_path = os.path.join(audio_dir, base_name)
|
||||
|
||||
# Plot waveform if audio file exists
|
||||
if os.path.exists(audio_path):
|
||||
plot_waveform(audio_path, output_dir, show)
|
||||
|
||||
# Plot transcription comparison
|
||||
plot_transcription_comparison(reference, hypothesis, audio_path, output_dir, show)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error visualizing results: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Visualize Whisper evaluation results")
|
||||
parser.add_argument("--results_file", type=str, required=True, help="Path to evaluation results JSON file")
|
||||
parser.add_argument("--output_dir", type=str, default="./visualizations", help="Directory to save visualization outputs")
|
||||
parser.add_argument("--audio_dir", type=str, default=None, help="Directory containing audio files (optional)")
|
||||
parser.add_argument("--num_samples", type=int, default=5, help="Number of individual samples to visualize")
|
||||
parser.add_argument("--show", action="store_true", help="Show plots interactively")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
visualize_results(
|
||||
results_file=args.results_file,
|
||||
output_dir=args.output_dir,
|
||||
audio_dir=args.audio_dir,
|
||||
num_samples=args.num_samples,
|
||||
show=args.show
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,240 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import paddle
|
||||
import yaml
|
||||
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
from paddlespeech.s2t.utils.utility import str2bool
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
def prepare_parser():
|
||||
"""Prepare argument parser."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Whisper CLI for fine-tuning, inference, and evaluation",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
||||
# Create subparsers for different commands
|
||||
subparsers = parser.add_subparsers(dest="command", help="Command to run")
|
||||
|
||||
# Data preparation command
|
||||
data_parser = subparsers.add_parser("prepare", help="Prepare data for fine-tuning")
|
||||
data_parser.add_argument("--language", type=str, default="en", help="Language code (e.g., en, fr, es)")
|
||||
data_parser.add_argument("--output_dir", type=str, default="./data", help="Directory to save preprocessed data")
|
||||
data_parser.add_argument("--cache_dir", type=str, default=None, help="Cache directory for HuggingFace datasets")
|
||||
data_parser.add_argument("--val_size", type=float, default=0.03, help="Validation set size as fraction")
|
||||
data_parser.add_argument("--test_size", type=float, default=0.03, help="Test set size as fraction")
|
||||
data_parser.add_argument("--min_duration", type=float, default=0.5, help="Minimum audio duration in seconds")
|
||||
data_parser.add_argument("--max_duration", type=float, default=30.0, help="Maximum audio duration in seconds")
|
||||
data_parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
||||
|
||||
# Fine-tuning command
|
||||
train_parser = subparsers.add_parser("train", help="Fine-tune Whisper model")
|
||||
train_parser.add_argument("--config", type=str, required=True, help="Path to configuration file")
|
||||
train_parser.add_argument("--resource_path", type=str, default="./resources", help="Path to resources directory")
|
||||
train_parser.add_argument("--device", type=str, default="gpu", choices=["cpu", "gpu", "xpu"], help="Device to use")
|
||||
train_parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
||||
train_parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to checkpoint to resume from")
|
||||
train_parser.add_argument("--distributed", type=str2bool, default=False, help="Enable distributed training")
|
||||
|
||||
# Inference command
|
||||
infer_parser = subparsers.add_parser("infer", help="Inference with Whisper model")
|
||||
infer_parser.add_argument("--audio_file", type=str, help="Path to audio file for transcription")
|
||||
infer_parser.add_argument("--audio_dir", type=str, help="Path to directory containing audio files")
|
||||
infer_parser.add_argument("--output_dir", type=str, default="./transcripts", help="Output directory for transcriptions")
|
||||
infer_parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint from fine-tuning")
|
||||
infer_parser.add_argument("--resource_path", type=str, default="./resources",
|
||||
help="Path to resources directory containing original models and assets")
|
||||
|
||||
# Model options for inference
|
||||
infer_parser.add_argument("--use_original", type=str2bool, default=False,
|
||||
help="Use original Whisper model instead of fine-tuned")
|
||||
infer_parser.add_argument("--model_size", type=str, default="base",
|
||||
choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"],
|
||||
help="Model size for original Whisper")
|
||||
|
||||
# Decoding options for inference
|
||||
infer_parser.add_argument("--language", type=str, default=None, help="Language code (e.g., en, fr, auto for detection)")
|
||||
infer_parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"],
|
||||
help="Task: transcribe or translate to English")
|
||||
infer_parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for sampling")
|
||||
infer_parser.add_argument("--beam_size", type=int, default=5, help="Beam size for beam search")
|
||||
infer_parser.add_argument("--without_timestamps", type=str2bool, default=False, help="Don't include timestamps")
|
||||
|
||||
# Evaluation command
|
||||
eval_parser = subparsers.add_parser("evaluate", help="Evaluate Whisper model")
|
||||
eval_parser.add_argument("--manifest", type=str, required=True, help="Path to manifest file")
|
||||
eval_parser.add_argument("--output_dir", type=str, default="./results", help="Directory to save results")
|
||||
eval_parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint from fine-tuning")
|
||||
eval_parser.add_argument("--resource_path", type=str, default="./resources", help="Path to resources directory")
|
||||
|
||||
# Model options for evaluation
|
||||
eval_parser.add_argument("--model_size", type=str, default="base",
|
||||
choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"],
|
||||
help="Model size for original Whisper")
|
||||
|
||||
# Decoding options for evaluation
|
||||
eval_parser.add_argument("--language", type=str, default=None, help="Language code (e.g., en, fr)")
|
||||
eval_parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"],
|
||||
help="Task: transcribe or translate to English")
|
||||
eval_parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for sampling")
|
||||
eval_parser.add_argument("--beam_size", type=int, default=5, help="Beam size for beam search")
|
||||
eval_parser.add_argument("--without_timestamps", type=str2bool, default=True, help="Don't include timestamps")
|
||||
eval_parser.add_argument("--max_samples", type=int, default=None, help="Maximum number of samples to evaluate")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = prepare_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
# Execute the appropriate command
|
||||
if args.command == "prepare":
|
||||
from prepare_data import prepare_common_voice
|
||||
prepare_common_voice(
|
||||
language=args.language,
|
||||
output_dir=args.output_dir,
|
||||
cache_dir=args.cache_dir,
|
||||
val_size=args.val_size,
|
||||
test_size=args.test_size,
|
||||
min_duration=args.min_duration,
|
||||
max_duration=args.max_duration,
|
||||
seed=args.seed
|
||||
)
|
||||
|
||||
elif args.command == "train":
|
||||
import train
|
||||
if args.distributed:
|
||||
import paddle.distributed as dist
|
||||
dist.init_parallel_env()
|
||||
|
||||
# Load configuration
|
||||
with open(args.config) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# Set device
|
||||
paddle.set_device(args.device)
|
||||
|
||||
trainer = train.WhisperTrainer(config, args)
|
||||
trainer.train()
|
||||
|
||||
elif args.command == "infer":
|
||||
from infer import load_whisper_model, transcribe_file, batch_transcribe
|
||||
|
||||
# Validate arguments
|
||||
if not args.audio_file and not args.audio_dir:
|
||||
logger.error("Either --audio_file or --audio_dir must be specified")
|
||||
sys.exit(1)
|
||||
|
||||
if args.use_original:
|
||||
# Use original Whisper model
|
||||
if not args.resource_path:
|
||||
logger.error("--resource_path must be specified when using original model")
|
||||
sys.exit(1)
|
||||
|
||||
model = load_whisper_model(
|
||||
model_size=args.model_size,
|
||||
resource_path=args.resource_path
|
||||
)
|
||||
else:
|
||||
# Use fine-tuned model
|
||||
if not args.checkpoint:
|
||||
logger.error("--checkpoint must be specified when not using original model")
|
||||
sys.exit(1)
|
||||
|
||||
model = load_whisper_model(
|
||||
model_size=args.model_size,
|
||||
checkpoint_path=args.checkpoint,
|
||||
resource_path=args.resource_path
|
||||
)
|
||||
|
||||
# Prepare transcription keyword arguments
|
||||
transcribe_kwargs = {
|
||||
"language": args.language,
|
||||
"task": args.task,
|
||||
"temperature": args.temperature,
|
||||
"beam_size": args.beam_size,
|
||||
"without_timestamps": args.without_timestamps,
|
||||
}
|
||||
|
||||
# Run transcription
|
||||
if args.audio_file:
|
||||
result = transcribe_file(model, args.audio_file, args.resource_path, **transcribe_kwargs)
|
||||
if result:
|
||||
print("-" * 40)
|
||||
print("Transcription:")
|
||||
print(result["text"])
|
||||
print("-" * 40)
|
||||
|
||||
# Save to output directory if specified
|
||||
if args.output_dir:
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_name = os.path.splitext(os.path.basename(args.audio_file))[0]
|
||||
with open(output_dir / f"{file_name}.txt", "w", encoding="utf-8") as f:
|
||||
f.write(result["text"].strip())
|
||||
|
||||
import json
|
||||
with open(output_dir / f"{file_name}.json", "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"Results saved to {output_dir}")
|
||||
|
||||
elif args.audio_dir:
|
||||
batch_transcribe(
|
||||
model,
|
||||
args.audio_dir,
|
||||
args.output_dir,
|
||||
args.resource_path,
|
||||
**transcribe_kwargs
|
||||
)
|
||||
|
||||
elif args.command == "evaluate":
|
||||
from evaluate import load_model, evaluate_manifest
|
||||
|
||||
model = load_model(
|
||||
model_size=args.model_size,
|
||||
checkpoint_path=args.checkpoint,
|
||||
resource_path=args.resource_path
|
||||
)
|
||||
|
||||
evaluate_manifest(
|
||||
model=model,
|
||||
manifest_path=args.manifest,
|
||||
resource_path=args.resource_path,
|
||||
language=args.language,
|
||||
task=args.task,
|
||||
temperature=args.temperature,
|
||||
beam_size=args.beam_size,
|
||||
without_timestamps=args.without_timestamps,
|
||||
output_dir=args.output_dir,
|
||||
max_samples=args.max_samples
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,41 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Whisper fine-tuning module.
|
||||
|
||||
This module provides utilities and classes for fine-tuning Whisper models
|
||||
on custom datasets, following the paper "Robust Speech Recognition via
|
||||
Large-Scale Weak Supervision" and Hugging Face's fine-tuning approach.
|
||||
|
||||
References:
|
||||
- Radford, A. et al. (2022). Robust Speech Recognition via Large-Scale Weak Supervision.
|
||||
- https://github.com/openai/whisper
|
||||
- https://huggingface.co/blog/fine-tune-whisper
|
||||
- Whisper Original Paper: https://arxiv.org/abs/2212.04356
|
||||
"""
|
||||
|
||||
from paddlespeech.s2t.models.whisper.fine_tune.dataset import (
|
||||
WhisperDataset,
|
||||
WhisperInferenceDataset
|
||||
)
|
||||
|
||||
from paddlespeech.s2t.models.whisper.fine_tune.trainer import (
|
||||
WhisperTrainer
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'WhisperDataset',
|
||||
'WhisperInferenceDataset',
|
||||
'WhisperTrainer',
|
||||
]
|
@ -0,0 +1,360 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle.io import Dataset, DataLoader
|
||||
|
||||
from paddlespeech.s2t.models.whisper.whisper import (N_MELS, N_SAMPLES, log_mel_spectrogram, pad_or_trim)
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
class WhisperDataset(Dataset):
|
||||
"""Dataset for Whisper fine-tuning"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manifest_path: str,
|
||||
tokenizer,
|
||||
target_language: str = "en",
|
||||
task: str = "transcribe",
|
||||
max_duration: float = 30.0,
|
||||
min_duration: float = 0.5,
|
||||
sample_rate: int = 16000,
|
||||
resource_path: str = '',
|
||||
pad_to_max_length: bool = False,
|
||||
):
|
||||
"""Initialize the dataset.
|
||||
|
||||
Args:
|
||||
manifest_path: Path to manifest file with audio paths and transcripts
|
||||
tokenizer: Whisper tokenizer
|
||||
target_language: Target language code
|
||||
task: Task type, either 'transcribe' or 'translate'
|
||||
max_duration: Maximum audio duration
|
||||
min_duration: Minimum audio duration
|
||||
sample_rate: Audio sample rate
|
||||
resource_path: Path to resources directory
|
||||
pad_to_max_length: Whether to pad all sequences to maximum length in batch
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.target_language = target_language
|
||||
self.task = task
|
||||
self.sample_rate = sample_rate
|
||||
self.resource_path = resource_path
|
||||
self.pad_to_max_length = pad_to_max_length
|
||||
|
||||
# Load manifest
|
||||
with open(manifest_path, 'r', encoding='utf8') as f:
|
||||
manifest_lines = f.readlines()
|
||||
|
||||
self.data = []
|
||||
for line in manifest_lines:
|
||||
try:
|
||||
item = json.loads(line.strip())
|
||||
duration = item.get('duration', 0)
|
||||
if min_duration <= duration <= max_duration:
|
||||
self.data.append(item)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error parsing line in manifest: {e}")
|
||||
|
||||
logger.info(f"Loaded {len(self.data)} examples from {manifest_path}")
|
||||
|
||||
# Generate special tokens and language tokens
|
||||
self.special_tokens = {
|
||||
"sot": self.tokenizer.sot,
|
||||
"eot": self.tokenizer.eot,
|
||||
"translate": self.tokenizer.translate if hasattr(self.tokenizer, "translate") else None,
|
||||
"transcribe": self.tokenizer.transcribe if hasattr(self.tokenizer, "transcribe") else None,
|
||||
"no_timestamps": self.tokenizer.no_timestamps if hasattr(self.tokenizer, "no_timestamps") else None,
|
||||
}
|
||||
|
||||
# Get language token
|
||||
self.language_token = self.tokenizer.language_tokens.get(self.target_language) if hasattr(self.tokenizer, "language_tokens") else None
|
||||
if not self.language_token and hasattr(self.tokenizer, "language_token"):
|
||||
self.language_token = self.tokenizer.language_token(self.target_language)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
|
||||
# Load audio
|
||||
audio_path = item["audio"]
|
||||
try:
|
||||
# Process audio to mel spectrogram
|
||||
mel = log_mel_spectrogram(audio_path, n_mels=N_MELS, resource_path=self.resource_path)
|
||||
mel = pad_or_trim(mel, N_SAMPLES)
|
||||
|
||||
# Get text
|
||||
text = item["text"]
|
||||
|
||||
# Create prompt tokens
|
||||
prompt_tokens = [self.special_tokens["sot"]]
|
||||
|
||||
# Add language token if available
|
||||
if self.language_token is not None:
|
||||
prompt_tokens.append(self.language_token)
|
||||
|
||||
# Add task token if available
|
||||
task_token = self.special_tokens.get(self.task)
|
||||
if task_token is not None:
|
||||
prompt_tokens.append(task_token)
|
||||
|
||||
# Add no_timestamps token if available
|
||||
if self.special_tokens["no_timestamps"] is not None:
|
||||
prompt_tokens.append(self.special_tokens["no_timestamps"])
|
||||
|
||||
# Encode the text
|
||||
target_tokens = (
|
||||
prompt_tokens +
|
||||
self.tokenizer.encode(text) +
|
||||
[self.special_tokens["eot"]]
|
||||
)
|
||||
|
||||
return {
|
||||
"mel": mel,
|
||||
"target_tokens": np.array(target_tokens, dtype=np.int64),
|
||||
"text": text,
|
||||
"audio_path": audio_path
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing {audio_path}: {e}")
|
||||
# Return a dummy sample that will be filtered in collate_fn
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch, pad_idx=-100):
|
||||
"""Collate function for DataLoader"""
|
||||
# Filter None samples
|
||||
batch = [sample for sample in batch if sample is not None]
|
||||
if not batch:
|
||||
return None
|
||||
|
||||
# Get maximum sequence length in this batch
|
||||
max_target_len = max(len(sample["target_tokens"]) for sample in batch)
|
||||
|
||||
# Initialize tensors
|
||||
mel_specs = []
|
||||
token_ids = []
|
||||
labels = []
|
||||
|
||||
# Also collect metadata for debugging/logging
|
||||
texts = []
|
||||
audio_paths = []
|
||||
|
||||
for sample in batch:
|
||||
target_tokens = sample["target_tokens"]
|
||||
target_len = len(target_tokens)
|
||||
|
||||
# Prepare inputs and labels for causal LM training
|
||||
# Input tokens are shifted right
|
||||
input_tokens = np.zeros(max_target_len, dtype=np.int64)
|
||||
input_tokens[:target_len-1] = target_tokens[:target_len-1] # Exclude EOS
|
||||
|
||||
# Labels are shifted left and padded with pad_idx (ignore index)
|
||||
label = np.full(max_target_len, pad_idx, dtype=np.int64)
|
||||
label[:target_len-1] = target_tokens[1:target_len] # Start from first token after SOT
|
||||
|
||||
# Add to lists
|
||||
mel_specs.append(sample["mel"])
|
||||
token_ids.append(input_tokens)
|
||||
labels.append(label)
|
||||
|
||||
# Collect metadata
|
||||
texts.append(sample["text"])
|
||||
audio_paths.append(sample["audio_path"])
|
||||
|
||||
# Convert to tensors
|
||||
mel_specs = paddle.to_tensor(np.array(mel_specs), dtype=paddle.float32)
|
||||
token_ids = paddle.to_tensor(np.array(token_ids), dtype=paddle.int64)
|
||||
labels = paddle.to_tensor(np.array(labels), dtype=paddle.int64)
|
||||
|
||||
return {
|
||||
"mel": mel_specs,
|
||||
"tokens": token_ids,
|
||||
"labels": labels,
|
||||
"texts": texts,
|
||||
"audio_paths": audio_paths,
|
||||
}
|
||||
|
||||
def create_dataloader(self,
|
||||
batch_size=16,
|
||||
num_workers=4,
|
||||
shuffle=True,
|
||||
drop_last=False):
|
||||
"""Create a dataloader from this dataset"""
|
||||
return DataLoader(
|
||||
self,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
num_workers=num_workers,
|
||||
collate_fn=self.collate_fn,
|
||||
drop_last=drop_last
|
||||
)
|
||||
|
||||
|
||||
class WhisperInferenceDataset(Dataset):
|
||||
"""Dataset for Whisper inference with batching support"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_paths: Union[str, List[str]],
|
||||
tokenizer=None,
|
||||
language: Optional[str] = None,
|
||||
task: str = "transcribe",
|
||||
sample_rate: int = 16000,
|
||||
resource_path: str = '',
|
||||
):
|
||||
"""Initialize the inference dataset.
|
||||
|
||||
Args:
|
||||
audio_paths: Path to audio file or directory, or list of audio file paths
|
||||
tokenizer: Whisper tokenizer (optional, only needed for preparing decoder input)
|
||||
language: Language code (e.g., 'en', 'fr')
|
||||
task: Task type, either 'transcribe' or 'translate'
|
||||
sample_rate: Audio sample rate
|
||||
resource_path: Path to resources directory
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.language = language
|
||||
self.task = task
|
||||
self.sample_rate = sample_rate
|
||||
self.resource_path = resource_path
|
||||
|
||||
# Process audio paths
|
||||
if isinstance(audio_paths, str):
|
||||
# Single file
|
||||
if os.path.isfile(audio_paths):
|
||||
self.audio_paths = [audio_paths]
|
||||
# Directory
|
||||
elif os.path.isdir(audio_paths):
|
||||
self.audio_paths = [
|
||||
os.path.join(audio_paths, f) for f in os.listdir(audio_paths)
|
||||
if f.endswith(('.wav', '.mp3', '.flac', '.ogg'))
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Path not found: {audio_paths}")
|
||||
else:
|
||||
# List of files
|
||||
self.audio_paths = audio_paths
|
||||
|
||||
logger.info(f"Loaded {len(self.audio_paths)} audio files for inference")
|
||||
|
||||
# Generate special tokens and language tokens if tokenizer provided
|
||||
self.special_tokens = None
|
||||
self.language_token = None
|
||||
|
||||
if self.tokenizer:
|
||||
self.special_tokens = {
|
||||
"sot": self.tokenizer.sot,
|
||||
"eot": self.tokenizer.eot,
|
||||
"translate": self.tokenizer.translate if hasattr(self.tokenizer, "translate") else None,
|
||||
"transcribe": self.tokenizer.transcribe if hasattr(self.tokenizer, "transcribe") else None,
|
||||
"no_timestamps": self.tokenizer.no_timestamps if hasattr(self.tokenizer, "no_timestamps") else None,
|
||||
}
|
||||
|
||||
# Get language token
|
||||
if self.language and hasattr(self.tokenizer, "language_tokens"):
|
||||
self.language_token = self.tokenizer.language_tokens.get(self.language)
|
||||
elif self.language and hasattr(self.tokenizer, "language_token"):
|
||||
self.language_token = self.tokenizer.language_token(self.language)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audio_paths)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
audio_path = self.audio_paths[idx]
|
||||
|
||||
# Compute mel spectrogram
|
||||
try:
|
||||
mel = log_mel_spectrogram(audio_path, self.resource_path, self.sample_rate)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio file {audio_path}: {e}")
|
||||
# Return zero tensor with correct shape in case of error
|
||||
mel = paddle.zeros((N_MELS, N_SAMPLES // 160))
|
||||
|
||||
return {
|
||||
"mel": mel.numpy(),
|
||||
"audio_path": audio_path,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch):
|
||||
"""Collate function for inference DataLoader"""
|
||||
# Extract mel spectrograms and audio paths
|
||||
mel_specs = []
|
||||
audio_paths = []
|
||||
|
||||
for sample in batch:
|
||||
mel_specs.append(sample["mel"])
|
||||
audio_paths.append(sample["audio_path"])
|
||||
|
||||
# Convert to tensors
|
||||
mel_specs = paddle.to_tensor(np.array(mel_specs), dtype=paddle.float32)
|
||||
|
||||
return {
|
||||
"mel": mel_specs,
|
||||
"audio_paths": audio_paths,
|
||||
}
|
||||
|
||||
def create_dataloader(self,
|
||||
batch_size=1,
|
||||
num_workers=4):
|
||||
"""Create a dataloader from this inference dataset"""
|
||||
return DataLoader(
|
||||
self,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
collate_fn=self.collate_fn,
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
def prepare_decoder_input(self):
|
||||
"""Prepare decoder input tokens for initial prompt
|
||||
|
||||
Only used when tokenizer is provided.
|
||||
"""
|
||||
if not self.tokenizer:
|
||||
logger.warning("Cannot prepare decoder input without tokenizer")
|
||||
return None
|
||||
|
||||
# Create initial tokens - similar to training but without labels
|
||||
tokens = [self.special_tokens["sot"]]
|
||||
|
||||
if self.language_token:
|
||||
tokens.append(self.language_token)
|
||||
|
||||
if self.task == "translate":
|
||||
tokens.append(self.special_tokens["translate"])
|
||||
elif self.task == "transcribe":
|
||||
tokens.append(self.special_tokens["transcribe"])
|
||||
|
||||
if self.special_tokens["no_timestamps"]:
|
||||
tokens.append(self.special_tokens["no_timestamps"])
|
||||
|
||||
return paddle.to_tensor([tokens], dtype=paddle.int64)
|
@ -0,0 +1,380 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
import paddle.distributed as dist
|
||||
from paddle.io import DataLoader
|
||||
from paddle.optimizer import AdamW
|
||||
from paddle.optimizer.lr import CosineAnnealingDecay, LinearWarmup
|
||||
|
||||
from paddlespeech.s2t.models.whisper.tokenizer import get_tokenizer
|
||||
from paddlespeech.s2t.models.whisper.whisper import MODEL_DIMENSIONS, Whisper
|
||||
from paddlespeech.s2t.training.reporter import ObsScope, Reporter
|
||||
from paddlespeech.s2t.training.timer import Timer
|
||||
from paddlespeech.s2t.utils.checkpoint import Checkpoint
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
from paddlespeech.s2t.utils.utility import get_rank
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
class WhisperTrainer:
|
||||
"""Trainer for fine-tuning Whisper models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Dict,
|
||||
model: Optional[Whisper] = None,
|
||||
optimizer: Optional[paddle.optimizer.Optimizer] = None,
|
||||
checkpoint_dir: Optional[Union[str, Path]] = None,
|
||||
resource_path: Optional[str] = None,
|
||||
log_interval: int = 10,
|
||||
save_interval: int = 1,
|
||||
rank: int = 0,
|
||||
world_size: int = 1,
|
||||
):
|
||||
"""Initialize the trainer.
|
||||
|
||||
Args:
|
||||
config: Training configuration dictionary
|
||||
model: Whisper model instance, or None to build one from config
|
||||
optimizer: Optimizer instance, or None to build one from config
|
||||
checkpoint_dir: Directory to save checkpoints
|
||||
resource_path: Path to resources directory containing whisper assets
|
||||
log_interval: Steps between logging
|
||||
save_interval: Epochs between checkpoint saving
|
||||
rank: Local rank for distributed training
|
||||
world_size: Total number of processes for distributed training
|
||||
"""
|
||||
self.config = config
|
||||
self.resource_path = resource_path
|
||||
self.log_interval = log_interval
|
||||
self.save_interval = save_interval
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
|
||||
# Initialize model if not provided
|
||||
self.model = model if model is not None else self._build_model()
|
||||
|
||||
# Initialize tokenizer
|
||||
self.tokenizer = self._build_tokenizer()
|
||||
|
||||
# Initialize optimizer if not provided
|
||||
self.optimizer = optimizer if optimizer is not None else self._build_optimizer()
|
||||
|
||||
# Initialize checkpoint
|
||||
self.checkpoint = self._init_checkpoint(checkpoint_dir)
|
||||
|
||||
# Initialize reporter and timer
|
||||
self.reporter = Reporter()
|
||||
self.timer = Timer()
|
||||
|
||||
def _build_model(self) -> Whisper:
|
||||
"""Build Whisper model from config."""
|
||||
model_config = self.config['model']
|
||||
model_size = model_config['size']
|
||||
|
||||
# Load model dimensions
|
||||
model_dims = MODEL_DIMENSIONS[model_size]
|
||||
|
||||
# Create the model
|
||||
model = Whisper(model_dims)
|
||||
|
||||
# Load checkpoint if provided
|
||||
checkpoint_path = model_config.get('checkpoint')
|
||||
if checkpoint_path:
|
||||
logger.info(f"Loading model weights from {checkpoint_path}")
|
||||
state_dict = paddle.load(checkpoint_path)
|
||||
model.set_state_dict(state_dict)
|
||||
|
||||
# Freeze encoder if needed
|
||||
if model_config.get('freeze_encoder', False):
|
||||
logger.info("Freezing encoder parameters")
|
||||
for param in model.encoder.parameters():
|
||||
param.stop_gradient = True
|
||||
|
||||
# Handle distributed training
|
||||
if self.world_size > 1:
|
||||
logger.info(f"Initializing distributed model with {self.world_size} processes")
|
||||
model = paddle.DataParallel(model)
|
||||
|
||||
return model
|
||||
|
||||
def _build_tokenizer(self):
|
||||
"""Build Whisper tokenizer."""
|
||||
target_language = self.config['data'].get('target_language', 'en')
|
||||
task = self.config['data'].get('task', 'transcribe')
|
||||
|
||||
return get_tokenizer(
|
||||
multilingual=self.model.is_multilingual,
|
||||
resource_path=self.resource_path,
|
||||
language=target_language,
|
||||
task=task
|
||||
)
|
||||
|
||||
def _build_optimizer(self) -> paddle.optimizer.Optimizer:
|
||||
"""Build optimizer and learning rate scheduler."""
|
||||
training_config = self.config['training']
|
||||
|
||||
# Set learning rate
|
||||
learning_rate = training_config['learning_rate']
|
||||
|
||||
# Apply weight decay
|
||||
decay_params = [
|
||||
p.name for n, p in self.model.named_parameters()
|
||||
if not any(nd in n for nd in ["bias", "norm"])
|
||||
]
|
||||
|
||||
# Build scheduler
|
||||
scheduler_name = training_config.get('scheduler', 'linear')
|
||||
max_steps = training_config.get('max_steps', 100000)
|
||||
|
||||
# Calculate warmup steps
|
||||
warmup_ratio = training_config.get('warmup_ratio', 0.1)
|
||||
warmup_steps = int(max_steps * warmup_ratio)
|
||||
|
||||
if scheduler_name == 'cosine':
|
||||
lr_scheduler = LinearWarmup(
|
||||
CosineAnnealingDecay(learning_rate, max_steps - warmup_steps),
|
||||
warmup_steps,
|
||||
0.0,
|
||||
learning_rate
|
||||
)
|
||||
else: # default to linear
|
||||
lr_scheduler = paddle.optimizer.lr.LinearWarmup(
|
||||
paddle.optimizer.lr.PolynomialDecay(
|
||||
learning_rate=learning_rate,
|
||||
decay_steps=max_steps - warmup_steps,
|
||||
end_lr=0.0,
|
||||
power=1.0),
|
||||
warmup_steps,
|
||||
0.0,
|
||||
learning_rate
|
||||
)
|
||||
|
||||
# Create optimizer
|
||||
weight_decay = training_config.get('weight_decay', 0.01)
|
||||
max_grad_norm = training_config.get('max_grad_norm', 1.0)
|
||||
|
||||
optimizer = AdamW(
|
||||
learning_rate=lr_scheduler,
|
||||
beta1=0.9,
|
||||
beta2=0.999,
|
||||
epsilon=1e-8,
|
||||
parameters=self.model.parameters(),
|
||||
weight_decay=weight_decay,
|
||||
grad_clip=nn.ClipGradByNorm(max_grad_norm),
|
||||
apply_decay_param_fun=lambda x: x in decay_params
|
||||
)
|
||||
|
||||
return optimizer
|
||||
|
||||
def _init_checkpoint(self, checkpoint_dir=None) -> Checkpoint:
|
||||
"""Initialize checkpoint for saving and loading."""
|
||||
if checkpoint_dir is None:
|
||||
checkpoint_dir = self.config['output'].get('checkpoint_dir', './exp/whisper_fine_tune')
|
||||
|
||||
checkpoint_dir = Path(checkpoint_dir)
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return Checkpoint(
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
model=self.model,
|
||||
optimizer=self.optimizer,
|
||||
infos=dict(),
|
||||
visualizer=None,
|
||||
**{"epoch": 0}
|
||||
)
|
||||
|
||||
def train(self, train_loader: DataLoader, dev_loader: Optional[DataLoader] = None, num_epochs: int = 10):
|
||||
"""Run training loop.
|
||||
|
||||
Args:
|
||||
train_loader: DataLoader for training data
|
||||
dev_loader: DataLoader for validation data
|
||||
num_epochs: Number of epochs to train
|
||||
"""
|
||||
self.reporter.register(self.model, "model")
|
||||
|
||||
# Get training configuration
|
||||
training_config = self.config['training']
|
||||
max_epoch = num_epochs or training_config.get('max_epoch', 10)
|
||||
accum_grad = training_config.get('accum_grad', 1)
|
||||
|
||||
# Start training
|
||||
logger.info(f"Starting training for {max_epoch} epochs")
|
||||
self.timer.start()
|
||||
|
||||
# Resume from checkpoint if epoch > 0
|
||||
start_epoch = self.checkpoint.infos.get("epoch", 0)
|
||||
|
||||
for epoch in range(start_epoch, max_epoch):
|
||||
self._train_epoch(train_loader, epoch, accum_grad)
|
||||
|
||||
# Validation
|
||||
if dev_loader is not None:
|
||||
dev_loss = self._eval_epoch(dev_loader, epoch)
|
||||
self.checkpoint.infos["dev_loss"] = dev_loss
|
||||
|
||||
# Update epoch in checkpoint
|
||||
self.checkpoint.infos["epoch"] = epoch + 1
|
||||
|
||||
# Save checkpoint
|
||||
if (epoch + 1) % self.save_interval == 0 and self.rank == 0:
|
||||
logger.info(f"Saving checkpoint at epoch {epoch + 1}")
|
||||
self.checkpoint.save_parameters(tag=f"epoch_{epoch + 1}")
|
||||
|
||||
# Save final model
|
||||
if self.rank == 0:
|
||||
logger.info("Saving final model")
|
||||
self.checkpoint.save_parameters(tag="final")
|
||||
|
||||
def _train_epoch(self, train_loader: DataLoader, epoch: int, accum_grad: int = 1):
|
||||
"""Train for one epoch."""
|
||||
self.model.train()
|
||||
|
||||
train_loss = 0.0
|
||||
num_batches = 0
|
||||
steps_per_epoch = len(train_loader)
|
||||
|
||||
for batch_idx, batch in enumerate(train_loader):
|
||||
if batch is None:
|
||||
continue
|
||||
|
||||
# Get batch data
|
||||
mel = batch["mel"]
|
||||
tokens = batch["tokens"]
|
||||
labels = batch["labels"]
|
||||
|
||||
# Forward pass
|
||||
audio_features = self.model.embed_audio(mel)
|
||||
logits = self.model.logits(tokens, audio_features)
|
||||
|
||||
# Compute loss
|
||||
loss = F.cross_entropy(
|
||||
logits.reshape([-1, logits.shape[-1]]),
|
||||
labels.reshape([-1]),
|
||||
ignore_index=-100
|
||||
)
|
||||
|
||||
# Scale loss for gradient accumulation
|
||||
if accum_grad > 1:
|
||||
loss = loss / accum_grad
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
# Update parameters every accum_grad steps
|
||||
if (batch_idx + 1) % accum_grad == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.clear_grad()
|
||||
|
||||
# Logging
|
||||
train_loss += loss.item() * (accum_grad if accum_grad > 1 else 1)
|
||||
num_batches += 1
|
||||
|
||||
# Log training progress
|
||||
if batch_idx % self.log_interval == 0 and self.rank == 0:
|
||||
elapsed_time = self.timer.elapsed_interval()
|
||||
step = epoch * steps_per_epoch + batch_idx
|
||||
logger.info(f"Epoch {epoch+1} | Batch {batch_idx}/{steps_per_epoch} | "
|
||||
f"Loss: {loss.item()*(accum_grad if accum_grad > 1 else 1):.4f} | "
|
||||
f"Step {step} | {elapsed_time:.2f}s elapsed")
|
||||
|
||||
# Compute average loss
|
||||
avg_loss = train_loss / num_batches if num_batches > 0 else float('inf')
|
||||
|
||||
# Log epoch stats
|
||||
if self.rank == 0:
|
||||
logger.info(f"Epoch {epoch+1} | Average Training Loss: {avg_loss:.4f}")
|
||||
self.checkpoint.infos["train_loss"] = avg_loss
|
||||
|
||||
def _eval_epoch(self, dev_loader: DataLoader, epoch: int):
|
||||
"""Evaluate for one epoch."""
|
||||
self.model.eval()
|
||||
|
||||
dev_losses = []
|
||||
|
||||
with paddle.no_grad():
|
||||
for batch in dev_loader:
|
||||
if batch is None:
|
||||
continue
|
||||
|
||||
# Get batch data
|
||||
mel = batch["mel"]
|
||||
tokens = batch["tokens"]
|
||||
labels = batch["labels"]
|
||||
|
||||
# Forward pass
|
||||
audio_features = self.model.embed_audio(mel)
|
||||
logits = self.model.logits(tokens, audio_features)
|
||||
|
||||
# Compute loss
|
||||
loss = F.cross_entropy(
|
||||
logits.reshape([-1, logits.shape[-1]]),
|
||||
labels.reshape([-1]),
|
||||
ignore_index=-100
|
||||
)
|
||||
|
||||
dev_losses.append(loss.item())
|
||||
|
||||
avg_loss = sum(dev_losses) / len(dev_losses) if dev_losses else float('inf')
|
||||
|
||||
if self.rank == 0:
|
||||
logger.info(f"Epoch {epoch+1} | Validation Loss: {avg_loss:.4f}")
|
||||
|
||||
return avg_loss
|
||||
|
||||
def save(self, tag: str = "final"):
|
||||
"""Save model checkpoint."""
|
||||
if self.rank == 0:
|
||||
logger.info(f"Saving model checkpoint with tag '{tag}'")
|
||||
self.checkpoint.save_parameters(tag=tag)
|
||||
|
||||
def load(self, checkpoint_path: Union[str, Path]):
|
||||
"""Load model from checkpoint."""
|
||||
logger.info(f"Loading checkpoint from {checkpoint_path}")
|
||||
self.checkpoint.load_parameters(checkpoint_path)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,
|
||||
config: Dict,
|
||||
model_size: str = "base",
|
||||
checkpoint_path: Optional[str] = None,
|
||||
resource_path: Optional[str] = None):
|
||||
"""Create a trainer from pretrained model."""
|
||||
# Create model
|
||||
model_dims = MODEL_DIMENSIONS[model_size]
|
||||
model = Whisper(model_dims)
|
||||
|
||||
# Load checkpoint if provided
|
||||
if checkpoint_path:
|
||||
state_dict = paddle.load(checkpoint_path)
|
||||
model.set_state_dict(state_dict)
|
||||
|
||||
# Create trainer
|
||||
trainer = cls(
|
||||
config=config,
|
||||
model=model,
|
||||
resource_path=resource_path
|
||||
)
|
||||
|
||||
return trainer
|
@ -0,0 +1,271 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import yaml
|
||||
|
||||
from paddlespeech.s2t.models.whisper.tokenizer import get_tokenizer
|
||||
from paddlespeech.s2t.models.whisper.whisper import (
|
||||
MODEL_DIMENSIONS,
|
||||
LANGUAGES,
|
||||
Whisper,
|
||||
DecodingOptions,
|
||||
load_model,
|
||||
transcribe
|
||||
)
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
def load_whisper_model(
|
||||
model_size: str = "base",
|
||||
checkpoint_path: Optional[str] = None,
|
||||
resource_path: Optional[str] = None,
|
||||
) -> Whisper:
|
||||
"""Load Whisper model from checkpoint or pretrained weights.
|
||||
|
||||
Args:
|
||||
model_size: Model size for Whisper
|
||||
checkpoint_path: Path to model checkpoint from fine-tuning
|
||||
resource_path: Path to resources directory containing original models
|
||||
|
||||
Returns:
|
||||
Whisper model
|
||||
"""
|
||||
model_dims = MODEL_DIMENSIONS[model_size]
|
||||
model = Whisper(model_dims)
|
||||
|
||||
if checkpoint_path:
|
||||
logger.info(f"Loading model from checkpoint: {checkpoint_path}")
|
||||
state_dict = paddle.load(checkpoint_path)
|
||||
model.set_state_dict(state_dict)
|
||||
elif resource_path:
|
||||
model_path = os.path.join(resource_path, "whisper", f"whisper-{model_size}.pdparams")
|
||||
if os.path.exists(model_path):
|
||||
logger.info(f"Loading pretrained model from: {model_path}")
|
||||
state_dict = paddle.load(model_path)
|
||||
model.set_state_dict(state_dict)
|
||||
else:
|
||||
logger.error(f"Pretrained model not found at {model_path}")
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
else:
|
||||
logger.error("Either checkpoint_path or resource_path must be provided")
|
||||
raise ValueError("Either checkpoint_path or resource_path must be provided")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def detect_language(
|
||||
model: Whisper,
|
||||
mel: paddle.Tensor,
|
||||
tokenizer=None,
|
||||
resource_path: Optional[str] = None
|
||||
) -> str:
|
||||
"""Detect language from audio.
|
||||
|
||||
Args:
|
||||
model: Whisper model
|
||||
mel: Mel spectrogram
|
||||
tokenizer: Optional tokenizer
|
||||
resource_path: Path to resources directory (required if tokenizer not provided)
|
||||
|
||||
Returns:
|
||||
Detected language code
|
||||
"""
|
||||
if not tokenizer and not resource_path:
|
||||
raise ValueError("Either tokenizer or resource_path must be provided")
|
||||
|
||||
if not tokenizer:
|
||||
tokenizer = get_tokenizer(
|
||||
multilingual=model.is_multilingual,
|
||||
resource_path=resource_path
|
||||
)
|
||||
|
||||
# Get audio features from encoder
|
||||
audio_features = model.embed_audio(mel)
|
||||
|
||||
# Get initial tokens
|
||||
initial_tokens = paddle.to_tensor([[tokenizer.sot]], dtype=paddle.int64)
|
||||
|
||||
# Extract language token logits
|
||||
token_logits = model.logits(initial_tokens, audio_features)
|
||||
language_token_logits = token_logits[:, 0, tokenizer.language_token_ranges]
|
||||
|
||||
# Get language token with highest probability
|
||||
language_token_probs = paddle.softmax(language_token_logits, axis=-1)
|
||||
language_token_id = paddle.argmax(language_token_probs, axis=-1)
|
||||
detected_language_token_id = language_token_id.item() + tokenizer.language_token_ranges[0]
|
||||
|
||||
# Map token to language code
|
||||
language_token = tokenizer.all_tokens[detected_language_token_id]
|
||||
language_code = language_token.strip("<>")
|
||||
|
||||
return language_code
|
||||
|
||||
|
||||
def get_available_languages() -> List[str]:
|
||||
"""Get list of available languages in Whisper.
|
||||
|
||||
Returns:
|
||||
List of language codes
|
||||
"""
|
||||
return sorted(LANGUAGES.keys())
|
||||
|
||||
|
||||
def format_timestamp(
|
||||
seconds: float,
|
||||
always_include_hours: bool = False,
|
||||
decimal_marker: str = "."
|
||||
) -> str:
|
||||
"""Format a timestamp into a string.
|
||||
|
||||
Args:
|
||||
seconds: Number of seconds
|
||||
always_include_hours: Whether to always include hours
|
||||
decimal_marker: Decimal marker character
|
||||
|
||||
Returns:
|
||||
Formatted timestamp string
|
||||
"""
|
||||
seconds = max(0, seconds)
|
||||
hours = seconds // 3600
|
||||
seconds = seconds - hours * 3600
|
||||
minutes = seconds // 60
|
||||
seconds = seconds - minutes * 60
|
||||
|
||||
hours_marker = f"{int(hours):02d}:" if always_include_hours or hours > 0 else ""
|
||||
return f"{hours_marker}{int(minutes):02d}:{int(seconds):02d}{decimal_marker}{int(seconds * 100 % 100):02d}"
|
||||
|
||||
|
||||
def save_srt(
|
||||
transcript: Dict,
|
||||
file_path: str
|
||||
):
|
||||
"""Save transcript to SRT file.
|
||||
|
||||
Args:
|
||||
transcript: Transcript dictionary from Whisper
|
||||
file_path: Path to output SRT file
|
||||
"""
|
||||
if not transcript.get("segments"):
|
||||
return
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
for i, segment in enumerate(transcript["segments"], start=1):
|
||||
start = format_timestamp(segment["start"], always_include_hours=True, decimal_marker=",")
|
||||
end = format_timestamp(segment["end"], always_include_hours=True, decimal_marker=",")
|
||||
text = segment["text"].strip().replace("-->", "->")
|
||||
|
||||
f.write(f"{i}\n")
|
||||
f.write(f"{start} --> {end}\n")
|
||||
f.write(f"{text}\n\n")
|
||||
|
||||
|
||||
def batch_transcribe_with_progress(
|
||||
model: Whisper,
|
||||
dataset,
|
||||
dataloader,
|
||||
resource_path: str,
|
||||
language: Optional[str] = None,
|
||||
task: str = "transcribe",
|
||||
beam_size: int = 5,
|
||||
temperature: float = 0.0,
|
||||
without_timestamps: bool = True,
|
||||
verbose: bool = False,
|
||||
decoder_options: Optional[dict] = None,
|
||||
):
|
||||
"""Transcribe audio files in batches with progress reporting.
|
||||
|
||||
Args:
|
||||
model: Whisper model
|
||||
dataset: WhisperInferenceDataset
|
||||
dataloader: DataLoader from dataset
|
||||
resource_path: Path to resources directory
|
||||
language: Language code or None for auto-detection
|
||||
task: Task (transcribe or translate)
|
||||
beam_size: Beam size for beam search
|
||||
temperature: Temperature for sampling
|
||||
without_timestamps: Whether to include timestamps
|
||||
verbose: Whether to show verbose output
|
||||
decoder_options: Additional decoder options
|
||||
|
||||
Returns:
|
||||
List of transcription results
|
||||
"""
|
||||
model.eval()
|
||||
tokenizer = get_tokenizer(
|
||||
multilingual=model.is_multilingual,
|
||||
resource_path=resource_path,
|
||||
language=language,
|
||||
task=task
|
||||
)
|
||||
|
||||
results = []
|
||||
total_batches = len(dataloader)
|
||||
|
||||
with paddle.no_grad():
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
if verbose:
|
||||
logger.info(f"Processing batch {batch_idx+1}/{total_batches}")
|
||||
|
||||
mel = batch["mel"] # [batch, n_mels, n_frames]
|
||||
audio_paths = batch["audio_paths"]
|
||||
batch_size = mel.shape[0]
|
||||
|
||||
# Process each item in batch
|
||||
for i in range(batch_size):
|
||||
audio_path = audio_paths[i]
|
||||
mel_i = mel[i:i+1] # Keep batch dimension
|
||||
|
||||
# Auto-detect language if needed
|
||||
current_language = language
|
||||
if not current_language:
|
||||
current_language = detect_language(model, mel_i, tokenizer, resource_path)
|
||||
if verbose:
|
||||
logger.info(f"Detected language: {current_language}")
|
||||
|
||||
# Set up decoding options
|
||||
options = DecodingOptions(
|
||||
task=task,
|
||||
language=current_language,
|
||||
beam_size=beam_size,
|
||||
temperature=temperature,
|
||||
without_timestamps=without_timestamps,
|
||||
**decoder_options if decoder_options else {}
|
||||
)
|
||||
|
||||
# Transcribe
|
||||
result = transcribe(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
mel=mel_i,
|
||||
resource_path=resource_path,
|
||||
verbose=verbose,
|
||||
**options.__dict__
|
||||
)
|
||||
|
||||
# Add file path to result
|
||||
result["audio_path"] = audio_path
|
||||
results.append(result)
|
||||
|
||||
if verbose:
|
||||
logger.info(f"Transcribed {os.path.basename(audio_path)}: {result['text'][:80]}...")
|
||||
|
||||
return results
|
Loading…
Reference in new issue