You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
241 lines
11 KiB
241 lines
11 KiB
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import paddle
|
|
import yaml
|
|
|
|
from paddlespeech.s2t.utils.log import Log
|
|
from paddlespeech.s2t.utils.utility import str2bool
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
|
|
def prepare_parser():
|
|
"""Prepare argument parser."""
|
|
parser = argparse.ArgumentParser(
|
|
description="Whisper CLI for fine-tuning, inference, and evaluation",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
|
|
# Create subparsers for different commands
|
|
subparsers = parser.add_subparsers(dest="command", help="Command to run")
|
|
|
|
# Data preparation command
|
|
data_parser = subparsers.add_parser("prepare", help="Prepare data for fine-tuning")
|
|
data_parser.add_argument("--language", type=str, default="en", help="Language code (e.g., en, fr, es)")
|
|
data_parser.add_argument("--output_dir", type=str, default="./data", help="Directory to save preprocessed data")
|
|
data_parser.add_argument("--cache_dir", type=str, default=None, help="Cache directory for HuggingFace datasets")
|
|
data_parser.add_argument("--val_size", type=float, default=0.03, help="Validation set size as fraction")
|
|
data_parser.add_argument("--test_size", type=float, default=0.03, help="Test set size as fraction")
|
|
data_parser.add_argument("--min_duration", type=float, default=0.5, help="Minimum audio duration in seconds")
|
|
data_parser.add_argument("--max_duration", type=float, default=30.0, help="Maximum audio duration in seconds")
|
|
data_parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
|
|
|
# Fine-tuning command
|
|
train_parser = subparsers.add_parser("train", help="Fine-tune Whisper model")
|
|
train_parser.add_argument("--config", type=str, required=True, help="Path to configuration file")
|
|
train_parser.add_argument("--resource_path", type=str, default="./resources", help="Path to resources directory")
|
|
train_parser.add_argument("--device", type=str, default="gpu", choices=["cpu", "gpu", "xpu"], help="Device to use")
|
|
train_parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
|
train_parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to checkpoint to resume from")
|
|
train_parser.add_argument("--distributed", type=str2bool, default=False, help="Enable distributed training")
|
|
|
|
# Inference command
|
|
infer_parser = subparsers.add_parser("infer", help="Inference with Whisper model")
|
|
infer_parser.add_argument("--audio_file", type=str, help="Path to audio file for transcription")
|
|
infer_parser.add_argument("--audio_dir", type=str, help="Path to directory containing audio files")
|
|
infer_parser.add_argument("--output_dir", type=str, default="./transcripts", help="Output directory for transcriptions")
|
|
infer_parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint from fine-tuning")
|
|
infer_parser.add_argument("--resource_path", type=str, default="./resources",
|
|
help="Path to resources directory containing original models and assets")
|
|
|
|
# Model options for inference
|
|
infer_parser.add_argument("--use_original", type=str2bool, default=False,
|
|
help="Use original Whisper model instead of fine-tuned")
|
|
infer_parser.add_argument("--model_size", type=str, default="base",
|
|
choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"],
|
|
help="Model size for original Whisper")
|
|
|
|
# Decoding options for inference
|
|
infer_parser.add_argument("--language", type=str, default=None, help="Language code (e.g., en, fr, auto for detection)")
|
|
infer_parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"],
|
|
help="Task: transcribe or translate to English")
|
|
infer_parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for sampling")
|
|
infer_parser.add_argument("--beam_size", type=int, default=5, help="Beam size for beam search")
|
|
infer_parser.add_argument("--without_timestamps", type=str2bool, default=False, help="Don't include timestamps")
|
|
|
|
# Evaluation command
|
|
eval_parser = subparsers.add_parser("evaluate", help="Evaluate Whisper model")
|
|
eval_parser.add_argument("--manifest", type=str, required=True, help="Path to manifest file")
|
|
eval_parser.add_argument("--output_dir", type=str, default="./results", help="Directory to save results")
|
|
eval_parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint from fine-tuning")
|
|
eval_parser.add_argument("--resource_path", type=str, default="./resources", help="Path to resources directory")
|
|
|
|
# Model options for evaluation
|
|
eval_parser.add_argument("--model_size", type=str, default="base",
|
|
choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"],
|
|
help="Model size for original Whisper")
|
|
|
|
# Decoding options for evaluation
|
|
eval_parser.add_argument("--language", type=str, default=None, help="Language code (e.g., en, fr)")
|
|
eval_parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"],
|
|
help="Task: transcribe or translate to English")
|
|
eval_parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for sampling")
|
|
eval_parser.add_argument("--beam_size", type=int, default=5, help="Beam size for beam search")
|
|
eval_parser.add_argument("--without_timestamps", type=str2bool, default=True, help="Don't include timestamps")
|
|
eval_parser.add_argument("--max_samples", type=int, default=None, help="Maximum number of samples to evaluate")
|
|
|
|
return parser
|
|
|
|
|
|
def main():
|
|
parser = prepare_parser()
|
|
args = parser.parse_args()
|
|
|
|
if not args.command:
|
|
parser.print_help()
|
|
sys.exit(1)
|
|
|
|
# Execute the appropriate command
|
|
if args.command == "prepare":
|
|
from prepare_data import prepare_common_voice
|
|
prepare_common_voice(
|
|
language=args.language,
|
|
output_dir=args.output_dir,
|
|
cache_dir=args.cache_dir,
|
|
val_size=args.val_size,
|
|
test_size=args.test_size,
|
|
min_duration=args.min_duration,
|
|
max_duration=args.max_duration,
|
|
seed=args.seed
|
|
)
|
|
|
|
elif args.command == "train":
|
|
import train
|
|
if args.distributed:
|
|
import paddle.distributed as dist
|
|
dist.init_parallel_env()
|
|
|
|
# Load configuration
|
|
with open(args.config) as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
# Set device
|
|
paddle.set_device(args.device)
|
|
|
|
trainer = train.WhisperTrainer(config, args)
|
|
trainer.train()
|
|
|
|
elif args.command == "infer":
|
|
from infer import load_whisper_model, transcribe_file, batch_transcribe
|
|
|
|
# Validate arguments
|
|
if not args.audio_file and not args.audio_dir:
|
|
logger.error("Either --audio_file or --audio_dir must be specified")
|
|
sys.exit(1)
|
|
|
|
if args.use_original:
|
|
# Use original Whisper model
|
|
if not args.resource_path:
|
|
logger.error("--resource_path must be specified when using original model")
|
|
sys.exit(1)
|
|
|
|
model = load_whisper_model(
|
|
model_size=args.model_size,
|
|
resource_path=args.resource_path
|
|
)
|
|
else:
|
|
# Use fine-tuned model
|
|
if not args.checkpoint:
|
|
logger.error("--checkpoint must be specified when not using original model")
|
|
sys.exit(1)
|
|
|
|
model = load_whisper_model(
|
|
model_size=args.model_size,
|
|
checkpoint_path=args.checkpoint,
|
|
resource_path=args.resource_path
|
|
)
|
|
|
|
# Prepare transcription keyword arguments
|
|
transcribe_kwargs = {
|
|
"language": args.language,
|
|
"task": args.task,
|
|
"temperature": args.temperature,
|
|
"beam_size": args.beam_size,
|
|
"without_timestamps": args.without_timestamps,
|
|
}
|
|
|
|
# Run transcription
|
|
if args.audio_file:
|
|
result = transcribe_file(model, args.audio_file, args.resource_path, **transcribe_kwargs)
|
|
if result:
|
|
print("-" * 40)
|
|
print("Transcription:")
|
|
print(result["text"])
|
|
print("-" * 40)
|
|
|
|
# Save to output directory if specified
|
|
if args.output_dir:
|
|
output_dir = Path(args.output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
file_name = os.path.splitext(os.path.basename(args.audio_file))[0]
|
|
with open(output_dir / f"{file_name}.txt", "w", encoding="utf-8") as f:
|
|
f.write(result["text"].strip())
|
|
|
|
import json
|
|
with open(output_dir / f"{file_name}.json", "w", encoding="utf-8") as f:
|
|
json.dump(result, f, ensure_ascii=False, indent=2)
|
|
|
|
print(f"Results saved to {output_dir}")
|
|
|
|
elif args.audio_dir:
|
|
batch_transcribe(
|
|
model,
|
|
args.audio_dir,
|
|
args.output_dir,
|
|
args.resource_path,
|
|
**transcribe_kwargs
|
|
)
|
|
|
|
elif args.command == "evaluate":
|
|
from evaluate import load_model, evaluate_manifest
|
|
|
|
model = load_model(
|
|
model_size=args.model_size,
|
|
checkpoint_path=args.checkpoint,
|
|
resource_path=args.resource_path
|
|
)
|
|
|
|
evaluate_manifest(
|
|
model=model,
|
|
manifest_path=args.manifest,
|
|
resource_path=args.resource_path,
|
|
language=args.language,
|
|
task=args.task,
|
|
temperature=args.temperature,
|
|
beam_size=args.beam_size,
|
|
without_timestamps=args.without_timestamps,
|
|
output_dir=args.output_dir,
|
|
max_samples=args.max_samples
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|