revise from_pretrained function

pull/684/head
Haoxin Ma 4 years ago
parent 3652b87f33
commit d55e6b5a0a

@ -29,6 +29,9 @@ from deepspeech.utils.socket_server import warm_up_test
from deepspeech.utils.utility import add_arguments from deepspeech.utils.utility import add_arguments
from deepspeech.utils.utility import print_arguments from deepspeech.utils.utility import print_arguments
from paddle.io import DataLoader
from deepspeech.io.collator import SpeechCollator
def init_predictor(args): def init_predictor(args):
if args.model_dir is not None: if args.model_dir is not None:
@ -83,7 +86,12 @@ def start_server(config, args):
config.data.keep_transcription_text = True config.data.keep_transcription_text = True
dataset = ManifestDataset.from_config(config) dataset = ManifestDataset.from_config(config)
model = DeepSpeech2Model.from_pretrained(dataset, config, config.collator.batch_size=1
config.collator.num_workers=0
collate_fn = SpeechCollator.from_config(config)
test_loader = DataLoader(dataset_dataset, collate_fn=collate_fn, num_workers=0)
model = DeepSpeech2Model.from_pretrained(test_loader, config,
args.checkpoint_path) args.checkpoint_path)
model.eval() model.eval()

@ -28,6 +28,9 @@ from deepspeech.utils.utility import add_arguments
from deepspeech.utils.utility import print_arguments from deepspeech.utils.utility import print_arguments
from paddle.io import DataLoader
from deepspeech.io.collator import SpeechCollator
def start_server(config, args): def start_server(config, args):
"""Start the ASR server""" """Start the ASR server"""
config.defrost() config.defrost()
@ -36,7 +39,12 @@ def start_server(config, args):
config.data.keep_transcription_text = True config.data.keep_transcription_text = True
dataset = ManifestDataset.from_config(config) dataset = ManifestDataset.from_config(config)
model = DeepSpeech2Model.from_pretrained(dataset, config, config.collator.batch_size=1
config.collator.num_workers=0
collate_fn = SpeechCollator.from_config(config)
test_loader = DataLoader(dataset_dataset, collate_fn=collate_fn, num_workers=0)
model = DeepSpeech2Model.from_pretrained(test_loader, config,
args.checkpoint_path) args.checkpoint_path)
model.eval() model.eval()

@ -47,7 +47,7 @@ def tune(config, args):
drop_last=False, drop_last=False,
collate_fn=SpeechCollator(keep_transcription_text=True)) collate_fn=SpeechCollator(keep_transcription_text=True))
model = DeepSpeech2Model.from_pretrained(dev_dataset, config, model = DeepSpeech2Model.from_pretrained(valid_loader, config,
args.checkpoint_path) args.checkpoint_path)
model.eval() model.eval()

@ -318,7 +318,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
def export(self): def export(self):
infer_model = DeepSpeech2InferModel.from_pretrained( infer_model = DeepSpeech2InferModel.from_pretrained(
self.test_loader.dataset, self.config, self.args.checkpoint_path) self.test_loader, self.config, self.args.checkpoint_path)
infer_model.eval() infer_model.eval()
feat_dim = self.test_loader.collate_fn.feature_size feat_dim = self.test_loader.collate_fn.feature_size
static_model = paddle.jit.to_static( static_model = paddle.jit.to_static(

@ -506,7 +506,7 @@ class U2Tester(U2Trainer):
List[paddle.static.InputSpec]: input spec. List[paddle.static.InputSpec]: input spec.
""" """
from deepspeech.models.u2 import U2InferModel from deepspeech.models.u2 import U2InferModel
infer_model = U2InferModel.from_pretrained(self.test_loader.dataset, infer_model = U2InferModel.from_pretrained(self.test_loader,
self.config.model.clone(), self.config.model.clone(),
self.args.checkpoint_path) self.args.checkpoint_path)
feat_dim = self.test_loader.collate_fn.feature_size feat_dim = self.test_loader.collate_fn.feature_size

@ -198,11 +198,11 @@ class DeepSpeech2Model(nn.Layer):
cutoff_top_n, num_processes) cutoff_top_n, num_processes)
@classmethod @classmethod
def from_pretrained(cls, dataset, config, checkpoint_path): def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model. """Build a DeepSpeech2Model model from a pretrained model.
Parameters Parameters
---------- ----------
dataset: paddle.io.Dataset dataloader: paddle.io.DataLoader
config: yacs.config.CfgNode config: yacs.config.CfgNode
model configs model configs
@ -215,8 +215,8 @@ class DeepSpeech2Model(nn.Layer):
DeepSpeech2Model DeepSpeech2Model
The model built from pretrained result. The model built from pretrained result.
""" """
model = cls(feat_size=dataset.feature_size, model = cls(feat_size=dataloader.collate_fn.feature_size,
dict_size=dataset.vocab_size, dict_size=dataloader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers, num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers, num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size, rnn_size=config.model.rnn_layer_size,

@ -876,11 +876,11 @@ class U2Model(U2BaseModel):
return model return model
@classmethod @classmethod
def from_pretrained(cls, dataset, config, checkpoint_path): def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model. """Build a DeepSpeech2Model model from a pretrained model.
Args: Args:
dataset (paddle.io.Dataset): not used. dataloader (paddle.io.DataLoader): not used.
config (yacs.config.CfgNode): model configs config (yacs.config.CfgNode): model configs
checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name
@ -888,8 +888,8 @@ class U2Model(U2BaseModel):
DeepSpeech2Model: The model built from pretrained result. DeepSpeech2Model: The model built from pretrained result.
""" """
config.defrost() config.defrost()
config.input_dim = dataset.feature_size config.input_dim = dataloader.collate_fn.feature_size
config.output_dim = dataset.vocab_size config.output_dim = dataloader.collate_fn.vocab_size
config.freeze() config.freeze()
model = cls.from_config(config) model = cls.from_config(config)

Loading…
Cancel
Save