You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/examples/commonvoice/whisper/whisper_cli.py

241 lines
11 KiB

# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import sys
from pathlib import Path
import paddle
import yaml
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import str2bool
logger = Log(__name__).getlog()
def prepare_parser():
"""Prepare argument parser."""
parser = argparse.ArgumentParser(
description="Whisper CLI for fine-tuning, inference, and evaluation",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# Create subparsers for different commands
subparsers = parser.add_subparsers(dest="command", help="Command to run")
# Data preparation command
data_parser = subparsers.add_parser("prepare", help="Prepare data for fine-tuning")
data_parser.add_argument("--language", type=str, default="en", help="Language code (e.g., en, fr, es)")
data_parser.add_argument("--output_dir", type=str, default="./data", help="Directory to save preprocessed data")
data_parser.add_argument("--cache_dir", type=str, default=None, help="Cache directory for HuggingFace datasets")
data_parser.add_argument("--val_size", type=float, default=0.03, help="Validation set size as fraction")
data_parser.add_argument("--test_size", type=float, default=0.03, help="Test set size as fraction")
data_parser.add_argument("--min_duration", type=float, default=0.5, help="Minimum audio duration in seconds")
data_parser.add_argument("--max_duration", type=float, default=30.0, help="Maximum audio duration in seconds")
data_parser.add_argument("--seed", type=int, default=42, help="Random seed")
# Fine-tuning command
train_parser = subparsers.add_parser("train", help="Fine-tune Whisper model")
train_parser.add_argument("--config", type=str, required=True, help="Path to configuration file")
train_parser.add_argument("--resource_path", type=str, default="./resources", help="Path to resources directory")
train_parser.add_argument("--device", type=str, default="gpu", choices=["cpu", "gpu", "xpu"], help="Device to use")
train_parser.add_argument("--seed", type=int, default=42, help="Random seed")
train_parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to checkpoint to resume from")
train_parser.add_argument("--distributed", type=str2bool, default=False, help="Enable distributed training")
# Inference command
infer_parser = subparsers.add_parser("infer", help="Inference with Whisper model")
infer_parser.add_argument("--audio_file", type=str, help="Path to audio file for transcription")
infer_parser.add_argument("--audio_dir", type=str, help="Path to directory containing audio files")
infer_parser.add_argument("--output_dir", type=str, default="./transcripts", help="Output directory for transcriptions")
infer_parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint from fine-tuning")
infer_parser.add_argument("--resource_path", type=str, default="./resources",
help="Path to resources directory containing original models and assets")
# Model options for inference
infer_parser.add_argument("--use_original", type=str2bool, default=False,
help="Use original Whisper model instead of fine-tuned")
infer_parser.add_argument("--model_size", type=str, default="base",
choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"],
help="Model size for original Whisper")
# Decoding options for inference
infer_parser.add_argument("--language", type=str, default=None, help="Language code (e.g., en, fr, auto for detection)")
infer_parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"],
help="Task: transcribe or translate to English")
infer_parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for sampling")
infer_parser.add_argument("--beam_size", type=int, default=5, help="Beam size for beam search")
infer_parser.add_argument("--without_timestamps", type=str2bool, default=False, help="Don't include timestamps")
# Evaluation command
eval_parser = subparsers.add_parser("evaluate", help="Evaluate Whisper model")
eval_parser.add_argument("--manifest", type=str, required=True, help="Path to manifest file")
eval_parser.add_argument("--output_dir", type=str, default="./results", help="Directory to save results")
eval_parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint from fine-tuning")
eval_parser.add_argument("--resource_path", type=str, default="./resources", help="Path to resources directory")
# Model options for evaluation
eval_parser.add_argument("--model_size", type=str, default="base",
choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"],
help="Model size for original Whisper")
# Decoding options for evaluation
eval_parser.add_argument("--language", type=str, default=None, help="Language code (e.g., en, fr)")
eval_parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"],
help="Task: transcribe or translate to English")
eval_parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for sampling")
eval_parser.add_argument("--beam_size", type=int, default=5, help="Beam size for beam search")
eval_parser.add_argument("--without_timestamps", type=str2bool, default=True, help="Don't include timestamps")
eval_parser.add_argument("--max_samples", type=int, default=None, help="Maximum number of samples to evaluate")
return parser
def main():
parser = prepare_parser()
args = parser.parse_args()
if not args.command:
parser.print_help()
sys.exit(1)
# Execute the appropriate command
if args.command == "prepare":
from prepare_data import prepare_common_voice
prepare_common_voice(
language=args.language,
output_dir=args.output_dir,
cache_dir=args.cache_dir,
val_size=args.val_size,
test_size=args.test_size,
min_duration=args.min_duration,
max_duration=args.max_duration,
seed=args.seed
)
elif args.command == "train":
import train
if args.distributed:
import paddle.distributed as dist
dist.init_parallel_env()
# Load configuration
with open(args.config) as f:
config = yaml.safe_load(f)
# Set device
paddle.set_device(args.device)
trainer = train.WhisperTrainer(config, args)
trainer.train()
elif args.command == "infer":
from infer import load_whisper_model, transcribe_file, batch_transcribe
# Validate arguments
if not args.audio_file and not args.audio_dir:
logger.error("Either --audio_file or --audio_dir must be specified")
sys.exit(1)
if args.use_original:
# Use original Whisper model
if not args.resource_path:
logger.error("--resource_path must be specified when using original model")
sys.exit(1)
model = load_whisper_model(
model_size=args.model_size,
resource_path=args.resource_path
)
else:
# Use fine-tuned model
if not args.checkpoint:
logger.error("--checkpoint must be specified when not using original model")
sys.exit(1)
model = load_whisper_model(
model_size=args.model_size,
checkpoint_path=args.checkpoint,
resource_path=args.resource_path
)
# Prepare transcription keyword arguments
transcribe_kwargs = {
"language": args.language,
"task": args.task,
"temperature": args.temperature,
"beam_size": args.beam_size,
"without_timestamps": args.without_timestamps,
}
# Run transcription
if args.audio_file:
result = transcribe_file(model, args.audio_file, args.resource_path, **transcribe_kwargs)
if result:
print("-" * 40)
print("Transcription:")
print(result["text"])
print("-" * 40)
# Save to output directory if specified
if args.output_dir:
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
file_name = os.path.splitext(os.path.basename(args.audio_file))[0]
with open(output_dir / f"{file_name}.txt", "w", encoding="utf-8") as f:
f.write(result["text"].strip())
import json
with open(output_dir / f"{file_name}.json", "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"Results saved to {output_dir}")
elif args.audio_dir:
batch_transcribe(
model,
args.audio_dir,
args.output_dir,
args.resource_path,
**transcribe_kwargs
)
elif args.command == "evaluate":
from evaluate import load_model, evaluate_manifest
model = load_model(
model_size=args.model_size,
checkpoint_path=args.checkpoint,
resource_path=args.resource_path
)
evaluate_manifest(
model=model,
manifest_path=args.manifest,
resource_path=args.resource_path,
language=args.language,
task=args.task,
temperature=args.temperature,
beam_size=args.beam_size,
without_timestamps=args.without_timestamps,
output_dir=args.output_dir,
max_samples=args.max_samples
)
if __name__ == "__main__":
main()