Merge pull request #1046 from Jackwaterveg/refactor

[Rename]rename the config.model.feat_size and the config.model.vocab.size
pull/1049/head
Hui Zhang 3 years ago committed by GitHub
commit 85f7f674d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -24,4 +24,4 @@
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention | 6.725063021977743 | 0.047417 |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_greedy_search | 6.725063021977743 | 0.053922 |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_prefix_beam_search | 6.725063021977743 | 0.053180 |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention_rescoring | 6.725063021977743 | 0.041026 |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention_rescoring | 6.725063021977743 | 0.041026 |

@ -110,8 +110,8 @@ class DeepSpeech2Tester_hub():
def setup_model(self):
config = self.config.clone()
with UpdateConfig(config):
config.model.feat_size = self.collate_fn_test.feature_size
config.model.dict_size = self.collate_fn_test.vocab_size
config.model.input_dim = self.collate_fn_test.feature_size
config.model.output_dim = self.collate_fn_test.vocab_size
if self.args.model_type == 'offline':
model = DeepSpeech2Model.from_config(config.model)

@ -154,11 +154,11 @@ class DeepSpeech2Trainer(Trainer):
config = self.config.clone()
with UpdateConfig(config):
if self.train:
config.model.feat_size = self.train_loader.collate_fn.feature_size
config.model.dict_size = self.train_loader.collate_fn.vocab_size
config.model.input_dim = self.train_loader.collate_fn.feature_size
config.model.output_dim = self.train_loader.collate_fn.vocab_size
else:
config.model.feat_size = self.test_loader.collate_fn.feature_size
config.model.dict_size = self.test_loader.collate_fn.vocab_size
config.model.input_dim = self.test_loader.collate_fn.feature_size
config.model.output_dim = self.test_loader.collate_fn.vocab_size
if self.args.model_type == 'offline':
model = DeepSpeech2Model.from_config(config.model)

@ -249,8 +249,8 @@ class DeepSpeech2Model(nn.Layer):
The model built from config.
"""
model = cls(
feat_size=config.feat_size,
dict_size=config.dict_size,
feat_size=config.input_dim,
dict_size=config.output_dim,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,

@ -381,8 +381,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
The model built from config.
"""
model = cls(
feat_size=config.feat_size,
dict_size=config.dict_size,
feat_size=config.input_dim,
dict_size=config.output_dim,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,

Loading…
Cancel
Save