diff --git a/examples/commonvoice/whisper/README.md b/examples/commonvoice/whisper/README.md new file mode 100644 index 000000000..98af7051c --- /dev/null +++ b/examples/commonvoice/whisper/README.md @@ -0,0 +1,238 @@ +# Whisper Fine-tuning on Common Voice + +This example demonstrates how to fine-tune Whisper models on the [Common Voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) dataset using PaddleSpeech. + +## Overview + +Whisper is a state-of-the-art speech recognition model from OpenAI. This implementation allows you to fine-tune Whisper models on new datasets to improve performance for specific languages, domains, or accents. + +## Features + +- Complete fine-tuning pipeline for Whisper models on custom datasets +- Flexible configuration via YAML files +- Support for all Whisper model sizes (tiny, base, small, medium, large, etc.) +- Data preparation tools for Common Voice and custom datasets +- Distributed training support with mixed precision +- Gradient accumulation for large batch training +- Learning rate scheduling and optimization techniques +- Evaluation tools with WER/CER metrics +- Command-line inference with both fine-tuned and original Whisper models +- Model export utilities for deployment +- Visualization tools for performance analysis + +## Installation + +Ensure you have PaddleSpeech installed with all dependencies: + +```bash +git clone https://github.com/PaddlePaddle/PaddleSpeech.git +cd PaddleSpeech +pip install -e . +pip install datasets soundfile librosa matplotlib pandas jiwer +``` + +## Data + +We use the Common Voice dataset (version 11.0) available on Hugging Face: +https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0 + +Other datasets compatible with the pipeline include LibriSpeech, AISHELL, and any dataset that can be converted to the manifest format (see below). + +## Unified Command-Line Interface + +This example includes a unified CLI for all operations: + +```bash +python whisper_cli.py COMMAND [OPTIONS] +``` + +Available commands: +- `prepare`: Prepare dataset for fine-tuning +- `train`: Fine-tune Whisper model +- `evaluate`: Evaluate model performance +- `infer`: Run inference with fine-tuned or original model + +## Data Preparation + +To download and prepare the Common Voice dataset: + +```bash +python whisper_cli.py prepare --language en --output_dir ./data +``` + +Options: +- `--language`: Target language code (default: en) +- `--output_dir`: Directory to save preprocessed data (default: ./data) +- `--cache_dir`: Cache directory for HuggingFace datasets +- `--val_size`: Validation set size ratio (default: 0.03) +- `--test_size`: Test set size ratio (default: 0.03) +- `--min_duration`: Minimum audio duration in seconds (default: 0.5) +- `--max_duration`: Maximum audio duration in seconds (default: 30.0) + +Manifest format: +```json +{"audio": "path/to/audio.wav", "text": "transcription", "duration": 3.45} +``` + +## Configuration + +Fine-tuning parameters are specified in YAML config files. See `conf/whisper_base.yaml` for a detailed example. + +Key configuration sections: +- **model**: Model size, checkpoint path, freeze options +- **data**: Dataset paths, languages, tasks +- **training**: Batch size, learning rate, optimizer settings +- **distributed**: Distributed training options +- **output**: Save paths, logging options + +## Training + +To fine-tune the Whisper model: + +```bash +python whisper_cli.py train --config conf/whisper_base.yaml --resource_path ./resources +``` + +For distributed training: + +```bash +python -m paddle.distributed.launch --gpus "0,1,2,3" whisper_cli.py train --config conf/whisper_base.yaml --distributed True +``` + +Options: +- `--config`: Path to configuration YAML file +- `--resource_path`: Path to resources directory containing model assets +- `--device`: Device to use (cpu, gpu, xpu) +- `--seed`: Random seed +- `--checkpoint_path`: Path to resume training from checkpoint +- `--distributed`: Enable distributed training + +## Evaluation + +Evaluate model performance on a test set: + +```bash +python whisper_cli.py evaluate --manifest ./data/test_manifest.json --checkpoint ./exp/whisper_fine_tune/epoch_10 --output_dir ./eval_results +``` + +Options: +- `--manifest`: Path to test manifest file +- `--checkpoint`: Path to model checkpoint +- `--model_size`: Model size if using original Whisper +- `--language`: Language code +- `--output_dir`: Directory to save evaluation results +- `--max_samples`: Maximum number of samples to evaluate + +## Inference + +For transcribing audio with a fine-tuned model: + +```bash +python whisper_cli.py infer --audio_file path/to/audio.wav --checkpoint ./exp/whisper_fine_tune/final +``` + +For batch processing a directory: + +```bash +python whisper_cli.py infer --audio_dir path/to/audio/folder --output_dir ./transcriptions --checkpoint ./exp/whisper_fine_tune/final +``` + +For inference with the original Whisper models: + +```bash +python whisper_cli.py infer --audio_file path/to/audio.wav --use_original --model_size large-v3 --resource_path ./resources +``` + +Options: +- `--audio_file`: Path to single audio file +- `--audio_dir`: Path to directory with audio files +- `--checkpoint`: Path to fine-tuned checkpoint +- `--use_original`: Use original Whisper model +- `--model_size`: Model size (tiny, base, small, medium, large, etc.) +- `--language`: Language code (or "auto" for detection) +- `--task`: Task type (transcribe or translate) +- `--beam_size`: Beam size for beam search +- `--temperature`: Temperature for sampling +- `--without_timestamps`: Don't include timestamps + +## Visualization + +Visualize evaluation results: + +```bash +python visualize.py --results_file ./eval_results/evaluation_results.json --output_dir ./visualizations +``` + +Options: +- `--results_file`: Path to evaluation results JSON file +- `--output_dir`: Directory to save visualizations +- `--audio_dir`: Directory with audio files (optional) +- `--num_samples`: Number of individual samples to visualize +- `--show`: Show plots interactively + +## Model Export + +Export fine-tuned model to inference format: + +```bash +python export_model.py --checkpoint ./exp/whisper_fine_tune/final --output_path ./exported_model --model_size base +``` + +Options: +- `--checkpoint`: Path to model checkpoint +- `--output_path`: Path to save exported model +- `--model_size`: Model size + +## Advanced Usage + +### Freezing Encoder + +To freeze the encoder and only fine-tune the decoder, set the following in your config file: + +```yaml +model: + freeze_encoder: true +``` + +### Gradient Accumulation + +For effective training with limited GPU memory, use gradient accumulation: + +```yaml +training: + accum_grad: 8 # Accumulate gradients over 8 batches +``` + +### Mixed Precision + +Enable mixed precision training for faster computation: + +```yaml +training: + amp: true # Enable automatic mixed precision +``` + +### Custom Datasets + +To use custom datasets, prepare manifest files in the following format: + +```json +{"audio": "/absolute/path/to/audio.wav", "text": "transcription text"} +``` + +Then specify the manifest paths in your config file: + +```yaml +data: + train_manifest: path/to/train_manifest.json + dev_manifest: path/to/dev_manifest.json + test_manifest: path/to/test_manifest.json +``` + +## Reference + +- [OpenAI Whisper](https://github.com/openai/whisper) +- [Fine-tuning Whisper](https://huggingface.co/blog/fine-tune-whisper) +- [Paper: Robust Speech Recognition via Large-Scale Weak Supervision](https://arxiv.org/pdf/2212.04356) +- [Common Voice Dataset](https://commonvoice.mozilla.org/) +- [PaddleSpeech Documentation](https://github.com/PaddlePaddle/PaddleSpeech) diff --git a/examples/commonvoice/whisper/conf/whisper_base.yaml b/examples/commonvoice/whisper/conf/whisper_base.yaml new file mode 100644 index 000000000..f81873b15 --- /dev/null +++ b/examples/commonvoice/whisper/conf/whisper_base.yaml @@ -0,0 +1,70 @@ +# Configuration for fine-tuning Whisper base model + +# Data settings +data: + train_manifest: "data/train_manifest.json" + dev_manifest: "data/dev_manifest.json" + test_manifest: "data/test_manifest.json" + target_language: "en" # Language code for fine-tuning + max_duration: 30.0 # Maximum audio duration in seconds + min_duration: 0.5 # Minimum audio duration in seconds + +# Model settings +model: + name: "whisper" + size: "base" # Options: tiny, base, small, medium, large, large-v2, large-v3 + checkpoint: null # Path to pre-trained checkpoint, null for default + freeze_encoder: false # Whether to freeze the encoder during fine-tuning + use_fp16: true # Whether to use half precision + +# Training settings +training: + max_epoch: 20 + save_epoch: 1 + log_interval: 100 + batch_size: 16 + num_workers: 4 + accum_grad: 1 # Gradient accumulation steps + + # Optimizer settings + optimizer: "adamw" + learning_rate: 1e-5 + weight_decay: 0.01 + scheduler: "cosine" + warmup_ratio: 0.03 + max_grad_norm: 1.0 + + # Regularization + dropout: 0.1 + label_smoothing: 0.1 + + # Mixed precision training + amp_level: "O1" + amp_dtype: "float16" + +# Distributed training +distributed: + use_fleet: true + strategy: "standard" + find_unused_parameters: false + +# Output settings +output: + checkpoint_dir: "exp/whisper_fine_tune" + save_checkpoint: true + save_interval: 1 + keep_checkpoint_max: 5 + +# Evaluation settings +eval: + eval_batch_size: 16 + metrics: ["wer", "cer"] + +# Inference settings +inference: + beam_size: 5 + min_tokens: 0 + max_tokens: 448 + temperature: 0.0 + language: null # Set to target language code or null to auto-detect + without_timestamps: true diff --git a/examples/commonvoice/whisper/evaluate.py b/examples/commonvoice/whisper/evaluate.py new file mode 100644 index 000000000..16cb2eac0 --- /dev/null +++ b/examples/commonvoice/whisper/evaluate.py @@ -0,0 +1,264 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os +import time +from pathlib import Path +from typing import Dict, List, Optional, Union + +import numpy as np +import paddle +import yaml + +from paddlespeech.s2t.models.whisper.whisper import (MODEL_DIMENSIONS, Whisper, + DecodingOptions, log_mel_spectrogram, + transcribe) +from paddlespeech.s2t.utils.log import Log +from paddlespeech.s2t.utils.utility import get_rank, str2bool +from paddlespeech.metrics.wer import word_errors, char_errors + +logger = Log(__name__).getlog() + + +def compute_metrics(references, hypotheses, language="en"): + """Compute WER and CER metrics.""" + total_words = 0 + total_chars = 0 + total_word_errors = 0 + total_char_errors = 0 + + for ref, hyp in zip(references, hypotheses): + ref = ref.strip() + hyp = hyp.strip() + + word_error_count, word_count = word_errors(ref, hyp, language) + char_error_count, char_count = char_errors(ref, hyp, language) + + total_words += word_count + total_chars += char_count + total_word_errors += word_error_count + total_char_errors += char_error_count + + wer = float(total_word_errors) / max(1, total_words) + cer = float(total_char_errors) / max(1, total_chars) + + return { + "wer": wer, + "cer": cer, + "word_errors": total_word_errors, + "word_count": total_words, + "char_errors": total_char_errors, + "char_count": total_chars, + } + + +def load_model(model_size="base", checkpoint_path=None, resource_path=None): + """Load Whisper model from checkpoint or pretrained weights.""" + model_dims = MODEL_DIMENSIONS[model_size] + model = Whisper(model_dims) + + if checkpoint_path: + logger.info(f"Loading model from checkpoint: {checkpoint_path}") + state_dict = paddle.load(checkpoint_path) + model.set_state_dict(state_dict) + elif resource_path: + model_path = os.path.join(resource_path, "whisper", f"whisper-{model_size}.pdparams") + if os.path.exists(model_path): + logger.info(f"Loading pretrained model from: {model_path}") + state_dict = paddle.load(model_path) + model.set_state_dict(state_dict) + else: + logger.error(f"Pretrained model not found at {model_path}") + raise FileNotFoundError(f"Model file not found: {model_path}") + else: + logger.error("Either checkpoint_path or resource_path must be provided") + raise ValueError("Either checkpoint_path or resource_path must be provided") + + return model + + +def evaluate_manifest( + model, + manifest_path, + resource_path, + language=None, + task="transcribe", + temperature=0.0, + beam_size=5, + patience=1.0, + batch_size=16, + without_timestamps=True, + fp16=False, + verbose=True, + output_dir=None, + max_samples=None +): + """Evaluate Whisper model on a manifest file.""" + # Load manifest + with open(manifest_path, 'r', encoding='utf8') as f: + manifest_lines = f.readlines() + + # Limit samples if requested + if max_samples: + manifest_lines = manifest_lines[:max_samples] + + references = [] + hypotheses = [] + audio_paths = [] + durations = [] + + # Process each item + start_time = time.time() + model.eval() + + for i, line in enumerate(manifest_lines): + if i % 10 == 0: + logger.info(f"Processing item {i+1}/{len(manifest_lines)}") + + item = json.loads(line.strip()) + audio_path = item["audio"] + reference_text = item["text"] + + # Get duration if available + duration = item.get("duration", 0) + durations.append(duration) + + audio_paths.append(audio_path) + references.append(reference_text) + + # Process audio + try: + mel = log_mel_spectrogram(audio_path, resource_path=resource_path) + + # Setup decoding options + decode_options = DecodingOptions( + language=language, + task=task, + temperature=temperature, + beam_size=beam_size, + patience=patience, + without_timestamps=without_timestamps, + fp16=fp16, + ) + + # Run transcription + with paddle.no_grad(): + result = transcribe( + model=model, + mel=mel, + resource_path=resource_path, + verbose=False, + **decode_options.__dict__, + ) + + hypotheses.append(result["text"]) + + except Exception as e: + logger.error(f"Error processing {audio_path}: {str(e)}") + hypotheses.append("") + + # Compute metrics + elapsed = time.time() - start_time + logger.info(f"Processed {len(manifest_lines)} examples in {elapsed:.2f} seconds") + + metrics = compute_metrics(references, hypotheses, language=language if language else "en") + logger.info(f"WER: {metrics['wer']*100:.2f}%, CER: {metrics['cer']*100:.2f}%") + + # Save results if output directory provided + if output_dir: + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Save detailed results + results = { + "metrics": metrics, + "details": [ + { + "audio": audio_path, + "reference": reference, + "hypothesis": hypothesis, + "duration": duration + } + for audio_path, reference, hypothesis, duration in zip( + audio_paths, references, hypotheses, durations + ) + ] + } + + with open(output_dir / "evaluation_results.json", "w", encoding="utf-8") as f: + json.dump(results, f, ensure_ascii=False, indent=2) + + logger.info(f"Detailed evaluation results saved to {output_dir}") + + return metrics + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate Whisper ASR Model") + parser.add_argument("--manifest", type=str, required=True, help="Path to manifest file") + parser.add_argument("--output_dir", type=str, default="./results", help="Directory to save results") + parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint from fine-tuning") + parser.add_argument("--resource_path", type=str, default="./resources", + help="Path to resources directory") + + # Model options + parser.add_argument("--model_size", type=str, default="base", + choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], + help="Model size for original Whisper") + + # Decoding options + parser.add_argument("--language", type=str, default=None, help="Language code (e.g., en, fr)") + parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], + help="Task: transcribe or translate to English") + parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for sampling") + parser.add_argument("--beam_size", type=int, default=5, help="Beam size for beam search") + parser.add_argument("--patience", type=float, default=1.0, help="Beam search patience factor") + parser.add_argument("--without_timestamps", type=str2bool, default=True, help="Don't include timestamps") + parser.add_argument("--batch_size", type=int, default=16, help="Batch size for evaluation") + parser.add_argument("--fp16", type=str2bool, default=False, help="Use half-precision float16") + parser.add_argument("--verbose", type=str2bool, default=True, help="Whether to display verbose logs") + parser.add_argument("--max_samples", type=int, default=None, help="Maximum number of samples to evaluate") + + args = parser.parse_args() + + # Load model + model = load_model( + model_size=args.model_size, + checkpoint_path=args.checkpoint, + resource_path=args.resource_path + ) + + # Evaluate + evaluate_manifest( + model=model, + manifest_path=args.manifest, + resource_path=args.resource_path, + language=args.language, + task=args.task, + temperature=args.temperature, + beam_size=args.beam_size, + patience=args.patience, + batch_size=args.batch_size, + without_timestamps=args.without_timestamps, + fp16=args.fp16, + verbose=args.verbose, + output_dir=args.output_dir, + max_samples=args.max_samples + ) + + +if __name__ == "__main__": + main() diff --git a/examples/commonvoice/whisper/export_model.py b/examples/commonvoice/whisper/export_model.py new file mode 100644 index 000000000..a85d97c17 --- /dev/null +++ b/examples/commonvoice/whisper/export_model.py @@ -0,0 +1,135 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from pathlib import Path + +import paddle + +from paddlespeech.s2t.models.whisper.whisper import MODEL_DIMENSIONS, Whisper +from paddlespeech.s2t.utils.log import Log + +logger = Log(__name__).getlog() + + +def export_encoder(model, save_path, input_shape=(1, 80, 3000)): + """Export encoder part of Whisper to inference model.""" + model.eval() + + # Create save directory if not exists + save_dir = os.path.dirname(save_path) + os.makedirs(save_dir, exist_ok=True) + + # Define input spec + mel_spec = paddle.static.InputSpec(shape=input_shape, dtype='float32', name='mel') + + # Export encoder model + encoder_path = f"{save_path}_encoder" + paddle.jit.save( + layer=model.encoder, + path=encoder_path, + input_spec=[mel_spec] + ) + logger.info(f"Encoder model exported to {encoder_path}") + return encoder_path + + +def export_decoder(model, save_path, input_shape_tokens=(1, 448), input_shape_features=(1, 1500, 768)): + """Export decoder part of Whisper to inference model.""" + model.eval() + + # Create save directory if not exists + save_dir = os.path.dirname(save_path) + os.makedirs(save_dir, exist_ok=True) + + # Define input spec + token_spec = paddle.static.InputSpec(shape=input_shape_tokens, dtype='int64', name='tokens') + audio_features_spec = paddle.static.InputSpec(shape=input_shape_features, dtype='float32', name='audio_features') + + # Create a wrapper to match the exact API of the decoder + class DecoderWrapper(paddle.nn.Layer): + def __init__(self, decoder): + super().__init__() + self.decoder = decoder + + def forward(self, tokens, audio_features): + return self.decoder(tokens, audio_features) + + wrapper = DecoderWrapper(model.decoder) + + # Export decoder model + decoder_path = f"{save_path}_decoder" + paddle.jit.save( + layer=wrapper, + path=decoder_path, + input_spec=[token_spec, audio_features_spec] + ) + logger.info(f"Decoder model exported to {decoder_path}") + return decoder_path + + +def export_whisper(model, save_path): + """Export full Whisper model to static graph models.""" + export_encoder(model, save_path) + export_decoder(model, save_path) + + # Export model info + dims = model.dims + model_info = { + "n_mels": dims.n_mels, + "n_vocab": dims.n_vocab, + "n_audio_ctx": dims.n_audio_ctx, + "n_audio_state": dims.n_audio_state, + "n_audio_head": dims.n_audio_head, + "n_audio_layer": dims.n_audio_layer, + "n_text_ctx": dims.n_text_ctx, + "n_text_state": dims.n_text_state, + "n_text_head": dims.n_text_head, + "n_text_layer": dims.n_text_layer + } + + # Save model info + import json + with open(f"{save_path}_info.json", "w") as f: + json.dump(model_info, f, indent=4) + + logger.info(f"Model info saved to {save_path}_info.json") + + +def main(): + parser = argparse.ArgumentParser(description="Export Whisper model to inference format") + parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint") + parser.add_argument("--output_path", type=str, required=True, help="Path to save exported model") + parser.add_argument("--model_size", type=str, default="base", + choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], + help="Model size") + + args = parser.parse_args() + + # Create model + model_dims = MODEL_DIMENSIONS[args.model_size] + model = Whisper(model_dims) + + # Load checkpoint + state_dict = paddle.load(args.checkpoint) + model.set_state_dict(state_dict) + + # Export model + export_whisper(model, args.output_path) + logger.info(f"Model exported to {args.output_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/commonvoice/whisper/infer.py b/examples/commonvoice/whisper/infer.py new file mode 100644 index 000000000..833a26d1b --- /dev/null +++ b/examples/commonvoice/whisper/infer.py @@ -0,0 +1,323 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import time +from pathlib import Path +from typing import Dict, List, Optional, Union + +import numpy as np +import paddle +import yaml + +from paddlespeech.s2t.models.whisper.whisper import (CHUNK_LENGTH, MODEL_DIMENSIONS, N_MELS, + DecodingOptions, Whisper, + log_mel_spectrogram, pad_or_trim, + transcribe) +from paddlespeech.s2t.utils.log import Log +from paddlespeech.s2t.utils.utility import get_files_by_ext, str2bool + +logger = Log(__name__).getlog() + + +def load_audio(file_path, sample_rate=16000): + """Load audio file and convert to 16kHz mono if needed.""" + try: + import librosa + audio, sr = librosa.load(file_path, sr=sample_rate, mono=True) + return audio + except Exception as e: + logger.error(f"Error loading audio file {file_path}: {str(e)}") + return None + + +def get_model_path(size, resource_path): + """Get the model path based on size.""" + return os.path.join(resource_path, "whisper", f"whisper-{size}.pdparams") + + +def load_whisper_model(model_size="base", checkpoint_path=None, resource_path=None): + """Load Whisper model from checkpoint or pretrained weights.""" + model_dims = MODEL_DIMENSIONS[model_size] + model = Whisper(model_dims) + + if checkpoint_path: + logger.info(f"Loading model from checkpoint: {checkpoint_path}") + state_dict = paddle.load(checkpoint_path) + model.set_state_dict(state_dict) + elif resource_path: + model_path = get_model_path(model_size, resource_path) + if os.path.exists(model_path): + logger.info(f"Loading pretrained model from: {model_path}") + state_dict = paddle.load(model_path) + model.set_state_dict(state_dict) + else: + logger.error(f"Pretrained model not found at {model_path}") + raise FileNotFoundError(f"Model file not found: {model_path}") + else: + logger.error("Either checkpoint_path or resource_path must be provided") + raise ValueError("Either checkpoint_path or resource_path must be provided") + + return model + + +def transcribe_file( + model, + audio_file, + resource_path, + language=None, + task="transcribe", + temperature=0.0, + beam_size=5, + best_of=5, + patience=1.0, + length_penalty=1.0, + suppress_tokens="-1", + initial_prompt=None, + condition_on_previous_text=True, + without_timestamps=False, + fp16=False, + verbose=True, +): + """Transcribe a single audio file.""" + # Check if file exists + if not os.path.exists(audio_file): + logger.error(f"Audio file not found: {audio_file}") + return None + + # Load and process audio + logger.info(f"Processing audio file: {audio_file}") + try: + mel = log_mel_spectrogram(audio_file, resource_path=resource_path) + except Exception as e: + logger.error(f"Error processing audio: {str(e)}") + return None + + # Setup decoding options + decode_options = DecodingOptions( + language=language, + task=task, + temperature=temperature, + beam_size=beam_size, + best_of=best_of, + patience=patience, + length_penalty=length_penalty, + suppress_tokens=suppress_tokens, + prompt=initial_prompt, + without_timestamps=without_timestamps, + fp16=fp16, + ) + + # Run transcription + logger.info("Running transcription...") + start_time = time.time() + + model.eval() + with paddle.no_grad(): + result = transcribe( + model=model, + mel=mel, + resource_path=resource_path, + verbose=verbose, + condition_on_previous_text=condition_on_previous_text, + **decode_options.__dict__, + ) + + elapsed = time.time() - start_time + logger.info(f"Transcription completed in {elapsed:.2f} seconds") + + return result + + +def batch_transcribe( + model, + audio_dir, + output_dir, + resource_path, + extensions=["wav", "mp3", "flac", "m4a", "ogg"], + **transcribe_kwargs +): + """Transcribe all audio files in a directory.""" + # Create output directory + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Get list of audio files + audio_files = [] + for ext in extensions: + audio_files.extend(get_files_by_ext(audio_dir, ext)) + + if not audio_files: + logger.error(f"No audio files found in {audio_dir} with extensions {extensions}") + return + + logger.info(f"Found {len(audio_files)} audio files to process") + + # Process each file + results = {} + for audio_file in audio_files: + base_name = os.path.basename(audio_file) + file_name = os.path.splitext(base_name)[0] + + logger.info(f"Processing file: {base_name}") + result = transcribe_file(model, audio_file, resource_path, **transcribe_kwargs) + + if result: + # Save transcription + output_file = output_dir / f"{file_name}.txt" + with open(output_file, "w", encoding="utf-8") as f: + f.write(result["text"].strip()) + + # Save detailed results + results[base_name] = { + "text": result["text"].strip(), + "segments": result.get("segments", []), + "language": result.get("language", ""), + } + + # Save all results as JSON + import json + with open(output_dir / "all_transcripts.json", "w", encoding="utf-8") as f: + json.dump(results, f, ensure_ascii=False, indent=2) + + logger.info(f"All transcriptions saved to {output_dir}") + return results + + +def main(): + parser = argparse.ArgumentParser(description="Whisper ASR Inference") + parser.add_argument("--audio_file", type=str, help="Path to audio file for transcription") + parser.add_argument("--audio_dir", type=str, help="Path to directory containing audio files") + parser.add_argument("--output_dir", type=str, default="./transcripts", help="Output directory for transcriptions") + parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint from fine-tuning") + parser.add_argument("--resource_path", type=str, default="./resources", + help="Path to resources directory containing original models and assets") + + # Model options + parser.add_argument("--use_original", type=str2bool, default=False, + help="Use original Whisper model instead of fine-tuned") + parser.add_argument("--model_size", type=str, default="base", + choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], + help="Model size for original Whisper") + + # Decoding options + parser.add_argument("--language", type=str, default=None, help="Language code (e.g., en, fr, auto for detection)") + parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], + help="Task: transcribe or translate to English") + parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for sampling") + parser.add_argument("--beam_size", type=int, default=5, help="Beam size for beam search") + parser.add_argument("--best_of", type=int, default=5, help="Number of candidates when sampling with temp > 0") + parser.add_argument("--patience", type=float, default=1.0, help="Beam search patience factor") + parser.add_argument("--length_penalty", type=float, default=1.0, help="Exponential length penalty") + parser.add_argument("--suppress_tokens", type=str, default="-1", help="Comma-separated list of token ids to suppress") + parser.add_argument("--initial_prompt", type=str, default=None, help="Optional text to provide as prompt") + parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, + help="Whether to condition on previous text") + parser.add_argument("--without_timestamps", type=str2bool, default=False, help="Don't include timestamps") + parser.add_argument("--fp16", type=str2bool, default=False, help="Use half-precision float16") + parser.add_argument("--verbose", type=str2bool, default=True, help="Whether to display the text being decoded") + + args = parser.parse_args() + + # Validate arguments + if not args.audio_file and not args.audio_dir: + parser.error("Either --audio_file or --audio_dir must be specified") + + if args.use_original: + # Use original Whisper model + if not args.resource_path: + parser.error("--resource_path must be specified when using original model") + + model = load_whisper_model( + model_size=args.model_size, + resource_path=args.resource_path + ) + else: + # Use fine-tuned model + if not args.checkpoint: + parser.error("--checkpoint must be specified when not using original model") + + # Determine model size from checkpoint directory structure + model_size = "base" # Default + if "tiny" in args.checkpoint: + model_size = "tiny" + elif "small" in args.checkpoint: + model_size = "small" + elif "medium" in args.checkpoint: + model_size = "medium" + elif "large" in args.checkpoint: + model_size = "large" + + model = load_whisper_model( + model_size=model_size, + checkpoint_path=args.checkpoint, + resource_path=args.resource_path + ) + + # Prepare transcription keyword arguments + transcribe_kwargs = { + "language": args.language, + "task": args.task, + "temperature": args.temperature, + "beam_size": args.beam_size, + "best_of": args.best_of, + "patience": args.patience, + "length_penalty": args.length_penalty, + "suppress_tokens": args.suppress_tokens, + "initial_prompt": args.initial_prompt, + "condition_on_previous_text": args.condition_on_previous_text, + "without_timestamps": args.without_timestamps, + "fp16": args.fp16, + "verbose": args.verbose, + } + + # Run transcription + if args.audio_file: + # Single file transcription + result = transcribe_file(model, args.audio_file, args.resource_path, **transcribe_kwargs) + if result: + print("-" * 40) + print("Transcription:") + print(result["text"]) + print("-" * 40) + + # Save to output directory if specified + if args.output_dir: + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + file_name = os.path.splitext(os.path.basename(args.audio_file))[0] + with open(output_dir / f"{file_name}.txt", "w", encoding="utf-8") as f: + f.write(result["text"].strip()) + + import json + with open(output_dir / f"{file_name}.json", "w", encoding="utf-8") as f: + json.dump(result, f, ensure_ascii=False, indent=2) + + print(f"Results saved to {output_dir}") + + elif args.audio_dir: + # Batch transcription + batch_transcribe( + model, + args.audio_dir, + args.output_dir, + args.resource_path, + **transcribe_kwargs + ) + + +if __name__ == "__main__": + main() diff --git a/examples/commonvoice/whisper/prepare_data.py b/examples/commonvoice/whisper/prepare_data.py new file mode 100644 index 000000000..b89294092 --- /dev/null +++ b/examples/commonvoice/whisper/prepare_data.py @@ -0,0 +1,197 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os +import random +import shutil +from pathlib import Path +from typing import Dict, List, Optional, Union + +import datasets +import numpy as np +import soundfile +import tqdm +import yaml + + +def prepare_common_voice(language: str, + output_dir: str, + cache_dir: Optional[str] = None, + val_size: float = 0.03, + test_size: float = 0.03, + min_duration: float = 0.5, + max_duration: float = 30.0, + seed: int = 42): + """ + Prepare Mozilla Common Voice dataset for Whisper fine-tuning. + + Args: + language: Language code (e.g., "en", "fr", "es") + output_dir: Directory to save preprocessed data + cache_dir: Cache directory for HuggingFace datasets + val_size: Validation set size as a fraction of total data + test_size: Test set size as a fraction of total data + min_duration: Minimum audio duration in seconds + max_duration: Maximum audio duration in seconds + seed: Random seed for reproducibility + """ + print(f"Preparing Common Voice dataset for language: {language}") + + # Create output directories + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + (output_dir / "wavs").mkdir(exist_ok=True) + + # Load Common Voice dataset + print("Loading Common Voice dataset from HuggingFace...") + try: + common_voice = datasets.load_dataset( + "mozilla-foundation/common_voice_11_0", + language, + cache_dir=cache_dir, + trust_remote_code=True + ) + except Exception as e: + print(f"Error loading dataset: {e}") + print("Make sure you have access to the Common Voice dataset on HuggingFace.") + return + + # Filter and process data + print("Processing dataset...") + + # Function to process and filter examples + def process_example(example): + audio = example['audio'] + sample_rate = audio['sampling_rate'] + + # Calculate duration + duration = len(audio['array']) / sample_rate + + # Filter by duration + if duration < min_duration or duration > max_duration: + return None + + return { + "path": None, # Will be filled later + "audio": audio, + "text": example['sentence'], + "duration": duration, + } + + # Process all splits + all_data = [] + for split in ['train', 'validation', 'test']: + if split in common_voice: + split_data = [] + for example in tqdm.tqdm(common_voice[split], desc=f"Processing {split} set"): + processed = process_example(example) + if processed: + split_data.append(processed) + all_data.extend(split_data) + + # Shuffle and split data + random.seed(seed) + random.shuffle(all_data) + + total_size = len(all_data) + val_count = max(1, int(total_size * val_size)) + test_count = max(1, int(total_size * test_size)) + train_count = total_size - val_count - test_count + + train_data = all_data[:train_count] + val_data = all_data[train_count:train_count + val_count] + test_data = all_data[train_count + val_count:] + + print(f"Dataset split - Train: {len(train_data)}, Dev: {len(val_data)}, Test: {len(test_data)}") + + # Save audio files and create manifest files + def save_manifest(data, name): + manifest = [] + for i, item in enumerate(tqdm.tqdm(data, desc=f"Saving {name} files")): + # Generate filename + filename = f"{name}_{i:08d}.wav" + filepath = str(output_dir / "wavs" / filename) + + # Save audio file + soundfile.write( + filepath, + item["audio"]["array"], + item["audio"]["sampling_rate"] + ) + + # Add to manifest + manifest_item = { + "utt": f"{name}_{i:08d}", + "audio": filepath, + "text": item["text"], + "duration": item["duration"] + } + manifest.append(manifest_item) + + # Write manifest file + with open(output_dir / f"{name}_manifest.json", "w", encoding="utf-8") as f: + for item in manifest: + f.write(json.dumps(item, ensure_ascii=False) + "\n") + + save_manifest(train_data, "train") + save_manifest(val_data, "dev") + save_manifest(test_data, "test") + + print(f"Dataset preparation complete. Files saved to {output_dir}") + + # Save config + stats = { + "language": language, + "total_examples": total_size, + "train_examples": len(train_data), + "dev_examples": len(val_data), + "test_examples": len(test_data), + "min_duration": min_duration, + "max_duration": max_duration, + } + + with open(output_dir / "stats.yaml", "w") as f: + yaml.dump(stats, f) + + print("Data preparation complete!") + + +def main(): + parser = argparse.ArgumentParser(description="Prepare Common Voice dataset for Whisper fine-tuning") + parser.add_argument("--language", type=str, default="en", help="Language code (e.g., en, fr, es)") + parser.add_argument("--output_dir", type=str, default="./data", help="Directory to save preprocessed data") + parser.add_argument("--cache_dir", type=str, default=None, help="Cache directory for HuggingFace datasets") + parser.add_argument("--val_size", type=float, default=0.03, help="Validation set size as fraction") + parser.add_argument("--test_size", type=float, default=0.03, help="Test set size as fraction") + parser.add_argument("--min_duration", type=float, default=0.5, help="Minimum audio duration in seconds") + parser.add_argument("--max_duration", type=float, default=30.0, help="Maximum audio duration in seconds") + parser.add_argument("--seed", type=int, default=42, help="Random seed") + + args = parser.parse_args() + prepare_common_voice( + language=args.language, + output_dir=args.output_dir, + cache_dir=args.cache_dir, + val_size=args.val_size, + test_size=args.test_size, + min_duration=args.min_duration, + max_duration=args.max_duration, + seed=args.seed + ) + + +if __name__ == "__main__": + main() diff --git a/examples/commonvoice/whisper/train.py b/examples/commonvoice/whisper/train.py new file mode 100644 index 000000000..5e0e492ee --- /dev/null +++ b/examples/commonvoice/whisper/train.py @@ -0,0 +1,510 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import sys +import time +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import paddle +import paddle.distributed as dist +import paddle.nn as nn +import paddle.nn.functional as F +import yaml +from paddle.io import DataLoader, Dataset +from paddle.optimizer import AdamW +from paddle.optimizer.lr import CosineAnnealingDecay, LinearWarmup + +from paddlespeech.s2t.models.whisper.tokenizer import get_tokenizer +from paddlespeech.s2t.models.whisper.whisper import (CHUNK_LENGTH, MODEL_DIMENSIONS, N_MELS, N_SAMPLES, Whisper, + log_mel_spectrogram, pad_or_trim) +from paddlespeech.s2t.utils.checkpoint import Checkpoint +from paddlespeech.s2t.training.extensions.evaluator import StandardEvaluator +from paddlespeech.s2t.training.extensions.reporter import ObsScope, Reporter +from paddlespeech.s2t.training.scheduler import LRSchedulerFactory +from paddlespeech.s2t.training.trainer import Trainer +from paddlespeech.s2t.utils.log import Log +from paddlespeech.s2t.utils.utility import get_rank, str2bool + +logger = Log(__name__).getlog() + + +class WhisperDataset(Dataset): + """Dataset for Whisper fine-tuning""" + + def __init__( + self, + manifest_path: str, + tokenizer, + target_language: str = "en", + max_duration: float = 30.0, + min_duration: float = 0.5, + sample_rate: int = 16000, + resource_path: str = '', + ): + """Initialize the dataset. + + Args: + manifest_path: Path to manifest file with audio paths and transcripts + tokenizer: Whisper tokenizer + target_language: Target language code + max_duration: Maximum audio duration + min_duration: Minimum audio duration + sample_rate: Audio sample rate + resource_path: Path to resources directory + """ + super().__init__() + + self.tokenizer = tokenizer + self.target_language = target_language + self.sample_rate = sample_rate + self.resource_path = resource_path + + # Load manifest + with open(manifest_path, 'r', encoding='utf8') as f: + manifest_lines = f.readlines() + + self.data = [] + for line in manifest_lines: + try: + item = json.loads(line.strip()) + duration = item.get('duration', 0) + if min_duration <= duration <= max_duration: + self.data.append(item) + except Exception as e: + logger.warning(f"Error parsing line in manifest: {e}") + + logger.info(f"Loaded {len(self.data)} examples from {manifest_path}") + + # Generate special tokens and language tokens + self.special_tokens = { + "sot": self.tokenizer.sot, + "eot": self.tokenizer.eot, + "translate": self.tokenizer.translate, + "transcribe": self.tokenizer.transcribe, + } + + # Get language token + self.language_token = self.tokenizer.language_tokens[self.target_language] + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + item = self.data[idx] + + # Load audio + audio_path = item["audio"] + try: + # Process audio to mel spectrogram + mel = log_mel_spectrogram(audio_path, n_mels=N_MELS, resource_path=self.resource_path) + mel = pad_or_trim(mel, N_SAMPLES) + + # Get text + text = item["text"] + + # Create prompt tokens + prompt_tokens = [ + self.special_tokens["sot"], + self.language_token, + self.special_tokens["transcribe"], + ] + + # Encode the text + target_tokens = ( + prompt_tokens + + self.tokenizer.encode(text) + + [self.special_tokens["eot"]] + ) + + return { + "mel": mel, + "target_tokens": np.array(target_tokens, dtype=np.int64), + "text": text + } + + except Exception as e: + logger.warning(f"Error processing {audio_path}: {e}") + # Return a dummy sample that will be filtered in collate_fn + return None + + @staticmethod + def collate_fn(batch): + """Collate function for DataLoader""" + # Filter None samples + batch = [sample for sample in batch if sample is not None] + if not batch: + return None + + # Get maximum sequence length in this batch + max_target_len = max(len(sample["target_tokens"]) for sample in batch) + + # Initialize tensors + mel_specs = [] + token_ids = [] + labels = [] + + for sample in batch: + target_tokens = sample["target_tokens"] + target_len = len(target_tokens) + + # Prepare inputs and labels for causal LM training + # Input tokens are shifted right + input_tokens = np.zeros(max_target_len, dtype=np.int64) + input_tokens[:target_len-1] = target_tokens[:target_len-1] # Exclude EOS + + # Labels are shifted left and padded with -100 (ignore index) + label = np.full(max_target_len, -100, dtype=np.int64) + label[:target_len-1] = target_tokens[1:target_len] # Start from first token after SOT + + # Add to lists + mel_specs.append(sample["mel"]) + token_ids.append(input_tokens) + labels.append(label) + + # Convert to tensors + mel_specs = paddle.to_tensor(np.array(mel_specs), dtype=paddle.float32) + token_ids = paddle.to_tensor(np.array(token_ids), dtype=paddle.int64) + labels = paddle.to_tensor(np.array(labels), dtype=paddle.int64) + + return { + "mel": mel_specs, + "tokens": token_ids, + "labels": labels, + } + + +class WhisperTrainer(Trainer): + """Trainer for Whisper fine-tuning""" + + def __init__(self, config, args): + """Initialize the trainer. + + Args: + config: Training configuration + args: Command-line arguments + """ + super().__init__() + self.config = config + self.args = args + self.resource_path = args.resource_path + + # Set random seed + if args.seed is not None: + paddle.seed(args.seed) + np.random.seed(args.seed) + + # Initialize distributed training if needed + if args.distributed: + dist.init_parallel_env() + + # Build model and optimizer + self.model = self._build_model() + self.optimizer = self._build_optimizer() + + # Build tokenizer + self.tokenizer = get_tokenizer( + multilingual=self.model.is_multilingual, + resource_path=self.resource_path, + language=config['data']['target_language'], + task="transcribe" + ) + + # Initialize checkpoint class + self._init_checkpoint() + + def _build_model(self): + """Build Whisper model""" + config = self.config + model_size = config['model']['size'] + + # Load model dimensions + model_dims = MODEL_DIMENSIONS[model_size] + + # Create the model + model = Whisper(model_dims) + + # Load checkpoint if provided + checkpoint_path = config['model']['checkpoint'] + if checkpoint_path: + logger.info(f"Loading checkpoint from {checkpoint_path}") + state_dict = paddle.load(checkpoint_path) + model.set_state_dict(state_dict) + + # Freeze encoder if needed + if config['model'].get('freeze_encoder', False): + logger.info("Freezing encoder parameters") + for param in model.encoder.parameters(): + param.stop_gradient = True + + # Handle distributed training + if self.args.distributed: + model = paddle.DataParallel(model) + + return model + + def _build_optimizer(self): + """Build optimizer and learning rate scheduler""" + config = self.config + model = self.model + + # Set learning rate + learning_rate = config['training']['learning_rate'] + + # Apply weight decay + decay_params = [ + p.name for n, p in model.named_parameters() + if not any(nd in n for nd in ["bias", "norm"]) + ] + + # Build scheduler + scheduler_name = config['training'].get('scheduler', 'linear') + num_training_steps = self.config['training'].get('max_steps', 100000) + + # Calculate warmup steps + warmup_ratio = config['training'].get('warmup_ratio', 0.1) + warmup_steps = int(num_training_steps * warmup_ratio) + + if scheduler_name == 'cosine': + lr_scheduler = LinearWarmup( + CosineAnnealingDecay(learning_rate, num_training_steps - warmup_steps), + warmup_steps, + 0.0, + learning_rate + ) + else: # default to linear + lr_scheduler = paddle.optimizer.lr.LinearWarmup( + paddle.optimizer.lr.PolynomialDecay( + learning_rate=learning_rate, + decay_steps=num_training_steps - warmup_steps, + end_lr=0.0, + power=1.0), + warmup_steps, + 0.0, + learning_rate + ) + + # Create optimizer + optimizer = AdamW( + learning_rate=lr_scheduler, + beta1=0.9, + beta2=0.999, + epsilon=1e-8, + parameters=model.parameters(), + weight_decay=config['training'].get('weight_decay', 0.01), + grad_clip=nn.ClipGradByNorm(config['training'].get('max_grad_norm', 1.0)), + apply_decay_param_fun=lambda x: x in decay_params + ) + + return optimizer + + def _init_checkpoint(self): + """Initialize checkpoint for saving and loading""" + config = self.config + args = self.args + + checkpoint_dir = Path(config['output']['checkpoint_dir']) + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + self.checkpoint = Checkpoint( + checkpoint_dir=checkpoint_dir, + model=self.model, + optimizer=self.optimizer, + infos=dict(), + visualizer=None, + **{"epoch": 0} + ) + + # Try to load from checkpoint if provided + if args.checkpoint_path: + self.checkpoint.load_parameters(args.checkpoint_path) + + def train(self): + """Run training""" + config = self.config + args = self.args + + # Create dataset + train_dataset = WhisperDataset( + manifest_path=config['data']['train_manifest'], + tokenizer=self.tokenizer, + target_language=config['data']['target_language'], + max_duration=config['data']['max_duration'], + min_duration=config['data']['min_duration'], + resource_path=self.resource_path, + ) + + dev_dataset = WhisperDataset( + manifest_path=config['data']['dev_manifest'], + tokenizer=self.tokenizer, + target_language=config['data']['target_language'], + max_duration=config['data']['max_duration'], + min_duration=config['data']['min_duration'], + resource_path=self.resource_path, + ) + + # Create data loaders + train_dataloader = DataLoader( + train_dataset, + batch_size=config['training']['batch_size'], + shuffle=True, + num_workers=config['training'].get('num_workers', 4), + collate_fn=WhisperDataset.collate_fn, + drop_last=True, + ) + + dev_dataloader = DataLoader( + dev_dataset, + batch_size=config['eval']['eval_batch_size'], + shuffle=False, + num_workers=config['training'].get('num_workers', 4), + collate_fn=WhisperDataset.collate_fn, + ) + + # Setup training + max_epoch = config['training']['max_epoch'] + accum_grad = config['training'].get('accum_grad', 1) + log_interval = config['training'].get('log_interval', 100) + save_interval = config['output'].get('save_interval', 1) + + # Initialize reporter for logging + reporter = Reporter() + reporter.register(self.model, "model") + + # Initialize evaluator + evaluator = StandardEvaluator(self.model) + + # Training loop + for epoch in range(max_epoch): + self.model.train() + train_loss = 0 + num_batches = 0 + start_time = time.time() + + for batch_idx, batch in enumerate(train_dataloader): + if batch is None: + continue + + mel = batch["mel"] + tokens = batch["tokens"] + labels = batch["labels"] + + # Forward pass + audio_features = self.model.embed_audio(mel) + logits = self.model.logits(tokens, audio_features) + + # Compute loss + loss = F.cross_entropy( + logits.reshape([-1, logits.shape[-1]]), + labels.reshape([-1]), + ignore_index=-100 + ) + + # Scale loss for gradient accumulation + if accum_grad > 1: + loss = loss / accum_grad + + # Backward pass + loss.backward() + + # Update parameters every accum_grad steps + if (batch_idx + 1) % accum_grad == 0: + self.optimizer.step() + self.optimizer.clear_grad() + + # Logging + train_loss += loss.item() * accum_grad + num_batches += 1 + + if batch_idx % log_interval == 0 and get_rank() == 0: + elapsed_time = time.time() - start_time + logger.info(f"Epoch {epoch}/{max_epoch} | Batch {batch_idx}/{len(train_dataloader)} | " + f"Loss: {loss.item()*accum_grad:.4f} | {elapsed_time:.2f}s elapsed") + + # End of epoch + avg_train_loss = train_loss / num_batches if num_batches > 0 else float('inf') + logger.info(f"Epoch {epoch}/{max_epoch} | Average Train Loss: {avg_train_loss:.4f}") + + # Evaluation + self.model.eval() + dev_losses = [] + + with paddle.no_grad(): + for batch in dev_dataloader: + if batch is None: + continue + + mel = batch["mel"] + tokens = batch["tokens"] + labels = batch["labels"] + + # Forward pass + audio_features = self.model.embed_audio(mel) + logits = self.model.logits(tokens, audio_features) + + # Compute loss + loss = F.cross_entropy( + logits.reshape([-1, logits.shape[-1]]), + labels.reshape([-1]), + ignore_index=-100 + ) + + dev_losses.append(loss.item()) + + avg_dev_loss = sum(dev_losses) / len(dev_losses) if dev_losses else float('inf') + logger.info(f"Epoch {epoch}/{max_epoch} | Validation Loss: {avg_dev_loss:.4f}") + + # Update checkpoint info + self.checkpoint.infos["epoch"] = epoch + self.checkpoint.infos["train_loss"] = avg_train_loss + self.checkpoint.infos["dev_loss"] = avg_dev_loss + + # Save checkpoint + if epoch % save_interval == 0 and get_rank() == 0: + self.checkpoint.save_parameters(tag=f"epoch_{epoch}") + + # Save final model + if get_rank() == 0: + self.checkpoint.save_parameters(tag="final") + logger.info(f"Training completed. Final model saved at {self.checkpoint.checkpoint_dir}") + + +def main(): + parser = argparse.ArgumentParser(description="Train Whisper model") + parser.add_argument("--config", type=str, required=True, help="Path to configuration file") + parser.add_argument("--resource_path", type=str, default="./resources", help="Path to resources directory") + parser.add_argument("--device", type=str, default="gpu", choices=["cpu", "gpu", "xpu"], help="Device to use") + parser.add_argument("--seed", type=int, default=42, help="Random seed") + parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to checkpoint to resume from") + parser.add_argument("--distributed", type=str2bool, default=False, help="Enable distributed training") + + args = parser.parse_args() + + # Load configuration + with open(args.config) as f: + config = yaml.safe_load(f) + + # Set device + paddle.set_device(args.device) + + trainer = WhisperTrainer(config, args) + trainer.train() + + +if __name__ == "__main__": + import json + main() diff --git a/examples/commonvoice/whisper/visualize.py b/examples/commonvoice/whisper/visualize.py new file mode 100644 index 000000000..912cd39b7 --- /dev/null +++ b/examples/commonvoice/whisper/visualize.py @@ -0,0 +1,295 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import soundfile as sf +from matplotlib.ticker import MaxNLocator + +from paddlespeech.s2t.utils.log import Log + +logger = Log(__name__).getlog() + + +def plot_waveform(audio_path, output_dir=None, show=True): + """Plot waveform of audio file.""" + try: + audio, sr = sf.read(audio_path) + if audio.ndim > 1: + audio = audio[:, 0] # Take first channel if stereo + + duration = len(audio) / sr + time_axis = np.linspace(0, duration, len(audio)) + + plt.figure(figsize=(12, 4)) + plt.plot(time_axis, audio, color='#1f77b4') + plt.title(f"Waveform: {os.path.basename(audio_path)}") + plt.xlabel("Time (seconds)") + plt.ylabel("Amplitude") + plt.grid(alpha=0.3) + + if output_dir: + output_path = os.path.join(output_dir, f"{os.path.splitext(os.path.basename(audio_path))[0]}_waveform.png") + plt.savefig(output_path, dpi=150, bbox_inches='tight') + logger.info(f"Waveform plot saved to {output_path}") + + if show: + plt.show() + else: + plt.close() + + return True + except Exception as e: + logger.error(f"Error plotting waveform: {e}") + return False + + +def plot_transcription_comparison(reference, hypothesis, audio_path=None, output_dir=None, show=True): + """Plot comparison between reference and hypothesis transcriptions.""" + try: + fig, ax = plt.subplots(figsize=(12, 6)) + + # Title with audio file name if provided + title = "Reference vs. Hypothesis" + if audio_path: + title += f": {os.path.basename(audio_path)}" + plt.title(title) + + # Create comparison table + rows = [] + rows.append(["Reference", reference]) + rows.append(["Hypothesis", hypothesis]) + + # Calculate word-level differences + ref_words = reference.strip().split() + hyp_words = hypothesis.strip().split() + + from difflib import SequenceMatcher + matcher = SequenceMatcher(None, ref_words, hyp_words) + + # Generate color-coded text for differences + ref_colored = [] + hyp_colored = [] + + for op, i1, i2, j1, j2 in matcher.get_opcodes(): + if op == 'equal': + ref_colored.extend(ref_words[i1:i2]) + hyp_colored.extend(hyp_words[j1:j2]) + elif op == 'replace': + ref_colored.extend([f"**{w}**" for w in ref_words[i1:i2]]) + hyp_colored.extend([f"**{w}**" for w in hyp_words[j1:j2]]) + elif op == 'delete': + ref_colored.extend([f"**{w}**" for w in ref_words[i1:i2]]) + elif op == 'insert': + hyp_colored.extend([f"**{w}**" for w in hyp_words[j1:j2]]) + + rows.append(["Ref (Diff)", " ".join(ref_colored)]) + rows.append(["Hyp (Diff)", " ".join(hyp_colored)]) + + # Calculate error metrics + word_error = sum(1 for i, j in zip(ref_words, hyp_words) if i != j) + if len(ref_words) > 0: + word_error_rate = word_error / len(ref_words) + else: + word_error_rate = 0 + + rows.append(["Word Error", f"{word_error}/{len(ref_words)} ({word_error_rate:.2%})"]) + + # Create table + ax.axis('tight') + ax.axis('off') + table = ax.table(cellText=rows, colWidths=[0.15, 0.85], loc='center', cellLoc='left') + table.auto_set_font_size(False) + table.set_fontsize(10) + table.scale(1, 1.5) + + # Highlight difference rows with light blue background + for i in [2, 3]: + for j in range(2): + cell = table._cells[(i, j)] + cell.set_facecolor('#E6F3FF') + + if output_dir: + base_name = os.path.splitext(os.path.basename(audio_path))[0] if audio_path else "comparison" + output_path = os.path.join(output_dir, f"{base_name}_comparison.png") + plt.savefig(output_path, dpi=150, bbox_inches='tight') + logger.info(f"Comparison plot saved to {output_path}") + + if show: + plt.show() + else: + plt.close() + + return True + except Exception as e: + logger.error(f"Error plotting transcription comparison: {e}") + return False + + +def plot_evaluation_metrics(results_file, output_dir=None, show=True): + """Plot evaluation metrics from results file.""" + try: + with open(results_file, 'r', encoding='utf-8') as f: + results = json.load(f) + + metrics = results['metrics'] + details = results['details'] + + # Extract duration information + durations = [item['duration'] for item in details if 'duration' in item] + if not durations: + durations = [0] * len(details) # Default if no durations available + + # Calculate per-sample WER + from paddlespeech.metrics.wer import word_errors + wers = [] + for item in details: + ref = item['reference'] + hyp = item['hypothesis'] + word_error_count, word_count = word_errors(ref, hyp) + wer = word_error_count / max(1, word_count) + wers.append(wer * 100) # Convert to percentage + + # Create plots + fig, axes = plt.subplots(2, 2, figsize=(14, 10)) + + # Overall metrics + ax = axes[0, 0] + metric_names = ['WER', 'CER'] + metric_values = [metrics['wer'] * 100, metrics['cer'] * 100] # Convert to percentage + ax.bar(metric_names, metric_values, color=['#1f77b4', '#ff7f0e']) + ax.set_title('Overall Error Metrics') + ax.set_ylabel('Error Rate (%)') + ax.grid(axis='y', alpha=0.3) + for i, v in enumerate(metric_values): + ax.text(i, v + 0.5, f"{v:.2f}%", ha='center') + + # WER vs Duration scatter plot + ax = axes[0, 1] + ax.scatter(durations, wers, alpha=0.6) + ax.set_title('WER vs Audio Duration') + ax.set_xlabel('Duration (seconds)') + ax.set_ylabel('WER (%)') + ax.grid(alpha=0.3) + + # WER distribution histogram + ax = axes[1, 0] + ax.hist(wers, bins=20, color='#2ca02c', alpha=0.7) + ax.set_title('WER Distribution') + ax.set_xlabel('WER (%)') + ax.set_ylabel('Number of Samples') + ax.grid(alpha=0.3) + + # Sample count by WER range + ax = axes[1, 1] + wer_ranges = [0, 5, 10, 20, 30, 50, 100] + wer_bins = pd.cut(wers, wer_ranges, right=False) + wer_counts = pd.value_counts(wer_bins).sort_index() + + wer_labels = [f"{wer_ranges[i]}-{wer_ranges[i+1]}%" for i in range(len(wer_ranges)-1)] + ax.bar(wer_labels, wer_counts, color='#d62728') + ax.set_title('Sample Count by WER Range') + ax.set_xlabel('WER Range') + ax.set_ylabel('Number of Samples') + ax.tick_params(axis='x', rotation=45) + ax.yaxis.set_major_locator(MaxNLocator(integer=True)) + ax.grid(axis='y', alpha=0.3) + + plt.tight_layout() + + if output_dir: + output_path = os.path.join(output_dir, "evaluation_metrics.png") + plt.savefig(output_path, dpi=150, bbox_inches='tight') + logger.info(f"Metrics plots saved to {output_path}") + + if show: + plt.show() + else: + plt.close() + + return True + except Exception as e: + logger.error(f"Error plotting evaluation metrics: {e}") + return False + + +def visualize_results(results_file, output_dir=None, audio_dir=None, num_samples=5, show=True): + """Visualize evaluation results.""" + try: + # Create output directory if needed + if output_dir: + os.makedirs(output_dir, exist_ok=True) + + # Load results + with open(results_file, 'r', encoding='utf-8') as f: + results = json.load(f) + + # Plot overall metrics + plot_evaluation_metrics(results_file, output_dir, show) + + # Plot individual examples + if num_samples > 0: + details = results['details'] + samples = details[:num_samples] + + for i, sample in enumerate(samples): + audio_path = sample['audio'] + reference = sample['reference'] + hypothesis = sample['hypothesis'] + + # If audio directory is provided, use audio files from there + if audio_dir: + base_name = os.path.basename(audio_path) + audio_path = os.path.join(audio_dir, base_name) + + # Plot waveform if audio file exists + if os.path.exists(audio_path): + plot_waveform(audio_path, output_dir, show) + + # Plot transcription comparison + plot_transcription_comparison(reference, hypothesis, audio_path, output_dir, show) + + return True + except Exception as e: + logger.error(f"Error visualizing results: {e}") + return False + + +def main(): + parser = argparse.ArgumentParser(description="Visualize Whisper evaluation results") + parser.add_argument("--results_file", type=str, required=True, help="Path to evaluation results JSON file") + parser.add_argument("--output_dir", type=str, default="./visualizations", help="Directory to save visualization outputs") + parser.add_argument("--audio_dir", type=str, default=None, help="Directory containing audio files (optional)") + parser.add_argument("--num_samples", type=int, default=5, help="Number of individual samples to visualize") + parser.add_argument("--show", action="store_true", help="Show plots interactively") + + args = parser.parse_args() + + visualize_results( + results_file=args.results_file, + output_dir=args.output_dir, + audio_dir=args.audio_dir, + num_samples=args.num_samples, + show=args.show + ) + + +if __name__ == "__main__": + main() diff --git a/examples/commonvoice/whisper/whisper_cli.py b/examples/commonvoice/whisper/whisper_cli.py new file mode 100644 index 000000000..f8f82c23a --- /dev/null +++ b/examples/commonvoice/whisper/whisper_cli.py @@ -0,0 +1,240 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import sys +from pathlib import Path + +import paddle +import yaml + +from paddlespeech.s2t.utils.log import Log +from paddlespeech.s2t.utils.utility import str2bool + +logger = Log(__name__).getlog() + + +def prepare_parser(): + """Prepare argument parser.""" + parser = argparse.ArgumentParser( + description="Whisper CLI for fine-tuning, inference, and evaluation", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + # Create subparsers for different commands + subparsers = parser.add_subparsers(dest="command", help="Command to run") + + # Data preparation command + data_parser = subparsers.add_parser("prepare", help="Prepare data for fine-tuning") + data_parser.add_argument("--language", type=str, default="en", help="Language code (e.g., en, fr, es)") + data_parser.add_argument("--output_dir", type=str, default="./data", help="Directory to save preprocessed data") + data_parser.add_argument("--cache_dir", type=str, default=None, help="Cache directory for HuggingFace datasets") + data_parser.add_argument("--val_size", type=float, default=0.03, help="Validation set size as fraction") + data_parser.add_argument("--test_size", type=float, default=0.03, help="Test set size as fraction") + data_parser.add_argument("--min_duration", type=float, default=0.5, help="Minimum audio duration in seconds") + data_parser.add_argument("--max_duration", type=float, default=30.0, help="Maximum audio duration in seconds") + data_parser.add_argument("--seed", type=int, default=42, help="Random seed") + + # Fine-tuning command + train_parser = subparsers.add_parser("train", help="Fine-tune Whisper model") + train_parser.add_argument("--config", type=str, required=True, help="Path to configuration file") + train_parser.add_argument("--resource_path", type=str, default="./resources", help="Path to resources directory") + train_parser.add_argument("--device", type=str, default="gpu", choices=["cpu", "gpu", "xpu"], help="Device to use") + train_parser.add_argument("--seed", type=int, default=42, help="Random seed") + train_parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to checkpoint to resume from") + train_parser.add_argument("--distributed", type=str2bool, default=False, help="Enable distributed training") + + # Inference command + infer_parser = subparsers.add_parser("infer", help="Inference with Whisper model") + infer_parser.add_argument("--audio_file", type=str, help="Path to audio file for transcription") + infer_parser.add_argument("--audio_dir", type=str, help="Path to directory containing audio files") + infer_parser.add_argument("--output_dir", type=str, default="./transcripts", help="Output directory for transcriptions") + infer_parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint from fine-tuning") + infer_parser.add_argument("--resource_path", type=str, default="./resources", + help="Path to resources directory containing original models and assets") + + # Model options for inference + infer_parser.add_argument("--use_original", type=str2bool, default=False, + help="Use original Whisper model instead of fine-tuned") + infer_parser.add_argument("--model_size", type=str, default="base", + choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], + help="Model size for original Whisper") + + # Decoding options for inference + infer_parser.add_argument("--language", type=str, default=None, help="Language code (e.g., en, fr, auto for detection)") + infer_parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], + help="Task: transcribe or translate to English") + infer_parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for sampling") + infer_parser.add_argument("--beam_size", type=int, default=5, help="Beam size for beam search") + infer_parser.add_argument("--without_timestamps", type=str2bool, default=False, help="Don't include timestamps") + + # Evaluation command + eval_parser = subparsers.add_parser("evaluate", help="Evaluate Whisper model") + eval_parser.add_argument("--manifest", type=str, required=True, help="Path to manifest file") + eval_parser.add_argument("--output_dir", type=str, default="./results", help="Directory to save results") + eval_parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint from fine-tuning") + eval_parser.add_argument("--resource_path", type=str, default="./resources", help="Path to resources directory") + + # Model options for evaluation + eval_parser.add_argument("--model_size", type=str, default="base", + choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], + help="Model size for original Whisper") + + # Decoding options for evaluation + eval_parser.add_argument("--language", type=str, default=None, help="Language code (e.g., en, fr)") + eval_parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], + help="Task: transcribe or translate to English") + eval_parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for sampling") + eval_parser.add_argument("--beam_size", type=int, default=5, help="Beam size for beam search") + eval_parser.add_argument("--without_timestamps", type=str2bool, default=True, help="Don't include timestamps") + eval_parser.add_argument("--max_samples", type=int, default=None, help="Maximum number of samples to evaluate") + + return parser + + +def main(): + parser = prepare_parser() + args = parser.parse_args() + + if not args.command: + parser.print_help() + sys.exit(1) + + # Execute the appropriate command + if args.command == "prepare": + from prepare_data import prepare_common_voice + prepare_common_voice( + language=args.language, + output_dir=args.output_dir, + cache_dir=args.cache_dir, + val_size=args.val_size, + test_size=args.test_size, + min_duration=args.min_duration, + max_duration=args.max_duration, + seed=args.seed + ) + + elif args.command == "train": + import train + if args.distributed: + import paddle.distributed as dist + dist.init_parallel_env() + + # Load configuration + with open(args.config) as f: + config = yaml.safe_load(f) + + # Set device + paddle.set_device(args.device) + + trainer = train.WhisperTrainer(config, args) + trainer.train() + + elif args.command == "infer": + from infer import load_whisper_model, transcribe_file, batch_transcribe + + # Validate arguments + if not args.audio_file and not args.audio_dir: + logger.error("Either --audio_file or --audio_dir must be specified") + sys.exit(1) + + if args.use_original: + # Use original Whisper model + if not args.resource_path: + logger.error("--resource_path must be specified when using original model") + sys.exit(1) + + model = load_whisper_model( + model_size=args.model_size, + resource_path=args.resource_path + ) + else: + # Use fine-tuned model + if not args.checkpoint: + logger.error("--checkpoint must be specified when not using original model") + sys.exit(1) + + model = load_whisper_model( + model_size=args.model_size, + checkpoint_path=args.checkpoint, + resource_path=args.resource_path + ) + + # Prepare transcription keyword arguments + transcribe_kwargs = { + "language": args.language, + "task": args.task, + "temperature": args.temperature, + "beam_size": args.beam_size, + "without_timestamps": args.without_timestamps, + } + + # Run transcription + if args.audio_file: + result = transcribe_file(model, args.audio_file, args.resource_path, **transcribe_kwargs) + if result: + print("-" * 40) + print("Transcription:") + print(result["text"]) + print("-" * 40) + + # Save to output directory if specified + if args.output_dir: + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + file_name = os.path.splitext(os.path.basename(args.audio_file))[0] + with open(output_dir / f"{file_name}.txt", "w", encoding="utf-8") as f: + f.write(result["text"].strip()) + + import json + with open(output_dir / f"{file_name}.json", "w", encoding="utf-8") as f: + json.dump(result, f, ensure_ascii=False, indent=2) + + print(f"Results saved to {output_dir}") + + elif args.audio_dir: + batch_transcribe( + model, + args.audio_dir, + args.output_dir, + args.resource_path, + **transcribe_kwargs + ) + + elif args.command == "evaluate": + from evaluate import load_model, evaluate_manifest + + model = load_model( + model_size=args.model_size, + checkpoint_path=args.checkpoint, + resource_path=args.resource_path + ) + + evaluate_manifest( + model=model, + manifest_path=args.manifest, + resource_path=args.resource_path, + language=args.language, + task=args.task, + temperature=args.temperature, + beam_size=args.beam_size, + without_timestamps=args.without_timestamps, + output_dir=args.output_dir, + max_samples=args.max_samples + ) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/s2t/models/whisper/fine_tune/__init__.py b/paddlespeech/s2t/models/whisper/fine_tune/__init__.py new file mode 100644 index 000000000..322f20e59 --- /dev/null +++ b/paddlespeech/s2t/models/whisper/fine_tune/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Whisper fine-tuning module. + +This module provides utilities and classes for fine-tuning Whisper models +on custom datasets, following the paper "Robust Speech Recognition via +Large-Scale Weak Supervision" and Hugging Face's fine-tuning approach. + +References: + - Radford, A. et al. (2022). Robust Speech Recognition via Large-Scale Weak Supervision. + - https://github.com/openai/whisper + - https://huggingface.co/blog/fine-tune-whisper + - Whisper Original Paper: https://arxiv.org/abs/2212.04356 +""" + +from paddlespeech.s2t.models.whisper.fine_tune.dataset import ( + WhisperDataset, + WhisperInferenceDataset +) + +from paddlespeech.s2t.models.whisper.fine_tune.trainer import ( + WhisperTrainer +) + +__all__ = [ + 'WhisperDataset', + 'WhisperInferenceDataset', + 'WhisperTrainer', +] diff --git a/paddlespeech/s2t/models/whisper/fine_tune/dataset.py b/paddlespeech/s2t/models/whisper/fine_tune/dataset.py new file mode 100644 index 000000000..fdfb2fca6 --- /dev/null +++ b/paddlespeech/s2t/models/whisper/fine_tune/dataset.py @@ -0,0 +1,360 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import paddle +from paddle.io import Dataset, DataLoader + +from paddlespeech.s2t.models.whisper.whisper import (N_MELS, N_SAMPLES, log_mel_spectrogram, pad_or_trim) +from paddlespeech.s2t.utils.log import Log + +logger = Log(__name__).getlog() + + +class WhisperDataset(Dataset): + """Dataset for Whisper fine-tuning""" + + def __init__( + self, + manifest_path: str, + tokenizer, + target_language: str = "en", + task: str = "transcribe", + max_duration: float = 30.0, + min_duration: float = 0.5, + sample_rate: int = 16000, + resource_path: str = '', + pad_to_max_length: bool = False, + ): + """Initialize the dataset. + + Args: + manifest_path: Path to manifest file with audio paths and transcripts + tokenizer: Whisper tokenizer + target_language: Target language code + task: Task type, either 'transcribe' or 'translate' + max_duration: Maximum audio duration + min_duration: Minimum audio duration + sample_rate: Audio sample rate + resource_path: Path to resources directory + pad_to_max_length: Whether to pad all sequences to maximum length in batch + """ + super().__init__() + + self.tokenizer = tokenizer + self.target_language = target_language + self.task = task + self.sample_rate = sample_rate + self.resource_path = resource_path + self.pad_to_max_length = pad_to_max_length + + # Load manifest + with open(manifest_path, 'r', encoding='utf8') as f: + manifest_lines = f.readlines() + + self.data = [] + for line in manifest_lines: + try: + item = json.loads(line.strip()) + duration = item.get('duration', 0) + if min_duration <= duration <= max_duration: + self.data.append(item) + except Exception as e: + logger.warning(f"Error parsing line in manifest: {e}") + + logger.info(f"Loaded {len(self.data)} examples from {manifest_path}") + + # Generate special tokens and language tokens + self.special_tokens = { + "sot": self.tokenizer.sot, + "eot": self.tokenizer.eot, + "translate": self.tokenizer.translate if hasattr(self.tokenizer, "translate") else None, + "transcribe": self.tokenizer.transcribe if hasattr(self.tokenizer, "transcribe") else None, + "no_timestamps": self.tokenizer.no_timestamps if hasattr(self.tokenizer, "no_timestamps") else None, + } + + # Get language token + self.language_token = self.tokenizer.language_tokens.get(self.target_language) if hasattr(self.tokenizer, "language_tokens") else None + if not self.language_token and hasattr(self.tokenizer, "language_token"): + self.language_token = self.tokenizer.language_token(self.target_language) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + item = self.data[idx] + + # Load audio + audio_path = item["audio"] + try: + # Process audio to mel spectrogram + mel = log_mel_spectrogram(audio_path, n_mels=N_MELS, resource_path=self.resource_path) + mel = pad_or_trim(mel, N_SAMPLES) + + # Get text + text = item["text"] + + # Create prompt tokens + prompt_tokens = [self.special_tokens["sot"]] + + # Add language token if available + if self.language_token is not None: + prompt_tokens.append(self.language_token) + + # Add task token if available + task_token = self.special_tokens.get(self.task) + if task_token is not None: + prompt_tokens.append(task_token) + + # Add no_timestamps token if available + if self.special_tokens["no_timestamps"] is not None: + prompt_tokens.append(self.special_tokens["no_timestamps"]) + + # Encode the text + target_tokens = ( + prompt_tokens + + self.tokenizer.encode(text) + + [self.special_tokens["eot"]] + ) + + return { + "mel": mel, + "target_tokens": np.array(target_tokens, dtype=np.int64), + "text": text, + "audio_path": audio_path + } + + except Exception as e: + logger.warning(f"Error processing {audio_path}: {e}") + # Return a dummy sample that will be filtered in collate_fn + return None + + @staticmethod + def collate_fn(batch, pad_idx=-100): + """Collate function for DataLoader""" + # Filter None samples + batch = [sample for sample in batch if sample is not None] + if not batch: + return None + + # Get maximum sequence length in this batch + max_target_len = max(len(sample["target_tokens"]) for sample in batch) + + # Initialize tensors + mel_specs = [] + token_ids = [] + labels = [] + + # Also collect metadata for debugging/logging + texts = [] + audio_paths = [] + + for sample in batch: + target_tokens = sample["target_tokens"] + target_len = len(target_tokens) + + # Prepare inputs and labels for causal LM training + # Input tokens are shifted right + input_tokens = np.zeros(max_target_len, dtype=np.int64) + input_tokens[:target_len-1] = target_tokens[:target_len-1] # Exclude EOS + + # Labels are shifted left and padded with pad_idx (ignore index) + label = np.full(max_target_len, pad_idx, dtype=np.int64) + label[:target_len-1] = target_tokens[1:target_len] # Start from first token after SOT + + # Add to lists + mel_specs.append(sample["mel"]) + token_ids.append(input_tokens) + labels.append(label) + + # Collect metadata + texts.append(sample["text"]) + audio_paths.append(sample["audio_path"]) + + # Convert to tensors + mel_specs = paddle.to_tensor(np.array(mel_specs), dtype=paddle.float32) + token_ids = paddle.to_tensor(np.array(token_ids), dtype=paddle.int64) + labels = paddle.to_tensor(np.array(labels), dtype=paddle.int64) + + return { + "mel": mel_specs, + "tokens": token_ids, + "labels": labels, + "texts": texts, + "audio_paths": audio_paths, + } + + def create_dataloader(self, + batch_size=16, + num_workers=4, + shuffle=True, + drop_last=False): + """Create a dataloader from this dataset""" + return DataLoader( + self, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + collate_fn=self.collate_fn, + drop_last=drop_last + ) + + +class WhisperInferenceDataset(Dataset): + """Dataset for Whisper inference with batching support""" + + def __init__( + self, + audio_paths: Union[str, List[str]], + tokenizer=None, + language: Optional[str] = None, + task: str = "transcribe", + sample_rate: int = 16000, + resource_path: str = '', + ): + """Initialize the inference dataset. + + Args: + audio_paths: Path to audio file or directory, or list of audio file paths + tokenizer: Whisper tokenizer (optional, only needed for preparing decoder input) + language: Language code (e.g., 'en', 'fr') + task: Task type, either 'transcribe' or 'translate' + sample_rate: Audio sample rate + resource_path: Path to resources directory + """ + super().__init__() + + self.tokenizer = tokenizer + self.language = language + self.task = task + self.sample_rate = sample_rate + self.resource_path = resource_path + + # Process audio paths + if isinstance(audio_paths, str): + # Single file + if os.path.isfile(audio_paths): + self.audio_paths = [audio_paths] + # Directory + elif os.path.isdir(audio_paths): + self.audio_paths = [ + os.path.join(audio_paths, f) for f in os.listdir(audio_paths) + if f.endswith(('.wav', '.mp3', '.flac', '.ogg')) + ] + else: + raise ValueError(f"Path not found: {audio_paths}") + else: + # List of files + self.audio_paths = audio_paths + + logger.info(f"Loaded {len(self.audio_paths)} audio files for inference") + + # Generate special tokens and language tokens if tokenizer provided + self.special_tokens = None + self.language_token = None + + if self.tokenizer: + self.special_tokens = { + "sot": self.tokenizer.sot, + "eot": self.tokenizer.eot, + "translate": self.tokenizer.translate if hasattr(self.tokenizer, "translate") else None, + "transcribe": self.tokenizer.transcribe if hasattr(self.tokenizer, "transcribe") else None, + "no_timestamps": self.tokenizer.no_timestamps if hasattr(self.tokenizer, "no_timestamps") else None, + } + + # Get language token + if self.language and hasattr(self.tokenizer, "language_tokens"): + self.language_token = self.tokenizer.language_tokens.get(self.language) + elif self.language and hasattr(self.tokenizer, "language_token"): + self.language_token = self.tokenizer.language_token(self.language) + + def __len__(self): + return len(self.audio_paths) + + def __getitem__(self, idx): + audio_path = self.audio_paths[idx] + + # Compute mel spectrogram + try: + mel = log_mel_spectrogram(audio_path, self.resource_path, self.sample_rate) + except Exception as e: + logger.error(f"Error processing audio file {audio_path}: {e}") + # Return zero tensor with correct shape in case of error + mel = paddle.zeros((N_MELS, N_SAMPLES // 160)) + + return { + "mel": mel.numpy(), + "audio_path": audio_path, + } + + @staticmethod + def collate_fn(batch): + """Collate function for inference DataLoader""" + # Extract mel spectrograms and audio paths + mel_specs = [] + audio_paths = [] + + for sample in batch: + mel_specs.append(sample["mel"]) + audio_paths.append(sample["audio_path"]) + + # Convert to tensors + mel_specs = paddle.to_tensor(np.array(mel_specs), dtype=paddle.float32) + + return { + "mel": mel_specs, + "audio_paths": audio_paths, + } + + def create_dataloader(self, + batch_size=1, + num_workers=4): + """Create a dataloader from this inference dataset""" + return DataLoader( + self, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + collate_fn=self.collate_fn, + drop_last=False + ) + + def prepare_decoder_input(self): + """Prepare decoder input tokens for initial prompt + + Only used when tokenizer is provided. + """ + if not self.tokenizer: + logger.warning("Cannot prepare decoder input without tokenizer") + return None + + # Create initial tokens - similar to training but without labels + tokens = [self.special_tokens["sot"]] + + if self.language_token: + tokens.append(self.language_token) + + if self.task == "translate": + tokens.append(self.special_tokens["translate"]) + elif self.task == "transcribe": + tokens.append(self.special_tokens["transcribe"]) + + if self.special_tokens["no_timestamps"]: + tokens.append(self.special_tokens["no_timestamps"]) + + return paddle.to_tensor([tokens], dtype=paddle.int64) diff --git a/paddlespeech/s2t/models/whisper/fine_tune/trainer.py b/paddlespeech/s2t/models/whisper/fine_tune/trainer.py new file mode 100644 index 000000000..4c613c932 --- /dev/null +++ b/paddlespeech/s2t/models/whisper/fine_tune/trainer.py @@ -0,0 +1,380 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from pathlib import Path +from typing import Dict, List, Optional, Union + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import paddle.distributed as dist +from paddle.io import DataLoader +from paddle.optimizer import AdamW +from paddle.optimizer.lr import CosineAnnealingDecay, LinearWarmup + +from paddlespeech.s2t.models.whisper.tokenizer import get_tokenizer +from paddlespeech.s2t.models.whisper.whisper import MODEL_DIMENSIONS, Whisper +from paddlespeech.s2t.training.reporter import ObsScope, Reporter +from paddlespeech.s2t.training.timer import Timer +from paddlespeech.s2t.utils.checkpoint import Checkpoint +from paddlespeech.s2t.utils.log import Log +from paddlespeech.s2t.utils.utility import get_rank + +logger = Log(__name__).getlog() + + +class WhisperTrainer: + """Trainer for fine-tuning Whisper models.""" + + def __init__( + self, + config: Dict, + model: Optional[Whisper] = None, + optimizer: Optional[paddle.optimizer.Optimizer] = None, + checkpoint_dir: Optional[Union[str, Path]] = None, + resource_path: Optional[str] = None, + log_interval: int = 10, + save_interval: int = 1, + rank: int = 0, + world_size: int = 1, + ): + """Initialize the trainer. + + Args: + config: Training configuration dictionary + model: Whisper model instance, or None to build one from config + optimizer: Optimizer instance, or None to build one from config + checkpoint_dir: Directory to save checkpoints + resource_path: Path to resources directory containing whisper assets + log_interval: Steps between logging + save_interval: Epochs between checkpoint saving + rank: Local rank for distributed training + world_size: Total number of processes for distributed training + """ + self.config = config + self.resource_path = resource_path + self.log_interval = log_interval + self.save_interval = save_interval + self.rank = rank + self.world_size = world_size + + # Initialize model if not provided + self.model = model if model is not None else self._build_model() + + # Initialize tokenizer + self.tokenizer = self._build_tokenizer() + + # Initialize optimizer if not provided + self.optimizer = optimizer if optimizer is not None else self._build_optimizer() + + # Initialize checkpoint + self.checkpoint = self._init_checkpoint(checkpoint_dir) + + # Initialize reporter and timer + self.reporter = Reporter() + self.timer = Timer() + + def _build_model(self) -> Whisper: + """Build Whisper model from config.""" + model_config = self.config['model'] + model_size = model_config['size'] + + # Load model dimensions + model_dims = MODEL_DIMENSIONS[model_size] + + # Create the model + model = Whisper(model_dims) + + # Load checkpoint if provided + checkpoint_path = model_config.get('checkpoint') + if checkpoint_path: + logger.info(f"Loading model weights from {checkpoint_path}") + state_dict = paddle.load(checkpoint_path) + model.set_state_dict(state_dict) + + # Freeze encoder if needed + if model_config.get('freeze_encoder', False): + logger.info("Freezing encoder parameters") + for param in model.encoder.parameters(): + param.stop_gradient = True + + # Handle distributed training + if self.world_size > 1: + logger.info(f"Initializing distributed model with {self.world_size} processes") + model = paddle.DataParallel(model) + + return model + + def _build_tokenizer(self): + """Build Whisper tokenizer.""" + target_language = self.config['data'].get('target_language', 'en') + task = self.config['data'].get('task', 'transcribe') + + return get_tokenizer( + multilingual=self.model.is_multilingual, + resource_path=self.resource_path, + language=target_language, + task=task + ) + + def _build_optimizer(self) -> paddle.optimizer.Optimizer: + """Build optimizer and learning rate scheduler.""" + training_config = self.config['training'] + + # Set learning rate + learning_rate = training_config['learning_rate'] + + # Apply weight decay + decay_params = [ + p.name for n, p in self.model.named_parameters() + if not any(nd in n for nd in ["bias", "norm"]) + ] + + # Build scheduler + scheduler_name = training_config.get('scheduler', 'linear') + max_steps = training_config.get('max_steps', 100000) + + # Calculate warmup steps + warmup_ratio = training_config.get('warmup_ratio', 0.1) + warmup_steps = int(max_steps * warmup_ratio) + + if scheduler_name == 'cosine': + lr_scheduler = LinearWarmup( + CosineAnnealingDecay(learning_rate, max_steps - warmup_steps), + warmup_steps, + 0.0, + learning_rate + ) + else: # default to linear + lr_scheduler = paddle.optimizer.lr.LinearWarmup( + paddle.optimizer.lr.PolynomialDecay( + learning_rate=learning_rate, + decay_steps=max_steps - warmup_steps, + end_lr=0.0, + power=1.0), + warmup_steps, + 0.0, + learning_rate + ) + + # Create optimizer + weight_decay = training_config.get('weight_decay', 0.01) + max_grad_norm = training_config.get('max_grad_norm', 1.0) + + optimizer = AdamW( + learning_rate=lr_scheduler, + beta1=0.9, + beta2=0.999, + epsilon=1e-8, + parameters=self.model.parameters(), + weight_decay=weight_decay, + grad_clip=nn.ClipGradByNorm(max_grad_norm), + apply_decay_param_fun=lambda x: x in decay_params + ) + + return optimizer + + def _init_checkpoint(self, checkpoint_dir=None) -> Checkpoint: + """Initialize checkpoint for saving and loading.""" + if checkpoint_dir is None: + checkpoint_dir = self.config['output'].get('checkpoint_dir', './exp/whisper_fine_tune') + + checkpoint_dir = Path(checkpoint_dir) + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + return Checkpoint( + checkpoint_dir=checkpoint_dir, + model=self.model, + optimizer=self.optimizer, + infos=dict(), + visualizer=None, + **{"epoch": 0} + ) + + def train(self, train_loader: DataLoader, dev_loader: Optional[DataLoader] = None, num_epochs: int = 10): + """Run training loop. + + Args: + train_loader: DataLoader for training data + dev_loader: DataLoader for validation data + num_epochs: Number of epochs to train + """ + self.reporter.register(self.model, "model") + + # Get training configuration + training_config = self.config['training'] + max_epoch = num_epochs or training_config.get('max_epoch', 10) + accum_grad = training_config.get('accum_grad', 1) + + # Start training + logger.info(f"Starting training for {max_epoch} epochs") + self.timer.start() + + # Resume from checkpoint if epoch > 0 + start_epoch = self.checkpoint.infos.get("epoch", 0) + + for epoch in range(start_epoch, max_epoch): + self._train_epoch(train_loader, epoch, accum_grad) + + # Validation + if dev_loader is not None: + dev_loss = self._eval_epoch(dev_loader, epoch) + self.checkpoint.infos["dev_loss"] = dev_loss + + # Update epoch in checkpoint + self.checkpoint.infos["epoch"] = epoch + 1 + + # Save checkpoint + if (epoch + 1) % self.save_interval == 0 and self.rank == 0: + logger.info(f"Saving checkpoint at epoch {epoch + 1}") + self.checkpoint.save_parameters(tag=f"epoch_{epoch + 1}") + + # Save final model + if self.rank == 0: + logger.info("Saving final model") + self.checkpoint.save_parameters(tag="final") + + def _train_epoch(self, train_loader: DataLoader, epoch: int, accum_grad: int = 1): + """Train for one epoch.""" + self.model.train() + + train_loss = 0.0 + num_batches = 0 + steps_per_epoch = len(train_loader) + + for batch_idx, batch in enumerate(train_loader): + if batch is None: + continue + + # Get batch data + mel = batch["mel"] + tokens = batch["tokens"] + labels = batch["labels"] + + # Forward pass + audio_features = self.model.embed_audio(mel) + logits = self.model.logits(tokens, audio_features) + + # Compute loss + loss = F.cross_entropy( + logits.reshape([-1, logits.shape[-1]]), + labels.reshape([-1]), + ignore_index=-100 + ) + + # Scale loss for gradient accumulation + if accum_grad > 1: + loss = loss / accum_grad + + # Backward pass + loss.backward() + + # Update parameters every accum_grad steps + if (batch_idx + 1) % accum_grad == 0: + self.optimizer.step() + self.optimizer.clear_grad() + + # Logging + train_loss += loss.item() * (accum_grad if accum_grad > 1 else 1) + num_batches += 1 + + # Log training progress + if batch_idx % self.log_interval == 0 and self.rank == 0: + elapsed_time = self.timer.elapsed_interval() + step = epoch * steps_per_epoch + batch_idx + logger.info(f"Epoch {epoch+1} | Batch {batch_idx}/{steps_per_epoch} | " + f"Loss: {loss.item()*(accum_grad if accum_grad > 1 else 1):.4f} | " + f"Step {step} | {elapsed_time:.2f}s elapsed") + + # Compute average loss + avg_loss = train_loss / num_batches if num_batches > 0 else float('inf') + + # Log epoch stats + if self.rank == 0: + logger.info(f"Epoch {epoch+1} | Average Training Loss: {avg_loss:.4f}") + self.checkpoint.infos["train_loss"] = avg_loss + + def _eval_epoch(self, dev_loader: DataLoader, epoch: int): + """Evaluate for one epoch.""" + self.model.eval() + + dev_losses = [] + + with paddle.no_grad(): + for batch in dev_loader: + if batch is None: + continue + + # Get batch data + mel = batch["mel"] + tokens = batch["tokens"] + labels = batch["labels"] + + # Forward pass + audio_features = self.model.embed_audio(mel) + logits = self.model.logits(tokens, audio_features) + + # Compute loss + loss = F.cross_entropy( + logits.reshape([-1, logits.shape[-1]]), + labels.reshape([-1]), + ignore_index=-100 + ) + + dev_losses.append(loss.item()) + + avg_loss = sum(dev_losses) / len(dev_losses) if dev_losses else float('inf') + + if self.rank == 0: + logger.info(f"Epoch {epoch+1} | Validation Loss: {avg_loss:.4f}") + + return avg_loss + + def save(self, tag: str = "final"): + """Save model checkpoint.""" + if self.rank == 0: + logger.info(f"Saving model checkpoint with tag '{tag}'") + self.checkpoint.save_parameters(tag=tag) + + def load(self, checkpoint_path: Union[str, Path]): + """Load model from checkpoint.""" + logger.info(f"Loading checkpoint from {checkpoint_path}") + self.checkpoint.load_parameters(checkpoint_path) + + @classmethod + def from_pretrained(cls, + config: Dict, + model_size: str = "base", + checkpoint_path: Optional[str] = None, + resource_path: Optional[str] = None): + """Create a trainer from pretrained model.""" + # Create model + model_dims = MODEL_DIMENSIONS[model_size] + model = Whisper(model_dims) + + # Load checkpoint if provided + if checkpoint_path: + state_dict = paddle.load(checkpoint_path) + model.set_state_dict(state_dict) + + # Create trainer + trainer = cls( + config=config, + model=model, + resource_path=resource_path + ) + + return trainer diff --git a/paddlespeech/s2t/models/whisper/fine_tune/utils.py b/paddlespeech/s2t/models/whisper/fine_tune/utils.py new file mode 100644 index 000000000..ce6daed04 --- /dev/null +++ b/paddlespeech/s2t/models/whisper/fine_tune/utils.py @@ -0,0 +1,271 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pathlib import Path +from typing import Dict, List, Optional, Union + +import numpy as np +import paddle +import yaml + +from paddlespeech.s2t.models.whisper.tokenizer import get_tokenizer +from paddlespeech.s2t.models.whisper.whisper import ( + MODEL_DIMENSIONS, + LANGUAGES, + Whisper, + DecodingOptions, + load_model, + transcribe +) +from paddlespeech.s2t.utils.log import Log + +logger = Log(__name__).getlog() + + +def load_whisper_model( + model_size: str = "base", + checkpoint_path: Optional[str] = None, + resource_path: Optional[str] = None, +) -> Whisper: + """Load Whisper model from checkpoint or pretrained weights. + + Args: + model_size: Model size for Whisper + checkpoint_path: Path to model checkpoint from fine-tuning + resource_path: Path to resources directory containing original models + + Returns: + Whisper model + """ + model_dims = MODEL_DIMENSIONS[model_size] + model = Whisper(model_dims) + + if checkpoint_path: + logger.info(f"Loading model from checkpoint: {checkpoint_path}") + state_dict = paddle.load(checkpoint_path) + model.set_state_dict(state_dict) + elif resource_path: + model_path = os.path.join(resource_path, "whisper", f"whisper-{model_size}.pdparams") + if os.path.exists(model_path): + logger.info(f"Loading pretrained model from: {model_path}") + state_dict = paddle.load(model_path) + model.set_state_dict(state_dict) + else: + logger.error(f"Pretrained model not found at {model_path}") + raise FileNotFoundError(f"Model file not found: {model_path}") + else: + logger.error("Either checkpoint_path or resource_path must be provided") + raise ValueError("Either checkpoint_path or resource_path must be provided") + + return model + + +def detect_language( + model: Whisper, + mel: paddle.Tensor, + tokenizer=None, + resource_path: Optional[str] = None +) -> str: + """Detect language from audio. + + Args: + model: Whisper model + mel: Mel spectrogram + tokenizer: Optional tokenizer + resource_path: Path to resources directory (required if tokenizer not provided) + + Returns: + Detected language code + """ + if not tokenizer and not resource_path: + raise ValueError("Either tokenizer or resource_path must be provided") + + if not tokenizer: + tokenizer = get_tokenizer( + multilingual=model.is_multilingual, + resource_path=resource_path + ) + + # Get audio features from encoder + audio_features = model.embed_audio(mel) + + # Get initial tokens + initial_tokens = paddle.to_tensor([[tokenizer.sot]], dtype=paddle.int64) + + # Extract language token logits + token_logits = model.logits(initial_tokens, audio_features) + language_token_logits = token_logits[:, 0, tokenizer.language_token_ranges] + + # Get language token with highest probability + language_token_probs = paddle.softmax(language_token_logits, axis=-1) + language_token_id = paddle.argmax(language_token_probs, axis=-1) + detected_language_token_id = language_token_id.item() + tokenizer.language_token_ranges[0] + + # Map token to language code + language_token = tokenizer.all_tokens[detected_language_token_id] + language_code = language_token.strip("<>") + + return language_code + + +def get_available_languages() -> List[str]: + """Get list of available languages in Whisper. + + Returns: + List of language codes + """ + return sorted(LANGUAGES.keys()) + + +def format_timestamp( + seconds: float, + always_include_hours: bool = False, + decimal_marker: str = "." +) -> str: + """Format a timestamp into a string. + + Args: + seconds: Number of seconds + always_include_hours: Whether to always include hours + decimal_marker: Decimal marker character + + Returns: + Formatted timestamp string + """ + seconds = max(0, seconds) + hours = seconds // 3600 + seconds = seconds - hours * 3600 + minutes = seconds // 60 + seconds = seconds - minutes * 60 + + hours_marker = f"{int(hours):02d}:" if always_include_hours or hours > 0 else "" + return f"{hours_marker}{int(minutes):02d}:{int(seconds):02d}{decimal_marker}{int(seconds * 100 % 100):02d}" + + +def save_srt( + transcript: Dict, + file_path: str +): + """Save transcript to SRT file. + + Args: + transcript: Transcript dictionary from Whisper + file_path: Path to output SRT file + """ + if not transcript.get("segments"): + return + + with open(file_path, "w", encoding="utf-8") as f: + for i, segment in enumerate(transcript["segments"], start=1): + start = format_timestamp(segment["start"], always_include_hours=True, decimal_marker=",") + end = format_timestamp(segment["end"], always_include_hours=True, decimal_marker=",") + text = segment["text"].strip().replace("-->", "->") + + f.write(f"{i}\n") + f.write(f"{start} --> {end}\n") + f.write(f"{text}\n\n") + + +def batch_transcribe_with_progress( + model: Whisper, + dataset, + dataloader, + resource_path: str, + language: Optional[str] = None, + task: str = "transcribe", + beam_size: int = 5, + temperature: float = 0.0, + without_timestamps: bool = True, + verbose: bool = False, + decoder_options: Optional[dict] = None, +): + """Transcribe audio files in batches with progress reporting. + + Args: + model: Whisper model + dataset: WhisperInferenceDataset + dataloader: DataLoader from dataset + resource_path: Path to resources directory + language: Language code or None for auto-detection + task: Task (transcribe or translate) + beam_size: Beam size for beam search + temperature: Temperature for sampling + without_timestamps: Whether to include timestamps + verbose: Whether to show verbose output + decoder_options: Additional decoder options + + Returns: + List of transcription results + """ + model.eval() + tokenizer = get_tokenizer( + multilingual=model.is_multilingual, + resource_path=resource_path, + language=language, + task=task + ) + + results = [] + total_batches = len(dataloader) + + with paddle.no_grad(): + for batch_idx, batch in enumerate(dataloader): + if verbose: + logger.info(f"Processing batch {batch_idx+1}/{total_batches}") + + mel = batch["mel"] # [batch, n_mels, n_frames] + audio_paths = batch["audio_paths"] + batch_size = mel.shape[0] + + # Process each item in batch + for i in range(batch_size): + audio_path = audio_paths[i] + mel_i = mel[i:i+1] # Keep batch dimension + + # Auto-detect language if needed + current_language = language + if not current_language: + current_language = detect_language(model, mel_i, tokenizer, resource_path) + if verbose: + logger.info(f"Detected language: {current_language}") + + # Set up decoding options + options = DecodingOptions( + task=task, + language=current_language, + beam_size=beam_size, + temperature=temperature, + without_timestamps=without_timestamps, + **decoder_options if decoder_options else {} + ) + + # Transcribe + result = transcribe( + model=model, + tokenizer=tokenizer, + mel=mel_i, + resource_path=resource_path, + verbose=verbose, + **options.__dict__ + ) + + # Add file path to result + result["audio_path"] = audio_path + results.append(result) + + if verbose: + logger.info(f"Transcribed {os.path.basename(audio_path)}: {result['text'][:80]}...") + + return results