diff --git a/deepspeech/exps/deepspeech2/bin/deploy/runtime.py b/deepspeech/exps/deepspeech2/bin/deploy/runtime.py index f3125e04..0ec36b5d 100644 --- a/deepspeech/exps/deepspeech2/bin/deploy/runtime.py +++ b/deepspeech/exps/deepspeech2/bin/deploy/runtime.py @@ -29,6 +29,9 @@ from deepspeech.utils.socket_server import warm_up_test from deepspeech.utils.utility import add_arguments from deepspeech.utils.utility import print_arguments +from paddle.io import DataLoader +from deepspeech.io.collator import SpeechCollator + def init_predictor(args): if args.model_dir is not None: @@ -83,7 +86,12 @@ def start_server(config, args): config.data.keep_transcription_text = True dataset = ManifestDataset.from_config(config) - model = DeepSpeech2Model.from_pretrained(dataset, config, + config.collator.batch_size=1 + config.collator.num_workers=0 + collate_fn = SpeechCollator.from_config(config) + test_loader = DataLoader(dataset_dataset, collate_fn=collate_fn, num_workers=0) + + model = DeepSpeech2Model.from_pretrained(test_loader, config, args.checkpoint_path) model.eval() diff --git a/deepspeech/exps/deepspeech2/bin/deploy/server.py b/deepspeech/exps/deepspeech2/bin/deploy/server.py index b2ff37e0..40ba4c72 100644 --- a/deepspeech/exps/deepspeech2/bin/deploy/server.py +++ b/deepspeech/exps/deepspeech2/bin/deploy/server.py @@ -28,6 +28,9 @@ from deepspeech.utils.utility import add_arguments from deepspeech.utils.utility import print_arguments +from paddle.io import DataLoader +from deepspeech.io.collator import SpeechCollator + def start_server(config, args): """Start the ASR server""" config.defrost() @@ -36,7 +39,12 @@ def start_server(config, args): config.data.keep_transcription_text = True dataset = ManifestDataset.from_config(config) - model = DeepSpeech2Model.from_pretrained(dataset, config, + config.collator.batch_size=1 + config.collator.num_workers=0 + collate_fn = SpeechCollator.from_config(config) + test_loader = DataLoader(dataset_dataset, collate_fn=collate_fn, num_workers=0) + + model = DeepSpeech2Model.from_pretrained(test_loader, config, args.checkpoint_path) model.eval() diff --git a/deepspeech/exps/deepspeech2/bin/tune.py b/deepspeech/exps/deepspeech2/bin/tune.py index 02e329a1..f10dc27c 100644 --- a/deepspeech/exps/deepspeech2/bin/tune.py +++ b/deepspeech/exps/deepspeech2/bin/tune.py @@ -47,7 +47,7 @@ def tune(config, args): drop_last=False, collate_fn=SpeechCollator(keep_transcription_text=True)) - model = DeepSpeech2Model.from_pretrained(dev_dataset, config, + model = DeepSpeech2Model.from_pretrained(valid_loader, config, args.checkpoint_path) model.eval() diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index deb8752b..209e8b02 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -318,7 +318,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): def export(self): infer_model = DeepSpeech2InferModel.from_pretrained( - self.test_loader.dataset, self.config, self.args.checkpoint_path) + self.test_loader, self.config, self.args.checkpoint_path) infer_model.eval() feat_dim = self.test_loader.collate_fn.feature_size static_model = paddle.jit.to_static( diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 05551875..308569cd 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -506,7 +506,7 @@ class U2Tester(U2Trainer): List[paddle.static.InputSpec]: input spec. """ from deepspeech.models.u2 import U2InferModel - infer_model = U2InferModel.from_pretrained(self.test_loader.dataset, + infer_model = U2InferModel.from_pretrained(self.test_loader, self.config.model.clone(), self.args.checkpoint_path) feat_dim = self.test_loader.collate_fn.feature_size diff --git a/deepspeech/models/deepspeech2.py b/deepspeech/models/deepspeech2.py index 0ff5514d..d2c03a18 100644 --- a/deepspeech/models/deepspeech2.py +++ b/deepspeech/models/deepspeech2.py @@ -198,11 +198,11 @@ class DeepSpeech2Model(nn.Layer): cutoff_top_n, num_processes) @classmethod - def from_pretrained(cls, dataset, config, checkpoint_path): + def from_pretrained(cls, dataloader, config, checkpoint_path): """Build a DeepSpeech2Model model from a pretrained model. Parameters ---------- - dataset: paddle.io.Dataset + dataloader: paddle.io.DataLoader config: yacs.config.CfgNode model configs @@ -215,8 +215,8 @@ class DeepSpeech2Model(nn.Layer): DeepSpeech2Model The model built from pretrained result. """ - model = cls(feat_size=dataset.feature_size, - dict_size=dataset.vocab_size, + model = cls(feat_size=dataloader.collate_fn.feature_size, + dict_size=dataloader.collate_fn.vocab_size, num_conv_layers=config.model.num_conv_layers, num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size, diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index 238e2d35..23ae3423 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -876,11 +876,11 @@ class U2Model(U2BaseModel): return model @classmethod - def from_pretrained(cls, dataset, config, checkpoint_path): + def from_pretrained(cls, dataloader, config, checkpoint_path): """Build a DeepSpeech2Model model from a pretrained model. Args: - dataset (paddle.io.Dataset): not used. + dataloader (paddle.io.DataLoader): not used. config (yacs.config.CfgNode): model configs checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name @@ -888,8 +888,8 @@ class U2Model(U2BaseModel): DeepSpeech2Model: The model built from pretrained result. """ config.defrost() - config.input_dim = dataset.feature_size - config.output_dim = dataset.vocab_size + config.input_dim = dataloader.collate_fn.feature_size + config.output_dim = dataloader.collate_fn.vocab_size config.freeze() model = cls.from_config(config)