fix the bug of 'import path error' for ds2

pull/791/head
huangyuxin 3 years ago
parent 86d08f994b
commit 8d5062702b

@ -83,8 +83,8 @@
"from deepspeech.frontend.utility import read_manifest\n", "from deepspeech.frontend.utility import read_manifest\n",
"from deepspeech.utils.utility import add_arguments, print_arguments\n", "from deepspeech.utils.utility import add_arguments, print_arguments\n",
"\n", "\n",
"from deepspeech.models.deepspeech2 import DeepSpeech2Model\n", "from deepspeech.models.ds2 import DeepSpeech2Model\n",
"from deepspeech.models.deepspeech2 import DeepSpeech2InferModel\n", "from deepspeech.models.ds2 import DeepSpeech2InferModel\n",
"from deepspeech.io.dataset import ManifestDataset\n", "from deepspeech.io.dataset import ManifestDataset\n",
"\n", "\n",
"\n", "\n",
@ -669,4 +669,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 2 "nbformat_minor": 2
} }

@ -23,7 +23,7 @@ from paddle.io import DataLoader
from deepspeech.exps.deepspeech2.config import get_cfg_defaults from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.io.collator import SpeechCollator from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.training.cli import default_argument_parser from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.socket_server import AsrRequestHandler from deepspeech.utils.socket_server import AsrRequestHandler
from deepspeech.utils.socket_server import AsrTCPServer from deepspeech.utils.socket_server import AsrTCPServer

@ -21,7 +21,7 @@ from paddle.io import DataLoader
from deepspeech.exps.deepspeech2.config import get_cfg_defaults from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.io.collator import SpeechCollator from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.training.cli import default_argument_parser from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.socket_server import AsrRequestHandler from deepspeech.utils.socket_server import AsrRequestHandler
from deepspeech.utils.socket_server import AsrTCPServer from deepspeech.utils.socket_server import AsrTCPServer

@ -21,7 +21,7 @@ from paddle.io import DataLoader
from deepspeech.exps.deepspeech2.config import get_cfg_defaults from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.io.collator import SpeechCollator from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.training.cli import default_argument_parser from deepspeech.training.cli import default_argument_parser
from deepspeech.utils import error_rate from deepspeech.utils import error_rate
from deepspeech.utils.utility import add_arguments from deepspeech.utils.utility import add_arguments

@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle import paddle
from paddle import nn
from deepspeech.modules.embedding import PositionalEncoding
from deepspeech.modules.subsampling import Conv2dSubsampling4 from deepspeech.modules.subsampling import Conv2dSubsampling4

@ -26,7 +26,7 @@ from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModeOnline'] __all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline']
class CRNNEncoder(nn.Layer): class CRNNEncoder(nn.Layer):
@ -68,7 +68,7 @@ class CRNNEncoder(nn.Layer):
rnn_input_size = i_size rnn_input_size = i_size
else: else:
rnn_input_size = layernorm_size rnn_input_size = layernorm_size
if use_gru == True: if use_gru is True:
self.rnn.append( self.rnn.append(
nn.GRU( nn.GRU(
input_size=rnn_input_size, input_size=rnn_input_size,
@ -113,7 +113,7 @@ class CRNNEncoder(nn.Layer):
if init_state_h_box is not None: if init_state_h_box is not None:
init_state_list = None init_state_list = None
if self.use_gru == True: if self.use_gru is True:
init_state_h_list = paddle.split( init_state_h_list = paddle.split(
init_state_h_box, self.num_rnn_layers, axis=0) init_state_h_box, self.num_rnn_layers, axis=0)
init_state_list = init_state_h_list init_state_list = init_state_h_list
@ -139,7 +139,7 @@ class CRNNEncoder(nn.Layer):
x = self.fc_layers_list[i](x) x = self.fc_layers_list[i](x)
x = F.relu(x) x = F.relu(x)
if self.use_gru == True: if self.use_gru is True:
final_chunk_state_h_box = paddle.concat( final_chunk_state_h_box = paddle.concat(
final_chunk_state_list, axis=0) final_chunk_state_list, axis=0)
final_chunk_state_c_box = init_state_c_box #paddle.zeros_like(final_chunk_state_h_box) final_chunk_state_c_box = init_state_c_box #paddle.zeros_like(final_chunk_state_h_box)

@ -146,7 +146,7 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True) self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual( self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True) paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
if use_gru == False: if use_gru is False:
self.assertEqual( self.assertEqual(
paddle.allclose(final_state_c_box, final_state_c_box_chk), True) paddle.allclose(final_state_c_box, final_state_c_box_chk), True)
@ -177,7 +177,7 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True) self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual( self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True) paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
if use_gru == False: if use_gru is False:
self.assertEqual( self.assertEqual(
paddle.allclose(final_state_c_box, final_state_c_box_chk), True) paddle.allclose(final_state_c_box, final_state_c_box_chk), True)

Loading…
Cancel
Save