Merge pull request #930 from Jackwaterveg/join_ctc

Join ctc
pull/931/head
Hui Zhang 4 years ago committed by GitHub
commit ed19e243de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -24,6 +24,7 @@ from .utils import add_results_to_json
from deepspeech.exps import dynamic_import_tester
from deepspeech.io.reader import LoadInputsAndTargets
from deepspeech.models.asr_interface import ASRInterface
from deepspeech.models.lm.transformer import TransformerLM
from deepspeech.utils.log import Log
# from espnet.asr.asr_utils import get_model_conf
# from espnet.asr.asr_utils import torch_load
@ -78,12 +79,18 @@ def recog_v2(args):
preprocess_args={"train": False}, )
if args.rnnlm:
lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
# NOTE: for a compatibility with less than 0.5.0 version models
lm_model_module = getattr(lm_args, "model_module", "default")
lm_class = dynamic_import_lm(lm_model_module, lm_args.backend)
lm = lm_class(len(char_list), lm_args)
torch_load(args.rnnlm, lm)
lm_path = args.rnnlm
lm = TransformerLM(
n_vocab=5002,
pos_enc=None,
embed_unit=128,
att_unit=512,
head=8,
unit=2048,
layer=16,
dropout_rate=0.5, )
model_dict = paddle.load(lm_path)
lm.set_state_dict(model_dict)
lm.eval()
else:
lm = None

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any
from typing import List
from typing import Tuple
@ -150,7 +151,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
h, _, cache = self.encoder.forward_one_step(
emb, self._target_mask(y), cache=state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(axis=-1).squeeze(0)
logp = F.log_softmax(h).squeeze(0)
return logp, cache
# batch beam search API (see BatchScorerInterface)
@ -193,7 +194,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
h, _, states = self.encoder.forward_one_step(
emb, self._target_mask(ys), cache=batch_state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(axi=-1)
logp = F.log_softmax(h)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)]
@ -231,14 +232,14 @@ if __name__ == "__main__":
#Test the score
input2 = np.array([5])
input2 = paddle.to_tensor(input2)
state = (None, None, 0)
state = None
output, state = tlm.score(input2, state, None)
input3 = np.array([10])
input3 = np.array([5, 10])
input3 = paddle.to_tensor(input3)
output, state = tlm.score(input3, state, None)
input4 = np.array([0])
input4 = np.array([5, 10, 0])
input4 = paddle.to_tensor(input4)
output, state = tlm.score(input4, state, None)
print("output", output)

@ -399,7 +399,8 @@ class TransformerEncoder(BaseEncoder):
xs, pos_emb, masks = self.embed(
xs, masks.astype(xs.dtype), offset=0)
else:
xs = self.embed(xs)
xs, pos_emb, masks = self.embed(
xs, masks.astype(xs.dtype), offset=0)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks = masks.astype(paddle.bool)

Loading…
Cancel
Save