remove resource and fix code style.

pull/2640/head
zxcd 3 years ago
parent ae98cc8bbf
commit 28b5778eed

@ -15,13 +15,14 @@
import os.path import os.path
import sys import sys
import distutils
import numpy as np
import paddle import paddle
import soundfile import soundfile
from yacs.config import CfgNode
from paddlespeech.s2t.models.whisper import _download
from paddlespeech.s2t.models.whisper import ModelDimensions from paddlespeech.s2t.models.whisper import ModelDimensions
from paddlespeech.s2t.models.whisper import transcribe from paddlespeech.s2t.models.whisper import transcribe
from paddlespeech.s2t.models.whisper import utils
from paddlespeech.s2t.models.whisper import Whisper from paddlespeech.s2t.models.whisper import Whisper
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
@ -29,17 +30,43 @@ from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
def load_model(model_file): class WhisperInfer():
logger.info("download and loading the model file......") def __init__(self, config, args):
download_root = os.getenv( self.args = args
"XDG_CACHE_HOME", self.config = config
os.path.join(os.path.expanduser("~"), ".cache", "whisper")) self.audio_file = args.audio_file
model_file = _download(args.model_file, download_root, in_memory=False)
model_dict = paddle.load(model_file) paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
config.pop("ngpu")
#load_model
model_dict = paddle.load(self.config.model_file)
config.pop("model_file")
dims = ModelDimensions(**model_dict["dims"]) dims = ModelDimensions(**model_dict["dims"])
model = Whisper(dims) self.model = Whisper(dims)
model.load_dict(model_dict) self.model.load_dict(model_dict)
return model
def run(self):
check(args.audio_file)
with paddle.no_grad():
temperature = config.pop("temperature")
temperature_increment_on_fallback = config.pop(
"temperature_increment_on_fallback")
if temperature_increment_on_fallback is not None:
temperature = tuple(
np.arange(temperature, 1.0 + 1e-6,
temperature_increment_on_fallback))
else:
temperature = [temperature]
result = transcribe(
self.model, args.audio_file, temperature=temperature, **config)
if args.result_file is not None:
with open(args.result_file, 'w') as f:
f.write(str(result))
print("result", result)
return result
def check(audio_file: str): def check(audio_file: str):
@ -49,7 +76,7 @@ def check(audio_file: str):
logger.info("checking the audio file format......") logger.info("checking the audio file format......")
try: try:
sig, sample_rate = soundfile.read(audio_file) _, sample_rate = soundfile.read(audio_file)
except Exception as e: except Exception as e:
logger.error(str(e)) logger.error(str(e))
logger.error( logger.error(
@ -60,38 +87,33 @@ def check(audio_file: str):
logger.info("The audio file format is right") logger.info("The audio file format is right")
def main(config, args):
WhisperInfer(config, args).run()
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
# save asr result to
parser.add_argument( parser.add_argument(
"--result_file", type=str, help="path of save the asr result") "--result_file", type=str, help="path of save the asr result")
parser.add_argument( parser.add_argument(
"--audio_file", type=str, help="path of the input audio file") "--audio_file", type=str, help="path of the input audio file")
parser.add_argument( parser.add_argument(
"--model_file", "--debug",
default="large", type=distutils.util.strtobool,
type=str, default=False,
help="path of the input model file") help="for debug.")
parser.add_argument("--beam_size", type=utils.optional_int, default=5)
parser.add_argument("--verbose", type=utils.str2bool, default=True)
parser.add_argument("--device", default="gpu")
args = parser.parse_args() args = parser.parse_args()
check(args.audio_file) config = CfgNode(new_allowed=True)
available_device = paddle.get_device() if args.config:
if args.device == "cpu" and "gpu:" in available_device: config.merge_from_file(args.config)
warnings.warn("Performing inference on CPU when CUDA is available") if args.decode_cfg:
paddle.set_device("cpu") decode_confs = CfgNode(new_allowed=True)
else: decode_confs.merge_from_file(args.decode_cfg)
paddle.set_device("gpu") config.decode = decode_confs
if args.opts:
model = load_model(args.model_file) config.merge_from_list(args.opts)
config.freeze()
result = transcribe( main(config, args)
model,
args.audio_file,
beam_size=args.beam_size,
fp16=False,
verbose=True)

@ -2,83 +2,11 @@
# Copyright (c) 2022 PaddlePaddle Authors and . All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors and . All Rights Reserved.
# #
# Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper/__init__.py) # Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper/__init__.py)
import hashlib from paddlespeech.s2t.models.whisper.whipser import decode
import io from paddlespeech.s2t.models.whisper.whipser import DecodingOptions
import os from paddlespeech.s2t.models.whisper.whipser import DecodingResult
import urllib from paddlespeech.s2t.models.whisper.whipser import detect_language
import warnings from paddlespeech.s2t.models.whisper.whipser import log_mel_spectrogram
from typing import List from paddlespeech.s2t.models.whisper.whipser import ModelDimensions
from typing import Optional from paddlespeech.s2t.models.whisper.whipser import transcribe
from typing import Union from paddlespeech.s2t.models.whisper.whipser import Whisper
import paddle
from more_itertools import padded
from tqdm import tqdm
from paddlespeech.s2t.models.whisper.audio import log_mel_spectrogram
from paddlespeech.s2t.models.whisper.audio import pad_or_trim
from paddlespeech.s2t.models.whisper.decoding import decode
from paddlespeech.s2t.models.whisper.decoding import DecodingOptions
from paddlespeech.s2t.models.whisper.decoding import DecodingResult
from paddlespeech.s2t.models.whisper.decoding import detect_language
from paddlespeech.s2t.models.whisper.model import ModelDimensions
from paddlespeech.s2t.models.whisper.model import Whisper
from paddlespeech.s2t.models.whisper.transcribe import transcribe
_MODELS = {
"large":
"https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/large.model.pdparams"
}
_MODELS_sha256 = {
"large": "589a2229582cc9173091f2481bba2cc8228997502ac75cbb0be6d874e8433d0f"
}
def _download(model_key: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)
expected_sha256 = _MODELS_sha256[model_key]
url = _MODELS[model_key]
download_target = os.path.join(root, os.path.basename(url))
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(
f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
with urllib.request.urlopen(url) as source, open(download_target,
"wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit='iB',
unit_scale=True,
unit_divisor=1024) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
return model_bytes if in_memory else download_target
def available_models() -> List[str]:
"""Returns the names of available models"""
return list(_MODELS.keys())

@ -12,7 +12,8 @@ from typing import Union
import numpy as np import numpy as np
import paddle import paddle
from transformers import GPT2TokenizerFast from paddlenlp.transformers import GPTTokenizer
#from transformers import GPT2TokenizerFast
LANGUAGES = { LANGUAGES = {
"en": "english", "en": "english",
@ -135,9 +136,9 @@ TO_LANGUAGE_CODE = {
@dataclass(frozen=True) @dataclass(frozen=True)
class Tokenizer: class Tokenizer:
"""A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens""" """A thin wrapper around `GPTTokenizer` providing quick access to special tokens"""
tokenizer: "GPT2TokenizerFast" tokenizer: "GPTTokenizer"
language: Optional[str] language: Optional[str]
sot_sequence: Tuple[int] sot_sequence: Tuple[int]
@ -147,6 +148,15 @@ class Tokenizer:
def decode(self, def decode(self,
token_ids: Union[int, List[int], np.ndarray, paddle.Tensor], token_ids: Union[int, List[int], np.ndarray, paddle.Tensor],
**kwargs): **kwargs):
if len(token_ids) > 1:
ids_list = []
for ids in token_ids:
if paddle.is_tensor(ids):
ids = ids.item()
if ids < len(self.tokenizer):
ids_list.append(ids)
token_ids = ids_list
return self.tokenizer.decode(token_ids, **kwargs) return self.tokenizer.decode(token_ids, **kwargs)
def decode_with_timestamps(self, tokens) -> str: def decode_with_timestamps(self, tokens) -> str:
@ -269,12 +279,13 @@ class Tokenizer:
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = { result = {
self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0] self.tokenizer.encode(" -").input_ids[0],
self.tokenizer.encode(" '").input_ids[0]
} }
for symbol in symbols + list(miscellaneous): for symbol in symbols + list(miscellaneous):
for tokens in [ for tokens in [
self.tokenizer.encode(symbol), self.tokenizer.encode(symbol).input_ids,
self.tokenizer.encode(" " + symbol) self.tokenizer.encode(" " + symbol).input_ids
]: ]:
if len(tokens) == 1 or symbol in miscellaneous: if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0]) result.add(tokens[0])
@ -282,7 +293,7 @@ class Tokenizer:
return tuple(sorted(result)) return tuple(sorted(result))
def _get_single_token_id(self, text) -> int: def _get_single_token_id(self, text) -> int:
tokens = self.tokenizer.encode(text) tokens = self.tokenizer.encode(text).input_ids
assert len(tokens) == 1, f"{text} is not encoded as a single token" assert len(tokens) == 1, f"{text} is not encoded as a single token"
return tokens[0] return tokens[0]
@ -291,7 +302,7 @@ class Tokenizer:
def build_tokenizer(name: str="gpt2"): def build_tokenizer(name: str="gpt2"):
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
path = os.path.join(os.path.dirname(__file__), "assets", name) path = os.path.join(os.path.dirname(__file__), "assets", name)
tokenizer = GPT2TokenizerFast.from_pretrained(path) tokenizer = GPTTokenizer.from_pretrained(path)
specials = [ specials = [
"<|startoftranscript|>", "<|startoftranscript|>",

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save