|
|
@ -18,8 +18,10 @@ import numpy as np
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
from paddle.inference import Config
|
|
|
|
from paddle.inference import Config
|
|
|
|
from paddle.inference import create_predictor
|
|
|
|
from paddle.inference import create_predictor
|
|
|
|
|
|
|
|
from paddle.io import DataLoader
|
|
|
|
|
|
|
|
|
|
|
|
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
|
|
|
|
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
|
|
|
|
|
|
|
|
from deepspeech.io.collator import SpeechCollator
|
|
|
|
from deepspeech.io.dataset import ManifestDataset
|
|
|
|
from deepspeech.io.dataset import ManifestDataset
|
|
|
|
from deepspeech.models.deepspeech2 import DeepSpeech2Model
|
|
|
|
from deepspeech.models.deepspeech2 import DeepSpeech2Model
|
|
|
|
from deepspeech.training.cli import default_argument_parser
|
|
|
|
from deepspeech.training.cli import default_argument_parser
|
|
|
@ -78,26 +80,31 @@ def inference(config, args):
|
|
|
|
def start_server(config, args):
|
|
|
|
def start_server(config, args):
|
|
|
|
"""Start the ASR server"""
|
|
|
|
"""Start the ASR server"""
|
|
|
|
config.defrost()
|
|
|
|
config.defrost()
|
|
|
|
config.data.manfiest = config.data.test_manifest
|
|
|
|
config.data.manifest = config.data.test_manifest
|
|
|
|
config.data.augmentation_config = ""
|
|
|
|
|
|
|
|
config.data.keep_transcription_text = True
|
|
|
|
|
|
|
|
dataset = ManifestDataset.from_config(config)
|
|
|
|
dataset = ManifestDataset.from_config(config)
|
|
|
|
|
|
|
|
|
|
|
|
model = DeepSpeech2Model.from_pretrained(dataset, config,
|
|
|
|
config.collator.augmentation_config = ""
|
|
|
|
|
|
|
|
config.collator.keep_transcription_text = True
|
|
|
|
|
|
|
|
config.collator.batch_size = 1
|
|
|
|
|
|
|
|
config.collator.num_workers = 0
|
|
|
|
|
|
|
|
collate_fn = SpeechCollator.from_config(config)
|
|
|
|
|
|
|
|
test_loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = DeepSpeech2Model.from_pretrained(test_loader, config,
|
|
|
|
args.checkpoint_path)
|
|
|
|
args.checkpoint_path)
|
|
|
|
model.eval()
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
# prepare ASR inference handler
|
|
|
|
# prepare ASR inference handler
|
|
|
|
def file_to_transcript(filename):
|
|
|
|
def file_to_transcript(filename):
|
|
|
|
feature = dataset.process_utterance(filename, "")
|
|
|
|
feature = test_loader.collate_fn.process_utterance(filename, "")
|
|
|
|
audio = np.array([feature[0]]).astype('float32') #[1, D, T]
|
|
|
|
audio = np.array([feature[0]]).astype('float32') #[1, T, D]
|
|
|
|
audio_len = feature[0].shape[1]
|
|
|
|
audio_len = feature[0].shape[0]
|
|
|
|
audio_len = np.array([audio_len]).astype('int64') # [1]
|
|
|
|
audio_len = np.array([audio_len]).astype('int64') # [1]
|
|
|
|
|
|
|
|
|
|
|
|
result_transcript = model.decode(
|
|
|
|
result_transcript = model.decode(
|
|
|
|
paddle.to_tensor(audio),
|
|
|
|
paddle.to_tensor(audio),
|
|
|
|
paddle.to_tensor(audio_len),
|
|
|
|
paddle.to_tensor(audio_len),
|
|
|
|
vocab_list=dataset.vocab_list,
|
|
|
|
vocab_list=test_loader.collate_fn.vocab_list,
|
|
|
|
decoding_method=config.decoding.decoding_method,
|
|
|
|
decoding_method=config.decoding.decoding_method,
|
|
|
|
lang_model_path=config.decoding.lang_model_path,
|
|
|
|
lang_model_path=config.decoding.lang_model_path,
|
|
|
|
beam_alpha=config.decoding.alpha,
|
|
|
|
beam_alpha=config.decoding.alpha,
|
|
|
@ -138,7 +145,7 @@ if __name__ == "__main__":
|
|
|
|
add_arg('host_ip', str,
|
|
|
|
add_arg('host_ip', str,
|
|
|
|
'localhost',
|
|
|
|
'localhost',
|
|
|
|
"Server's IP address.")
|
|
|
|
"Server's IP address.")
|
|
|
|
add_arg('host_port', int, 8086, "Server's IP port.")
|
|
|
|
add_arg('host_port', int, 8089, "Server's IP port.")
|
|
|
|
add_arg('speech_save_dir', str,
|
|
|
|
add_arg('speech_save_dir', str,
|
|
|
|
'demo_cache',
|
|
|
|
'demo_cache',
|
|
|
|
"Directory to save demo audios.")
|
|
|
|
"Directory to save demo audios.")
|
|
|
|